Model Parallel Layer

Hi, in my research code I have a huge dense classification layer with millions of classes. This does not fit well to the memory. What best practice suggests is to use model parallel dense layer, where a matrix multiplication is split across GPUs, that significantly impacts the performance since the huge dense layer becomes a bottleneck.

So far I have implemented a model parallel loss that distributes the computation across workers with nccl/gloo backends and is also able to propagate gradients correctly. In this layer I have different parameters on each GPU. Unfortunately, this does not work well with DDP distributed backend PyTorch plugin (pytorch-lightning/ddp_plugin.py at master · PyTorchLightning/pytorch-lightning · GitHub) as it aggregates gradients for backward assuming all parameters should be the same.

What do you suggest to implement as a workaround for my problem or is there any?

mind try Sharded version, that shall help…
Introducing PyTorch Lightning Sharded: Train SOTA Models, With Half The Memory