I am trying to implement a WGAN-GP (gradient penalty) by defining the following function which is called inside training_step() method.
def wgan_gradient_penalty( real: torch.Tensor, fake: torch.Tensor, discriminator: torch.nn.Module) -> torch.Tensor: alpha = torch.rand(real.size(0), 1, 1, 1).type_as(real) x_hat = alpha * real + (1 - alpha) * fake.detach() x_hat.requires_grad = True # calc. d_hat: discriminator output on x_hat d_hat = discriminator(x_hat) # calc. gradients of d_hat vs. x_hat grads = torch.autograd.grad( outputs=d_hat, inputs=x_hat, grad_outputs=torch.ones(d_hat.size()).type_as(real), create_graph=True, retain_graph=True)
But it seems that the output of the network does is detached from the network. When I check
d_hat.grad_fn which is None.
(Pdb) print(d_hat.grad_fn) None
and therefore, the grad_fn is not defined, and it results in the following error:
*** RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn