How to pass hidden states freely when training and testing

Hello! I want to combine Transformer XL of Hugging Face and Pytorch Lightning. But I don’t know how to pass the hidden states from one training/testing step to the following one. It seems that I can only set ‘truncate_bptt_steps’ in the Trainer to activate the arg ‘hiddens’ in ‘training_steps’. However, the ‘hiddens’ are reset after ‘truncate_bptt_steps’ of forward-propagation, which is not expected. So I wonder if there’s a way to pass the hidden states freely. Thank you so much!

you can simply add them in to output dictionary as

loss = ...
return {'loss': loss, 'hiddens': any_you_want}

see our tests: lightning/test_cpu.py at 05f25f3a543dd06382feb7f19761e452af572288 · Lightning-AI/lightning · GitHub