torchx

nn

An extension to the standard torch.nn neural network library which contains a number of useful classes and functions for building neural networks

Classes:

Conv2d(in_channels, out_channels, …)

Applies a 2D convolution over an input signal composed of several input planes.

ConvTranspose2d(in_channels, out_channels, …)

Applies a 2D convolution transpose over an input signal composed of several input planes.

Lerp(a, b, t)

A module that encapsulates the Linear Interpolation function

Linear(in_features, out_features, bias, …)

Applies a linear transformation to the incoming data.

WGAN_ACGAN(cond_weight)

WGAN + AC-GAN Loss Function

WGANGP_ACGAN(generator, discriminator, …)

WGAN-GP + AC-GAN Loss Function

Module()

Convenient intermediary parent class that implements useful module functions

Cond(cond, a, b)

Similar to tf.cond

MinibatchStddev(group_size)

Increase the variation using minibatch standard deviation in a module

PixelwiseNorm(epsilon)

Torch module encapsulating the pixel norm operator

PrintShape([format])

Print shape of tensor and then forward it to next module.

View(*shape)

Set the view of a Tensor in a module

Functions:

Conv2dBatch(in_channels, out_channels[, …])

A 2D convolution followed by a batch normalization and ReLU activation.

ConvTranspose2dBatch(in_channels, out_channels)

A 2D convolution transpose followed by a batch normalization and ReLU activation.

Conv2dGroup(in_channels, out_channels[, …])

A 2D convolution followed by a group norm and ReLU activation.

DSConv(in_channels, out_channels[, stride])

Depth-wise separable convolution followed by a 2D convolution each followed by a batch normalization and ReLU activation.

DWConv(in_channels, out_channels[, stride])

Depth-wise separable convolution followed by a batch normalization and ReLU activation.

class torchx.nn.Cond(cond, a, b)[source]

Similar to tf.cond

class torchx.nn.Conv2d(in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 0, dilation: int = 1, bias: bool = True, gain: float = 1.4142135623730951, use_wscale: bool = False, fan_in: float = None)[source]

Applies a 2D convolution over an input signal composed of several input planes.

A simpler, modified version of the standard torch.nn.Conv2d, which supports an equalized learning rate by scaling the weights dynamically in each forward pass. Implemented as described in https://arxiv.org/pdf/1710.10196.pdf Reference: https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L23-L29

The weight parameter is initialized using the standard normal if use_wscale is True. The bias parameter is initialized to zero.

Parameters
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int or tuple) – Size of the convolving kernel. Default: 3

  • stride (int or tuple) – Stride of the convolution. Default: 1

  • padding (int or tuple) – Zero-padding added to both sides of the input. Default: 0

  • dilation (int or tuple) – Spacing between kernel elements. Default: 1

  • bias (bool) – If True, adds a learnable bias to the output. Default: True

  • gain (float) – The gain for the scaled weight. Default: sqrt(2)

  • use_wscale (bool) – If True, scales the weights in each forward pass. Default: False

  • fan_in (float) – Size of the weight parameter to scale by. Default: None

Note

If fan_in is not provided, it is computed as \(\text{fan_in} = \text{in_channels} \times \text{kernel_size} ^ 2\)

Note

The wscale is computed as \(\text{wscale} = \frac{\text{gain}}{\sqrt{\text{fan_in}}}\)

Note

See torch.nn.Conv2d for more details on the 2d convolution operator.

Methods:

extra_repr()

Set the extra representation of the module

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

torchx.nn.Conv2dBatch(in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 0, bias: bool = True, leaky: float = None, **kwargs)[source]

A 2D convolution followed by a batch normalization and ReLU activation.

torchx.nn.Conv2dGroup(in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 0, bias: bool = True, num_groups=1, **kwargs)[source]

A 2D convolution followed by a group norm and ReLU activation.

class torchx.nn.ConvTranspose2d(in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 0, bias: bool = True, gain: float = 1.4142135623730951, use_wscale: bool = False, fan_in: float = None)[source]

Applies a 2D convolution transpose over an input signal composed of several input planes.

A simpler, modified version of the standard torch.nn.ConvTranspose2d, which supports an equalized learning rate by scaling the weights dynamically in each forward pass. Implemented as described in https://arxiv.org/pdf/1710.10196.pdf Reference: https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L23-L29

The weight parameter is initialized using the standard normal if use_wscale is True. The bias parameter is initialized to zero.

Parameters
  • in_channels – Number of channels in the input image

  • out_channels – Number of channels produced by the convolution

  • kernel_size – Size of the convolving kernel

  • stride – Stride of the convolution

  • padding – Zero-padding added to both sides of the input

  • bias – If True, adds a learnable bias to the output

  • gain – The gain for the scaled weight

  • use_wscale – If True, scales the weights in each forward pass

  • fan_in – Size of the weight parameter to scale by

Note

If fan_in is not provided, it is computed as \(\text{fan_in} = \text{in_channels} \times \text{kernel_size} ^ 2\)

Note

The wscale is computed as \(\text{wscale} = \frac{\text{gain}}{\sqrt{\text{fan_in}}}\)

Note

See torch.nn.ConvTranspose2d for more details on the 2d convolution operator.

Methods:

extra_repr()

Set the extra representation of the module

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

torchx.nn.ConvTranspose2dBatch(in_channels: int, out_channels: int, kernel_size: int = 4, stride: int = 2, padding: int = 0, bias: bool = False, leaky: float = None, **kwargs)[source]

