How to pass hidden states freely when training and testing

Hello! I want to combine Transformer XL of Hugging Face and Pytorch Lightning. But I don’t know how to pass the hidden states from one training/testing step to the following one. It seems that I can only set ‘truncate_bptt_steps’ in the Trainer to activate the arg ‘hiddens’ in ‘training_steps’. However, the ‘hiddens’ are reset after ‘truncate_bptt_steps’ of forward-propagation, which is not expected. So I wonder if there’s a way to pass the hidden states freely. Thank you so much!

you can simply add them in to output dictionary as

loss = ...
return {'loss': loss, 'hiddens': any_you_want}

see our tests: pytorch-lightning/ at 05f25f3a543dd06382feb7f19761e452af572288 · PyTorchLightning/pytorch-lightning · GitHub