Using ddp and loading checkpoint from non-lightning model

Hi, I would like to use ddp for a lightning model that i’m building, and a part of that model is another model that is not a lightning module. As such, I have a “.pt” checkpoint for it, so I can’t use lightningmodel.load_from_checkpoint. However, doing this:

net.load_state_dict(torch.load(load_path)[‘model_state_dict’])

Actually spawns extra processes on my gpu #0, and there are n of these, where n is the number of gpus I’ve specified I’d like to use.

In normal pytorch, the solution to this is to do :

net.load_state_dict(torch.load(load_path, map_location=‘cuda:{}’.format(gpu))[‘model_state_dict’])

where the function gets the gpu passed in from mp.spawn(). However, since this is all abstracted away I don’t know what to do. Is there a way to get the current gpu for that process, and do the same thing as above, or is there a better lighting way to do this?

Thanks