Source code for torchx.nn.Util

# -*- coding: utf-8 -*-
import torch

from ..utils import minibatch_stddev_layer


[docs]class Cond(torch.nn.Module): """Similar to tf.cond""" def __init__(self, cond, a, b): super().__init__() self.cond = cond self.a = a self.b = b def forward(self, x: torch.Tensor): if self.cond(): return self.a(x) else: return self.b(x)
[docs]class MinibatchStddev(torch.nn.Module): """Increase the variation using minibatch standard deviation in a module""" def __init__(self, group_size: int = 4): super().__init__() self.group_size = group_size def forward(self, x: torch.Tensor): return minibatch_stddev_layer(x, self.group_size)
[docs]class PrintShape(torch.nn.Module): """Print shape of tensor and then forward it to next module. For debugging purposes. """ def __init__(self, format="{}"): super().__init__() self.format = format def forward(self, x: torch.Tensor): print(self.format.format(x.shape)) return x
[docs]class View(torch.nn.Module): """Set the view of a Tensor in a module""" def __init__(self, *shape): super().__init__() self.shape = shape def forward(self, x: torch.Tensor): return x.view(*self.shape) def __repr__(self): return f"View({', '.join(map(str, self.shape))})"