How to implement gradient reversal (GRL) layer in PL?

I am training a DA network where I use GRL in the discriminator to train the encoder?
Is the GRL layer implementation in PL similar to the PyTorch one?

from torch.autograd import Function
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
        print(alpha)
    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output * -ctx.alpha
      
        return output, None
def grad_reverse(x,alpha):
    return GradReverse.apply(x,alpha)
####
#In the discriminator forward pass
####
  def forward(self, y,alpha):
      self.alpha=alpha
      y = grad_reverse(y,self.alpha)

      y=self.classifier3(y)
      return y

hey @Abubakr5

it should work the same if this works in PyTorch when using PyTorchLightning since it doesn’t change any core-training logic internally.

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only soon.

Thank you