Source code for torchx.nn.Interpolate

# -*- coding: utf-8 -*-
import torch

from ..utils import lerp


[docs]class Lerp(torch.nn.Module): """A module that encapsulates the Linear Interpolation function""" def __init__(self, a, b, t): super().__init__() self.a = a self.b = b self.t = t def forward(self, x: torch.Tensor): return lerp(self.a(x), self.b(x), self.t)