Hey teddy,
I was returning the predictions, the input, target and the loss (all packaged in a dict) and then accessed those in training_step_end
and training_epoch_end
. This probably caused to the OOM. Looking at the docs - the pseudocode - I realized that for the training/validation/test phase PL iterates over the batch and performs a training_step
. Its output is stored in a list which can be accessed in ...epoch_end
. If I’d return the predictions and then accessed them in ..epoch_end
, this list of items could be pretty big, depending on the dataset and batch size. I think I managed to solve the problem by just returning the loss and storing the computed metrics in a nn.Module
within the lightning module instead. I’ll share the code here:
The UNet model can be downloaded from my github gists here
# %% imports
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
# %% custom metric class
class CustomMetric(torch.nn.Module):
def __init__(self, metric, metric_name, **kwargs):
super().__init__()
self.metric = metric
self.metric_name = metric_name
self.kwargs = kwargs
self.scores = []
self.valid_classes = []
self.valid_matrices = []
self.score = None
self.valid_class = None
self.valid_matrix = None
self.last_scores = None
self.last_valid_classes = None
self.last_valid_matrices = None
def batch(self, prediction, target):
# compute score for every batch
self.score = self.metric(prediction, target, **self.kwargs)
# compute valid classes for every batch
self.valid_class = target.unique()
# compute valid_matrix for every batch
dummy = torch.zeros_like(self.score)
dummy[self.valid_class] = 1
self.valid_matrix = dummy.type(torch.bool)
self.scores.append(self.score)
self.valid_classes.append(self.valid_class)
self.valid_matrices.append(self.valid_matrix)
def get_metrics_batch(self, mean=True):
# returns the class metrics of the batch for the classes that are present in the image
if mean:
return self.score[self.valid_class].mean()
else:
return self.score[self.valid_class]
def get_metrics_epoch(self, mean=True, last=False):
if last:
scores = torch.stack(self.last_scores).T
masks = torch.stack(self.last_valid_matrices).T
else:
scores = torch.stack(self.scores).T
masks = torch.stack(self.valid_matrices).T
# iterate over columns (classes) and only select the present classes
filtered = [s[m] for s, m in zip(scores, masks)]
# filtered = [scores[:, class_idx].masked_select(masks[:, class_idx]) for class_idx in range(scores.shape[1])]
if mean:
return torch.stack([c.mean() for c in filtered]).mean()
else:
return torch.stack([c.mean() for c in filtered])
def epoch(self):
# compute score for every epoch
self.last_scores = self.scores
self.last_valid_classes = self.valid_classes
self.last_valid_matrices = self.valid_matrices
result = self.get_metrics_epoch(mean=True)
self.reset()
return result
def reset(self):
self.scores = []
self.valid_classes = []
self.valid_matrices = []
def __repr__(self):
return self.metric_name
# %% lightningModule
class Segmentation_Lightning(pl.LightningModule):
def __init__(self, model, lr, num_classes):
super().__init__()
# model
self.model = model
# learning rate
self.lr = lr
# number of classes
self.num_classes = num_classes
# loss
self.ce_loss = torch.nn.CrossEntropyLoss()
# metrics
self.f1_train = CustomMetric(metric=pl.metrics.functional.f1,
metric_name='F1_Train',
num_classes=4,
average='none')
self.f1_valid = CustomMetric(metric=pl.metrics.functional.f1,
metric_name='F1_Valid',
num_classes=4,
average='none')
self.f1_test = CustomMetric(metric=pl.metrics.functional.f1,
metric_name='F1_Valid',
num_classes=4,
average='none')
# save hyperparameters
self.save_hyperparameters()
def shared_step(self, batch):
# Batch
x, y = batch['x'], batch['y']
# Prediction
out = self.model(x)
# Softmax
out_soft = torch.nn.functional.softmax(out, dim=1)
# Loss
loss = self.ce_loss(out, y) # cross entropy loss (LogSoftmax + NLLLoss)
return {**batch, 'pred': out_soft, 'loss': loss}
def training_step(self, batch, batch_idx):
# Loss
shared_step = self.shared_step(batch)
# Metrics
self.f1_train.batch(shared_step['pred'], shared_step['y']) # e.g. [0.2, 0.3, 0.25, 0.25]
# Logging
name = 'Train'
self.logger.experiment.log_metric(f'{name}/F1/Batch', self.f1_train.get_metrics_batch(mean=True)) # Total F1
for class_idx, metric in zip(self.f1_train.valid_class, self.f1_train.get_metrics_batch(mean=False)):
self.logger.experiment.log_metric(f'{name}/F1/Batch/Class/{class_idx}', metric)
return shared_step['loss']
def training_epoch_end(self, outputs):
# Logging
name = 'Train'
# Class
for class_idx, value in enumerate(self.f1_train.get_metrics_epoch(mean=False)):
self.logger.experiment.log_metric(f'{name}/F1/Epoch/Class/{class_idx}', value)
# Total
self.logger.experiment.log_metric(f'{name}/F1/Epoch', self.f1_train.epoch()) # Total F1
def validation_step(self, batch, batch_idx):
# Loss
shared_step = self.shared_step(batch)
# Metrics
self.f1_valid.batch(shared_step['pred'], shared_step['y'])
# Logging
name = 'Valid'
self.logger.experiment.log_metric(f'{name}/F1/Batch', self.f1_valid.get_metrics_batch(mean=True)) # Total F1
for class_idx, metric in zip(self.f1_valid.valid_class, self.f1_valid.get_metrics_batch(mean=False)):
self.logger.experiment.log_metric(f'{name}/F1/Batch/Class/{class_idx}', metric)
# Logging for checkpoint
self.log('checkpoint_valid_f1_epoch', self.f1_valid.get_metrics_batch(mean=True)) # per epoch automatically
return shared_step['loss']
def validation_epoch_end(self, outputs):
# Logging
name = 'Valid'
# Class
for class_idx, value in enumerate(self.f1_valid.get_metrics_epoch(mean=False)):
self.logger.experiment.log_metric(f'{name}/F1/Epoch/Class/{class_idx}', value)
# Total
self.logger.experiment.log_metric(f'{name}/F1/Epoch', self.f1_valid.epoch()) # Total F1
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
return optimizer
# %% dataset class
class RandomDataSet(data.Dataset):
def __init__(self,
num_samples,
size,
num_classes=4,
inputs_dtype=torch.float32,
targets_dtype=torch.long
):
self.num_samples = num_samples
self.size = size
self.num_classes = num_classes
self.inputs_dtype = inputs_dtype
self.targets_dtype = targets_dtype
self.cached_data = []
# Generate some random input target pairs
for num in range(self.num_samples):
inp = torch.from_numpy(np.random.uniform(low=0, high=1, size=(3,) + size))
tar = torch.randint(low=0, high=num_classes, size=size)
self.cached_data.append((inp, tar))
def __len__(self):
return self.num_samples
def __getitem__(self, index: int):
x, y = self.cached_data[index]
# Typecasting
x, y = x.type(self.inputs_dtype), y.type(self.targets_dtype)
return {'x': x, 'y': y}
# %% dataloader
size = (512, 512)
batch_size = 8
num_classes = 4
dataset_train = RandomDataSet(num_samples=40, size=size, num_classes=num_classes)
dataset_valid = RandomDataSet(num_samples=16, size=size, num_classes=num_classes)
dataloader_training = DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=True,
num_workers=0)
dataloader_valid = DataLoader(dataset=dataset_valid,
batch_size=batch_size,
shuffle=False,
num_workers=0)
batch = next(iter(dataloader_training))
x, y = batch['x'], batch['y']
# %% model
from unet import UNet
model = UNet(in_channels=3,
out_channels=4,
n_blocks=4,
start_filters=32,
activation='relu',
normalization='group8',
conv_mode='same',
dim=2,
up_mode='transposed')
# %% task init
task = Segmentation_Lightning(model=model,
lr=0.001,
num_classes=num_classes)
# %% logger init
from pytorch_lightning.loggers.neptune import NeptuneLogger
from api_key_neptune import get_api_key # I created a .py file from which I import the api key
api_key = get_api_key()
neptune_logger = NeptuneLogger(
api_key=api_key,
project_name='johschmidt42/Test', # this has to be created beforehand otherwise an error is thrown
experiment_name='testing',
)
# %% trainer init
trainer = Trainer(gpus=1,
precision=32,
benchmark=True,
checkpoint_callback=False,
logger=neptune_logger,
log_every_n_steps=1,
num_sanity_val_steps=0,
enable_pl_optimizer=False,
)
# %% start training
trainer.max_epochs = 10
trainer.fit(task,
train_dataloader=dataloader_training,
val_dataloaders=dataloader_valid)
If you notice anything unusual, a bad practice for PL etc., please let me know!
Some comments:
- I prefer to use
logger.experiment.log_metric()
instead of log()
, because on_step=True
in validation_step()
doesn’t seem to work/is intended not to work.
- The
CustomMetrics
is basically a wrapper to perform some steps for metrics computation
- The goal for me is to properly compute the metrics for a given batch that could be missing a or several classes, e.g. there is no instance/pixel for class 1 in the batch. Computing the unreduced metric f1 (f1 for every class) could result in sth like:
[0.5, 0.0, 0.25, 0.25]
. Taking the mean of that would return 0.25, which is wrong, because class 1 has to be ignored. The result shoud be 0.333. Unfortunately, there is no ignore_index flag like for IoU.
- I could properly implement a metric (with
update()
and compute()
functions) to make use of multi GPU cases, but I only use 1 GPU right now, so I think this should be fine for now.