Suppose you do something like this:
class LitModel(LightningModule):
def training_step(...):
z = torch.rand(2, 3)
How do you move it to the device automatically?
Suppose you do something like this:
class LitModel(LightningModule):
def training_step(...):
z = torch.rand(2, 3)
How do you move it to the device automatically?
In this case use the .device attribute of the lightning module.
class LitModel(LightningModule):
def training_step(...):
z = torch.rand(2, 3, device=self.device)
This makes sure to keep your code 100% device agnostic