Source code for labml_helpers.train_valid

from typing import Optional, Dict, List, Callable, Any

import torch.optim
import torch.optim
from torch import nn

import labml.utils.pytorch as pytorch_utils
from labml import tracker, monit
from labml.configs import option, meta_config
from .device import DeviceConfigs
from .metrics import StateModule
from .training_loop import TrainingLoopConfigs

class ModeState:
    def __init__(self):
        self._rollback_stack = []

        self.is_train = False
        self.is_log_activations = False
        self.is_log_parameters = False
        self.is_optimize = False

    def _enter(self, mode: Dict[str, any]):
        rollback = {}
        for k, v in mode.items():
            if v is None:
            rollback[k] = getattr(self, k)
            setattr(self, k, v)


        return len(self._rollback_stack)

    def _exit(self, n: int):
        assert n == len(self._rollback_stack)

        rollback = self._rollback_stack[-1]

        for k, v in rollback.items():
            setattr(self, k, v)

    def update(self, *,
               is_train: Optional[bool] = None,
               is_log_parameters: Optional[bool] = None,
               is_log_activations: Optional[bool] = None,
               is_optimize: Optional[bool] = None):
        return Mode(self,

class Mode:
    def __init__(self, mode: ModeState, **kwargs: any):
        self.mode = mode
        self.update = {}
        for k, v in kwargs.items():
            if v is not None:
                self.update[k] = v

        self.idx = -1

    def __enter__(self):
        self.idx = self.mode._enter(self.update)

    def __exit__(self, exc_type, exc_val, exc_tb):

class ForwardHook:
    def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
        self.mode = mode
        self.model_name = model_name = name
        self.module = module

    def save(self, name: str, output):
        if isinstance(output, torch.Tensor):
            pytorch_utils.store_var(name, output)
        elif isinstance(output, tuple):
            for i, o in enumerate(output):
      "{name}.{i}", o)

    def __call__(self, module, i, o):
        if not self.mode.is_log_activations:
            return"module.{self.model_name}.{}", o)

def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
    for name, module in model.named_modules():
        if name == '':
            name = 'full'
        ForwardHook(mode, model_name, name, module)

class Trainer:
    def __init__(self, *,
                 name: str,
                 mode: ModeState,
                 inner_iterations: int,
                 state_modules: List[StateModule],
                 step: Callable[[any, 'BatchIndex'], None]):
        self.mode = mode = name
        self.step = step
        self.data_loader = data_loader
        self.state_modules = state_modules
        self.__iterable = None
        self.__states = [sm.create_state() for sm in self.state_modules]
        self._batch_index = BatchIndex(len(data_loader), inner_iterations)

    def reset(self):
        self.__iterable = None

    def __call__(self):
        for sm, s in zip(self.state_modules, self.__states):

        if self.__iterable is None or self._batch_index.completed:
            self.__iterable = iter(self.data_loader)
            for sm in self.state_modules:
        with torch.set_grad_enabled(self.mode.is_train):

        if self._batch_index.completed:
            for sm in self.state_modules:

    def __iterate(self):
        with monit.section(, is_partial=True):
            if self._batch_index.idx == 0:
            while not self._batch_index.iteration_completed:
                batch = next(self.__iterable)

                self.step(batch, self._batch_index)



class BatchIndex:
    idx: int
    total: int
    iteration: int
    total_iterations: int

    def __init__(self, total: int, total_iterations: int):
        self.total_iterations = total_iterations = total

    def is_interval(self, interval: int):
        if interval <= 0:
            return False
        if self.idx + 1 ==
            return True
            return (self.idx + 1) % interval == 0

    def is_last(self):
        return self.idx + 1 ==

    def completed(self):
        return self.iteration >= self.total_iterations

    def iteration_completed(self):
        # // is important so that the last step happens on the last iteration
        return self.idx >= (self.iteration + 1) * // self.total_iterations

    def epoch_progress(self):
        return self.idx /

    def step(self):
        self.idx += 1

    def step_inner(self):
        self.iteration += 1

    def reset(self):
        self.idx = 0
        self.iteration = 0

[docs]class TrainValidConfigs(TrainingLoopConfigs): state_modules: List[StateModule] mode: ModeState epochs: int = 10 trainer: Trainer validator: Trainer train_loader: valid_loader: loop_count = '_data_loop_count' loop_step = None inner_iterations: int = 1 def init(self): pass def step(self, batch: Any, batch_idx: BatchIndex): raise NotImplementedError def run_step(self): for i in range(self.inner_iterations): with tracker.namespace('sample'): self.sample() with self.mode.update(is_train=True): with tracker.namespace('train'): self.trainer() if self.validator: with tracker.namespace('valid'): self.validator() def run(self): with monit.section("Initialize"): self.init() _ = self.validator _ = self.trainer for _ in self.training_loop: self.run_step() def sample(self): pass
@option(TrainValidConfigs.trainer) def _default_trainer(c: TrainValidConfigs): return Trainer(name='Train', mode=c.mode, data_loader=c.train_loader, inner_iterations=c.inner_iterations, state_modules=c.state_modules, step=c.step) @option(TrainValidConfigs.validator) def _default_validator(c: TrainValidConfigs): return Trainer(name='Valid', mode=c.mode, data_loader=c.valid_loader, inner_iterations=c.inner_iterations, state_modules=c.state_modules, step=c.step) @option(TrainValidConfigs.loop_count) def _data_loop_count(c: TrainValidConfigs): return c.epochs class SimpleTrainValidConfigs(TrainValidConfigs): optimizer: torch.optim.Adam model: nn.Module device: torch.device = DeviceConfigs() loss_func: nn.Module update_batches: int = 1 log_params_updates: int = 2 ** 32 # 0 if not log_activations_batches: int = 2 ** 32 # 0 if not log_save_batches: int = 1 state_modules: List[StateModule] = [] def init(self): pass def step(self, batch: Any, batch_idx: BatchIndex): self.model.train(self.mode.is_train) data, target = batch[0].to(self.device), batch[1].to(self.device) if self.mode.is_train: tracker.add_global_step(len(data)) is_log_activations = batch_idx.is_interval(self.log_activations_batches) with monit.section("model"): with self.mode.update(is_log_activations=is_log_activations): output = self.model(data) loss = self.loss_func(output, target) tracker.add("loss.", loss) if self.mode.is_train: with monit.section('backward'): loss.backward() if batch_idx.is_interval(self.update_batches): with monit.section('optimize'): self.optimizer.step() if batch_idx.is_interval(self.log_params_updates): tracker.add('model', self.model) self.optimizer.zero_grad() if batch_idx.is_interval(self.log_save_batches): meta_config(SimpleTrainValidConfigs.update_batches, SimpleTrainValidConfigs.log_params_updates, SimpleTrainValidConfigs.log_activations_batches) @option(SimpleTrainValidConfigs.optimizer) def _default_optimizer(c: SimpleTrainValidConfigs): from labml_helpers.optimizer import OptimizerConfigs opt_conf = OptimizerConfigs() opt_conf.parameters = c.model.parameters() return opt_conf