Devide missmatch with DP training

Hi!

This is my first time with distributed training so the error is probably due to something I’m missing.

I tried to run on two gpus and I’m having errors due to tensors being in different devices at run time. I don’t manually change devices of any tensor as I expect PL to do so for me, but don’t know where to start debugging this error.

Let me share my code:

Dataset

    """Dataset."""

    def __init__(self, csv_file, test=False):
        """
        Args:
            csv_file (string): Path to the csv file with user,past,future.
        """
        self.frame = pd.read_csv(
            csv_file,
            delimiter="\t",
            names=["user_id", "past", "future"],
            header=None,
            # iterator=True,
        )
        self.test = test
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.frame )

    def __getitem__(self, idx):
        data = self.frame .loc[idx]
        user_id = data[0]
        past = torch.LongTensor(ast.literal_eval(data[1]))
        future = data[2]
        if self.test:
            return user_id, past, 0

        return user_id, past, future

Model:


class RNN(pl.LightningModule):
    def __init__(
        self, args=None,
    ):
        super().__init__()

        self.save_hyperparameters()
        self.args = args
#         self.args = args
        self.embeddings_user = nn.Embedding(
            self.args.n_users, self.args.embedding_dim
        )
        self.embeddings_past = nn.Embedding(
            self.args.n_items, self.args.embedding_dim
        )

        self.lstm = nn.LSTM(
            self.args.embedding_dim,
            self.args.hidden_dim,
            num_layers=self.args.n_layers,
        )
        self.linear = nn.Linear(
            self.args.hidden_dim + self.args.embedding_dim, self.args.n_items
        )
        self.criterion = torch.nn.CrossEntropyLoss()
        self.acc = torchmetrics.Accuracy(top_k=14)
        self.recall = torchmetrics.Recall(top_k=14)

    def forward(self, x):
        user, past = x
        past = past.permute(1, 0)
        print("FORWARD SHAPES", user.shape,past.shape)
        user = self.embeddings_user(user)
        past = self.embeddings_past(past)
        lstm_out, self.hidden = self.lstm(past)

        concat = torch.cat((user, lstm_out[-1]), -1)
        return self.linear(concat)

    def training_step(self, batch, batch_idx):
        user, past, target = batch[0], batch[1], batch[2]
        out = self((user, past))
        loss = self.criterion(out, target)

        out = out.softmax(1)
        acc = self.acc(out, target)
        recall = self.recall(out, target)

        self.log(
            "train/step_recall_top_14",
            recall,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
        )
        self.log(
            "train/step_acc_top_14", acc, on_step=True, on_epoch=False, prog_bar=False
        )
        self.log("train/step_loss", loss, on_step=True, on_epoch=False, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        user, past, target = batch[0], batch[1], batch[2]
        out = self((user, past))
        loss = self.criterion(out, target)

        out = out.softmax(1)
        acc = self.acc(out, target)
        recall = self.recall(out, target)

        return {
            "val_loss": loss,
            "val_acc": acc.detach(),
            "val_recall": recall.detach(),
        }

    def test_step(self, batch, batch_idx):
        user, past, target = batch[0], batch[1], batch[2]
        out = self((user, past))
        top14 = torch.topk(out, 14).indices  # Top 14 products for the user
        return {"users": user, "top14": top14}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_ac = torch.stack([x["val_acc"] for x in outputs]).mean()
        avg_recall = torch.stack([x["val_recall"] for x in outputs]).mean()

        self.log("val/loss", avg_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/acc", avg_ac, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/recall", avg_recall, on_step=False, on_epoch=True, prog_bar=False)

        # return {"val_loss": avg_loss, "val_auc": avg_ac, "val_acc": avg_recall}

    def test_epoch_end(self, outputs):
        users = torch.cat([x["users"] for x in outputs])
        y_hat = torch.cat([x["top14"] for x in outputs])
        users = users.tolist()
        y_hat = y_hat.tolist()
        print(len(users))
        print(len(y_hat))

        data = {"users": users, "top14": y_hat}
        df = pd.DataFrame.from_dict(data)
        print(len(df))
        df.to_csv("lightning_logs/predict.csv", index=False)
        # Yet to be processed into complaint format for the evaluator from production model, done in
        # The parse_predict.ipynb notebook

    def training_epoch_end(self, training_step_outputs):
        print("training end")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.args.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--embedding_dim", type=int, default=300)
        parser.add_argument("--hidden_dim", type=int, default=512)
        parser.add_argument("--n_layers", type=int, default=3)
        parser.add_argument("--n_users", type=int, default=692874)
        parser.add_argument("--n_items", type=int, default=58011)
        parser.add_argument("--learning_rate", type=float, default=0.0001)
        return parser

    ####################
    # DATA RELATED HOOKS
    ####################

    def setup(self, stage=None):
        print("Loading datasets")
        self.train_dataset = Dataset(self.args.train_path)
        self.val_dataset = Dataset(self.args.val_path)
        self.test_dataset = Dataset(self.args.test_path, test=True)
        print("Done")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.batch_size,
            shuffle=True,
            num_workers=self.args.n_workers,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.args.batch_size,
            shuffle=False,
            num_workers=self.args.n_workers,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.args.batch_size,
            shuffle=False,
            num_workers=self.args.n_workers,
        )

