Where to clamp weights

I would like to clamp weights between [-1,1] and make sure they never go beyond that range. Where is the right place to set this in Lightning?

two places in LightningModule:
optimizer_step

def optimizer_step(self, *args, **kwargs):
    super().optimizer_step(*args, **kwargs)
    # clamp the weights here

or in on_before_zero_grad hook

def on_before_zero_grad(self, *args, **kwargs):
    # clamp the weights here

Iā€™d suggest the latter one on_before_zero_grad.

1 Like