Source code for torchx.utils.Norm
# -*- coding: utf-8 -*-
import torch
[docs]def pixel_norm(x: torch.Tensor, epsilon: float = 1e-8):
"""Applies a pixel-wise normalization.
Note:
Implemented as described in `this paper <https://arxiv.org/pdf/1710.10196.pdf>`_.
`Reference <https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120-L122>`_.
""" # noqa: E501
return x * (x.pow(2).mean(axis=1, keepdim=True) + epsilon).rsqrt()