I would like to clamp weights between [-1,1] and make sure they never go beyond that range. Where is the right place to set this in Lightning?
two places in LightningModule
:
optimizer_step
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
# clamp the weights here
or in on_before_zero_grad hook
def on_before_zero_grad(self, *args, **kwargs):
# clamp the weights here
Iād suggest the latter one on_before_zero_grad
.
1 Like