A 2D convolution transpose followed by a batch normalization and ReLU activation.

torchx.nn.DSConv(in_channels: int, out_channels: int, stride: int = 1, **kwargs)[source]

Depth-wise separable convolution followed by a 2D convolution each followed by a batch normalization and ReLU activation.

torchx.nn.DWConv(in_channels: int, out_channels: int, stride: int = 1, **kwargs)[source]

Depth-wise separable convolution followed by a batch normalization and ReLU activation.

class torchx.nn.Lerp(a, b, t)[source]

A module that encapsulates the Linear Interpolation function

class torchx.nn.Linear(in_features: int, out_features: int, bias: bool = True, gain: float = 1.4142135623730951, use_wscale: bool = False, fan_in: float = None)[source]

Applies a linear transformation to the incoming data.

A simpler, modified version of the normal torch.nn.Conv2d which supports an equalized learning rate by scaling the weights dynamically in each forward pass. Implemented as described in https://arxiv.org/pdf/1710.10196.pdf Reference: https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L23-L29

The weight parameter is initialized using the standard normal if use_wscale is True. The bias parameter is initialized to zero.

Parameters
  • in_features – size of each input sample

  • out_features – size of each output sample

  • bias – If set to True, the layer will add a learnable additive bias.

Methods:

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

class torchx.nn.MinibatchStddev(group_size: int = 4)[source]

Increase the variation using minibatch standard deviation in a module

class torchx.nn.Module[source]

Convenient intermediary parent class that implements useful module functions

class torchx.nn.PixelwiseNorm(epsilon: float = 1e-08)[source]

Torch module encapsulating the pixel norm operator

class torchx.nn.PrintShape(format='{}')[source]

Print shape of tensor and then forward it to next module.

For debugging purposes.

class torchx.nn.View(*shape)[source]

Set the view of a Tensor in a module

class torchx.nn.WGANGP_ACGAN(generator, discriminator, drift: float = 0.001, use_gp: bool = False)[source]

WGAN-GP + AC-GAN Loss Function

Used as a loss function for training discriminators in GANs

Note

References: WGAN-GP, AC-GAN

class torchx.nn.WGAN_ACGAN(cond_weight: float = 1.0)[source]

WGAN + AC-GAN Loss Function

Used as a loss function for training generators in GANs

Note

References: WGAN, AC-GAN

optim

params

Classes:

Parameters(param_file_path, num_epochs, …)

Handles reading and encapsulation of parameters in a yaml file.

class torchx.params.Parameters(param_file_path: str = None, num_epochs: int = 100, epoch_start: int = 0, batch_size: int = 1, checkpoint_step: int = 2, validation_step: int = 2, num_validation: int = 1000, num_workers: int = 1, learning_rate: float = 0.001, cuda: str = '0', use_gpu: bool = True, pretrained_model_path: float = None, save_model_path: str = './.checkpoints', log_file: str = './model.log', **params)[source]

Handles reading and encapsulation of parameters in a yaml file.

Parameters
  • num_epochs – Number of epochs to train for

  • epoch_start – Start counting epochs from this number

  • batch_size – Number of images in each batch

  • checkpoint_step – How often to save checkpoints (epochs)

  • validation_step – How often to perform validation (epochs)

  • num_validation – How many validation images to use

  • num_workers – Number of workers

  • learning_rate – learning rate used for training

  • cuda – GPU ids used for training

  • use_gpu – whether to user gpu for training

  • pretrained_model_path – path to pretrained model

  • save_model_path – path to save model

  • log_file – path to log file

utils

Functions:

minibatch_stddev_layer(x[, group_size])

Appends a feature map containing the standard deviation of the minibatch.

hex_to_rgb(hex_colour)

Converts a colour from hex to rgb

pixel_norm(x[, epsilon])

Applies a pixel-wise normalization.

encode_array(array, encode_list)

Encodes the provided array using the provided encoding list.

decode_array(array, decode_list)

Decodes the provided array using the provided decoding list.

torchx.utils.decode_array(array: numpy.ndarray, decode_list: Tuple[int, numpy.ndarray])[source]

Decodes the provided array using the provided decoding list.

Note: this function is vectorized and is thus very fast.

Parameters
  • array – The array to decode

  • encode_list – A tuple of decoding value, index pairs. Each index must be an integer and each encoding value must be a numpy array of the same size as the last axis in the array that is to be decoded.

Raises

AssertionError the decoding values are not all the same size

torchx.utils.encode_array(array: numpy.ndarray, encode_list: Tuple[int, numpy.ndarray])[source]

Encodes the provided array using the provided encoding list.

Parameters
  • array – The array to encode

  • encode_list – A tuple of encoding value, index pairs. Each index must be an integer and each encoding value must be a numpy array of the same size as the last axis in the array that is to be encoded.

Raises

AssertionError if the size of the last axis of the array does not match – the sizes of the encoding values

Note

This function is vectorized and is thus very fast.

torchx.utils.hex_to_rgb(hex_colour: str)[source]

Converts a colour from hex to rgb

Params:

hex_colour: a hex string matching the regex ‘[0-9a-fA-F]{6}’

Returns

a numpy array containing the rgb values

torchx.utils.minibatch_stddev_layer(x, group_size=4)[source]

Appends a feature map containing the standard deviation of the minibatch.

Note

Implemented as described in this paper. Reference.

torchx.utils.pixel_norm(x: torch.Tensor, epsilon: float = 1e-08)[source]

Applies a pixel-wise normalization.

Note

Implemented as described in this paper. Reference.