Is there a pytorch / pytorch lightning implementation of SimCLR


The code + docs are here

But the direct use is:

Getting the model

First install bolts
pip install pytorch-lightning-bolts


import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import (
    SimCLREvalDataTransform, SimCLRTrainDataTransform)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size)

# fit
trainer = pl.Trainer(), dm)

Or use pre-trained

from pl_bolts.models.self_supervised import SimCLR

weight_path = ''
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)