How to implement Linear Probing for first N epochs and then switch to fine-tuning?

Hello, I’m thinking how I should implement a training techninque from Fine-Tuning can Distort Pretrained Features and Underperform Out-of-Distribution paper. Essentialy what authors describe is to freeze all model weights except softmax layer for beginning of training and after that switch to fine-tuning. I’m working on BERT-like models from transformers. Also how I could do this switch to fine-tuning gradual (let’s say every epoch unfreeze 1 top layer from transformer)?

Hi @Konrad, you can use the BaseFinetuning callback to achieve this.

You will need to override the freeze_before_training and finetune_function methods with logic to unfreeze 1 top layer at the start of each epoch. Let me know if you face any issue while implementing it.

