Source code for torchx.params.Parameters

# -*- coding: utf-8 -*-
import os
import logging

import torch

from yaml import load

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader


[docs]class Parameters: """Handles reading and encapsulation of parameters in a yaml file. Parameters: num_epochs: Number of epochs to train for epoch_start: Start counting epochs from this number batch_size: Number of images in each batch checkpoint_step: How often to save checkpoints (epochs) validation_step: How often to perform validation (epochs) num_validation: How many validation images to use num_workers: Number of workers learning_rate: learning rate used for training cuda: GPU ids used for training use_gpu: whether to user gpu for training pretrained_model_path: path to pretrained model save_model_path: path to save model log_file: path to log file """ def __init__( self, param_file_path: str = None, num_epochs: int = 100, epoch_start: int = 0, batch_size: int = 1, checkpoint_step: int = 2, validation_step: int = 2, num_validation: int = 1000, num_workers: int = 1, learning_rate: float = 0.001, cuda: str = "0", use_gpu: bool = True, pretrained_model_path: float = None, save_model_path: str = "./.checkpoints", log_file: str = "./model.log", **params ): self.num_epochs = num_epochs self.epoch_start = epoch_start self.batch_size = batch_size self.checkpoint_step = checkpoint_step self.validation_step = validation_step self.num_validation = num_validation self.num_workers = num_workers self.learning_rate = learning_rate self.cuda = cuda self.use_gpu = use_gpu self.pretrained_model_path = pretrained_model_path self.save_model_path = save_model_path self.log_file = log_file self.__dict__.update(**params) if param_file_path is not None and os.path.isfile(param_file_path): with open(param_file_path) as file: self.__dict__.update(**load(file, Loader=Loader)) self.use_gpu = self.use_gpu and torch.cuda.is_available() def get_logger(self, name: str, overwrite: bool = False, level: int = logging.INFO): log = logging.getLogger(name) if overwrite: with open(self.log_file, "w"): pass hdlr = logging.FileHandler(self.log_file) formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") hdlr.setFormatter(formatter) log.addHandler(hdlr) log.setLevel(level) return log