Source code for torchx.nn.Loss

# -*- coding: utf-8 -*-
import torch

from ..utils import lerp


[docs]class WGAN_ACGAN(torch.nn.Module): """WGAN + AC-GAN Loss Function Used as a loss function for training generators in GANs Note: References: `WGAN <https://arxiv.org/pdf/1704.00028.pdf>`_, `AC-GAN <https://arxiv.org/pdf/1610.09585.pdf>`_ """ def __init__(self, cond_weight: float = 1.0): super().__init__() self.cond_weight = cond_weight def forward(self, x: torch.Tensor, target: torch.Tensor): return -x.mean()
[docs]class WGANGP_ACGAN(torch.nn.Module): """WGAN-GP + AC-GAN Loss Function Used as a loss function for training discriminators in GANs Note: References: `WGAN-GP <https://arxiv.org/pdf/1704.00028.pdf>`_, `AC-GAN <https://arxiv.org/pdf/1610.09585.pdf>`_ """ def __init__( self, generator, discriminator, drift: float = 0.001, use_gp: bool = False ): super().__init__() self.generator = generator self.discriminator = discriminator self.device = next(discriminator.parameters()).device self.drift = drift self.use_gp = use_gp def forward( self, real_images: torch.Tensor, fake_images: torch.Tensor, alpha=0, reg_lambda=10, ): real_outputs = self.discriminator(real_images, alpha) fake_outputs = self.discriminator(fake_images, alpha) loss = fake_outputs.mean() - real_outputs.mean() if self.drift != 0: loss += self.drift * real_outputs.pow(2).mean() if self.use_gp: minibatch_size = real_images.shape[0] mixing_factors = torch.rand((minibatch_size, 1, 1, 1), device=self.device) mixed_images = lerp(real_images, fake_images, mixing_factors) mixed_images.requires_grad_(True) mixed_outputs = self.discriminator(mixed_images, alpha) gradient = torch.autograd.grad( outputs=mixed_outputs, inputs=mixed_images, grad_outputs=torch.ones_like(mixed_outputs), create_graph=True, retain_graph=True, only_inputs=True, )[0] gradient = gradient.view(gradient.shape[0], -1) gradient_penalty = ( reg_lambda * (gradient.norm(p=2, dim=1) - 1).pow(2).mean() ) loss += gradient_penalty return loss