And training script:




def main(args):
    # ------------
    # model
    # ------------
    model = RNN(args)

    args.val_check_interval = 0.33  # Evaluate after every 33% of the train batches
    args.logger = pl_loggers.TensorBoardLogger(
        file_api_root_path+ f"lightning_logs/", name=args.experiment_name
    )
    args.progress_bar_refresh_rate = 2
    # ------------
    # training
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)

    trainer.fit(model)


def cli_main():
    pl.seed_everything(17)

    parser = ArgumentParser()
    # parser = pl.Trainer.add_argparse_args(parser)
    parser = RNN.add_model_specific_args(parser)

    past_length = 50
    ## cli
    parser.add_argument(
        "--train_path",
        default=file_api_root_path+ f"validate_window{past_length}_processed.csv",
#         default=file_api_root_path+ f"train_window{past_length}_processed.csv",
        type=str,
    )
    parser.add_argument(
        "--val_path",
        default=file_api_root_path+ f"validate_window{past_length}_processed.csv",
        type=str,
    )
    parser.add_argument(
        "--test_path",
        default=file_api_root_path+ f"test_window{past_length}_processed.csv",
        type=str,
    )
    parser.add_argument("--batch_size", default=128, type=int)
    parser.add_argument("--n_workers", default=8, type=int)
    parser.add_argument("--max_epochs", default=10, type=int)
    parser.add_argument("--gpus", default=-1, type=int)
    parser.add_argument("--auto_select_gpus", default=True, type=bool)
    parser.add_argument("--log_gpu_memory", default=False, type=bool)
    parser.add_argument("--experiment_name", type=str, default="base_lstm")
    parser.add_argument("--accelerator", type=str, default="dp")

    args = parser.parse_args("")
    print(args)
    main(args)


if __name__ == "__main__":
    cli_main()

I’m currently running on two k80 gpus and I get following error:

Original Traceback (most recent call last):
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/pytorch_lightning/overrides/data_parallel.py", line 77, in forward
    output = super().forward(*inputs, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "<command-1320645486208119>", line 68, in validation_step
    acc = self.acc(out, target)
  File "/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/torchmetrics/metric.py", line 168, in forward
    self.update(*args, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/torchmetrics/metric.py", line 216, in wrapped_func
    return update(*args, **kwargs)
  File "/databricks/python/lib/python3.8/site-packages/torchmetrics/classification/accuracy.py", line 255, in update
    self.tp += tp
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

If I add a line to print my out and target devices right before calling acc = self.acc(out, target):

print("INPUT DEVICE", target.get_device(), "OUT DEVICE", out.get_device())

I get the output that both tensor are on same device, is only when calling torch metrics that this happens

This happens on the validation step that the trainer runs to check all inputs are good to go. I thought this was run on cpu? I have no code to change my tensors between devices and can’t really figure out why this is happening

This seems to be caused from the torchmetrics object not being sharable in DP mode. Metrics are nn.Modules. In DP, nn.Module state is replicated onto each GPU but then destroyed. Only the state of the root device survives.

Solution for this is to move the usage of the metrics objetcs to step_end methods.

Solved in the slack community.