Logging stops when adding ModelCheckpoint callback to trainer

I’m trying to add checkpoint saving capabilities in my Lightning code, but that somehow stops the terminal logging.

Following the doc, I added the ModelCheckpoint callback to my trainer as follows:

def get_callbacks(args):
	checkpoint_callback = ModelCheckpoint(
		monitor='valid_loss_epoch',
		# dirpath=f'saved_models',
		# filename=f'{args.name}',
		save_top_k=3,
		mode='min',
		verbose=True,
		save_last=True,
	)
	return [checkpoint_callback]

args.default_root_dir = f'saved_models/'
trainer = pl.Trainer.from_argparse_args(args)
trainer.callbacks = get_callbacks(args)

This makes the log in the terminal disappear i.e. the program doesn’t log to the terminal anymore. Commenting out the trainer.callbacks = get_callbacks(args) starts the terminal logging. I’ve tried combinations of adding and commenting the dirpath and default_root_dir , but none of these helps with the logging. Any help would be great!

Details of my Lightning Module step/epoch functions where I have the self.log calls:

def training_step(self, batch, batch_idx, optimizer_idx):
	(opt1, opt2)	= self.optimizers()
	outputs			= self(batch)
	preds			= torch.argmax(outputs, axis=1)
	targets			= batch["target"]
	loss			= self.calc_loss(outputs, targets)
	acc 			= self.calc_acc(preds, targets)

	self.manual_backward(loss, opt1)
	opt1.step()
	opt1.zero_grad()
	opt2.step()
	opt2.zero_grad()

	self.log('train_loss_step', loss, prog_bar=True)
	self.log('train_acc_step', acc, prog_bar=True)
	return {'loss': loss, "preds": preds, "targets": targets}

def backward(self, loss, optimizer, optimizer_idx):
	loss.backward()

def training_epoch_end(self, outputs):
	preds	= torch.cat([x['preds'] for x in outputs])
	targets	= torch.cat([x['targets'] for x in outputs])
	loss	= torch.stack([x['loss'] for x in outputs]).mean()
	self.log('train_loss_epoch', loss.item())
	self.log('train_acc_epoch', self.calc_acc(preds, targets))

def validation_step(self, batch, batch_idx, dataloader_idx=0):
	outputs	= self(batch)
	preds	= torch.argmax(outputs, axis=1)
	targets	= batch["target"]
	loss 	= self.calc_loss(outputs, targets)
	acc 	= self.calc_acc(preds, targets)

	self.log('valid_loss_step', loss, prog_bar=True)
	self.log('valid_acc_step', acc, prog_bar=True)
	return {'loss': loss, "preds": preds, "targets": targets}

def validation_epoch_end(self, outputs):
	preds	= torch.cat([x['preds'] for x in outputs])
	targets	= torch.cat([x['targets'] for x in outputs])
	loss	= torch.stack([x['loss'] for x in outputs]).mean()
	self.log('valid_loss_epoch', loss.item())
	self.log('valid_acc_epoch', self.calc_acc(preds, targets))

try this:

trainer = pl.Trainer.from_argparse_args(args, callbacks=get_callbacks(args))

since progress_bar used for terminal logging is a callback, which is added during trainer __init__, so if you do this: trainer.callbacks = get_callbacks(args), it will override the callbacks and thus will remove progress_bar callback.

1 Like