Compute loss on entire dataset

I believe it does not exactly fit the gradient accumulation functionality.

There is this loss we are trying to compute. One term of it regards one dataset sample at a time. Another term regards all adjacent dataset samples (a smoothing term across time). What would be an idiomatic way to implement this?

E.g., I was thinking, implement the training_step to compute the first term, agnostically to batch size, since it considers independent samples, and then, from each training step output some data that at a hook later on (e.g. on_epoch_end) are concatenated and an error is computed on all adjacent pairs.