Let's distributed the last huge fc more than million classes


I am quite new here, currently, I am using a vision transformer to do the face recognition with more than 1,000,000 classes. After extracting features from the vit, I need send them into a linear classifier fc with millions of dimention. Normal pytorch DDP can not afford this.

My question is, what should I do? Just use the model parrallel to wrapper vit+fc, or just DDP+vit, model parrallel only with fc?