Customize Backward for Dino Model

Hello,

I am trying to implement DINO SS model, on Pytorch Lightning (i havent seend any public implementation yet)

I have 2 models, teacher and student, on native pytorch the authors backward the loss like this:

student update

    optimizer.zero_grad()
    param_norms = None
      loss.backward()
      if args.clip_grad:
          param_norms = utils.clip_gradients(student, args.clip_grad)
      utils.cancel_gradients_last_layer(epoch, student,
                                        args.freeze_last_layer)
      optimizer.step()

EMA update for the teacher

with torch.no_grad():
m = momentum_schedule[it] # momentum parameter
for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

how can i do this??