I am trying to integrate pytorch-metric-learning with pytorch-lightning. It has come to my attention that within the package there are instances were tensors are created and transferred to a device (e.g., in the BaseMetricLossFunction which is the Base class of all loss functions). This is needed of course so that everything works in vanilla PyTorch.
I know that explicitly moving tensors to devices is strongly discouraged in PL, so my question is to how to go about such a case? Obviously, I could copy fork the package and remove all tensor transfers, but I would like to avoid that.
In case there is no way of avoiding this, it is going to be a serious problem for PL?
Note: the explicit transfer to device mentioned above transfers one tensor based on the device of another tensor (i.e. it is not hard-coded). The initial tensor would be taken care of by PL in the LightningModule, so at least there would be no risk of ending up with different devices. My real question is, are such explicit transfers problematic in any way for PL?