The warning in the tile keeps appearing at every validation of my model. This is the implementation of my validation_step
(I have not implemented the validation_step_end
method):
def calculate_stats(self, outputs):
preds = outputs["preds"].argmax(dim=1)
target = outputs["target"]
self.accuracy(preds, target)
self.iou(preds, target)
stats = {
"acc": self.accuracy,
"iou": self.iou,
"loss": outputs["loss"],
}
return stats
def validation_step(self, batch, batch_index, *args, **kwargs):
X, Y = batch["img"], batch["target"].squeeze(1).long()
with torch.no_grad():
Y_hat = self.forward(X)
val_loss = self.criterion(torch.log(Y_hat + 1e-8), Y)
stats = self.calculate_stats(
{"loss": val_loss, "preds": Y_hat, "target": Y})
self.log_dict(
{
"val_acc": stats["acc"],
"val_iou": stats["iou"]
},
on_step=False,
on_epoch=True,
sync_dist=True,
)
return val_loss
Am I missing something?