Performance on CIFAR10/STL10 DataModules drops from Lightning 0.8.5/Bolts 0.1.0 to Lightning 0.9.0/Bolts 0.1.1

I wasn’t sure how exactly to categorize this, but my guess is that it’s some change in Bolts leading up to the SimCLR release, so I’ve put it here for now.

Training in the traditional self-supervised linear classification evaluation setup for Moco, I’m seeing a substantial performance drop after updating versions this weekend - seemingly in the supervised (evaluation) portion, but not the self-supervised pretraining (making me think it’s something about the Bolts DataModules rather than base Lightning, but I could be interpreting incorrectly).

With the STL10 Bolts DataModule, I am getting a performance drop of 16% top-1 validation accuracy (88->72), and a 0.5 increase in cross-entropy loss (0.3 to 0.8). On CIFAR10 I get a smaller, but consistent drop of 2% top-1 validation accuracy. Both of these are using the exact same code, just with switching conda envs. (The reason for changing both the Lightning version and the Bolts version is that it seems required for compatibility between the two.)

I have some more minor details here (https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues/175 - sorry for the double post, didn’t realize Bolts-related stuff could go here) and am happy to provide the self-supervised pretrained models/hparam configs to speed things along. Ideally if something has changed in the DataModules/Bolts elsewhere that would lead to this kind of drop it would be great to be pointed to what changed and why.

We had found that the previous data had leaks from the val set because of the random split funciton. we fixed the leaks. The performance drop is likely due to overfitting from before.

Ah, that’d do it, thanks. Which DataModules were affected by that? Could you point me to around where the issue/fix was and roughly how the problem was being caused? (I’ve made some DataModules based on those and would want to make sure I don’t have a leak in those/fix it.) I figure it’s somewhere around here c180b703d9b820c49a87641e27806c4344867ba8 ? Any pointers appreciated.

I did a bit more digging and my best guess is this - https://github.com/PyTorchLightning/pytorch-lightning-bolts/commit/156fc650910e249ba5ba0baaf9b183ae2be7901b#diff-514d5723861e1f9e69849f6dd13f7ee1 ? If I’m understanding correctly, the correction comes from adding a seed to both times random_split is called so that the split done is the same in train_dataloader and val_dataloader (whereas before there was overlap since it was making two different splits). If this is the case, any DataModules using random_split are affected and all others (e.g. ImageNet) should not be. Does this sound about right?

1 Like

yes, that’s correct!