Source code for torchx.nn.Norm
# -*- coding: utf-8 -*-
import torch
from ..utils import pixel_norm
[docs]class PixelwiseNorm(torch.nn.Module):
"""Torch module encapsulating the pixel norm operator"""
def __init__(self, epsilon: float = 1e-8):
super().__init__()
self.epsilon = epsilon
def forward(self, x: torch.Tensor):
return pixel_norm(x, self.epsilon)