Hi, I would like to implement gradient skipping in PL, i.e. skipping training updates with a gradient norm above a certain threshold.
In other words,
- Calculate gradient norm of model parameters
- If gradient norm > thresh, decide whether or not to call optimizer.step()
Any advice on what could be the recommended way to implement this in the LightningModule?