Source code for torchx.utils.Batch

# -*- coding: utf-8 -*-
import torch


[docs]def minibatch_stddev_layer(x, group_size=4): """Appends a feature map containing the standard deviation of the minibatch. Note: Implemented as described in `this paper <https://arxiv.org/pdf/1710.10196.pdf>`_. `Reference <https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L127-L139>`_. """ # noqa: E501 # Minibatch must be divisible by (or smaller than) group_size. group_size = min(group_size, x.shape[0]) s = x.shape # [NCHW] Input shape. # [GMCHW] Split minibatch into M groups of size G. y = x.reshape(group_size, -1, s[1], s[2], s[3]) y = y.float() # [GMCHW] Cast to FP32. y -= y.mean(axis=0, keepdims=True) # [GMCHW] Subtract mean over group. y = y.pow(2).mean(axis=0) # [MCHW] Calc variance over group. y = (y + 1e-8).sqrt() # [MCHW] Calc stddev over group. # [M111] Take average over fmaps and pixels. y = y.mean(axis=[1, 2, 3], keepdims=True) y = y.to(x.dtype) # [M111] Cast back to original data type. y = y.repeat(group_size, 1, s[2], s[3]) # [N1HW] Replicate over group and pixels. return torch.cat([x, y], axis=1) # [NCHW] Append as new fmap.