Custom Batch class won't send to the correct device

Hi there!

In my collate_fn function, I am trying to return my own custom batch class instead of a list of tensors. Problem is that the Trainer will not send the CustomBatch tensors to the correct device

class CustomBatch:
    def __init__(self,
                 x: torch.Tensor,
                 y: torch.Tensor,
                 pad_x: torch.Tensor,
                 pad_y: torch.Tensor):
        self.x = x
        self.y = y
        self.pad_x = pad_x
        self.pad_y = pad_y
def collate_fn(batch):
    X = torch.cat([seq.x for seq in batch])
    Y = torch.cat([seq.y for seq in batch])
    pad_x = []
    pad_y = []
    for i, seq in enumerate(batch):
        pad_x.extend([i] * seq.x.size(0))
        pad_y.extend([i] * seq.y.size(0))
    pad_x = torch.Tensor(pad_x).long()
    pad_y = torch.Tensor(pad_y).long()
    return CustomBatch(X, Y, pad_x, pad_y)

If instead, I return a list of tensors then it works fine.

    return X, Y, pad_x, pad_y

Do add any special method to CustomBatch so the Trainer knows how to manage its content.

Best,
Arturo

hey @artuntun

you might want to look at transfer_batch_to_device hook.

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get quick response. The forums will be marked read-only after some time.

Thank you