import signal
import typing
from typing import Optional, Tuple, Any, Collection
from labml import tracker, logger, experiment, monit
from labml.configs import BaseConfigs, meta_config, option
from labml.internal.monitor import Loop
from labml.logger import Text
class TrainingLoopIterator(Collection):
def __init__(self, start: int, total: int, step: Optional[int]):
self.step = step
self.total = total
self.start = start
self.i = None
def __iter__(self):
self.i = None
return self
def __next__(self):
if self.step is not None:
if self.i is None:
self.i = self.start
else:
self.i += self.step
else:
if self.i is None:
self.i = 0
else:
self.i += 1
if self.i >= self.total:
raise StopIteration()
if self.step is None:
return tracker.get_global_step()
else:
return self.i
def __len__(self) -> int:
if self.step is not None:
return (self.total - self.start) // self.step
else:
return self.total
def __contains__(self, x: object) -> bool:
return False
class TrainingLoop:
_iter: Optional[TrainingLoopIterator]
__loop: Loop
__signal_received: Optional[Tuple[Any, Any]]
def __init__(self, *,
loop_count: int,
loop_step: Optional[int],
is_save_models: bool,
log_new_line_interval: int,
log_write_interval: int,
save_models_interval: int,
is_loop_on_interrupt: bool):
self.__loop_count = loop_count
self.__loop_step = loop_step
self.__is_save_models = is_save_models
self.__log_new_line_interval = log_new_line_interval
self.__log_write_interval = log_write_interval
self.__last_write_step = 0
self.__last_new_line_step = 0
self.__save_models_interval = save_models_interval
self.__last_save_step = 0
self.__signal_received = None
self.__is_loop_on_interrupt = is_loop_on_interrupt
self._iter = None
def __iter__(self):
self._iter = TrainingLoopIterator(tracker.get_global_step(),
self.__loop_count,
self.__loop_step)
self.__loop = monit.loop(typing.cast(Collection, self._iter))
iter(self.__loop)
try:
self.old_handler = signal.signal(signal.SIGINT, self.__handler)
except ValueError:
pass
return self
@property
def idx(self):
if not self._iter:
return 0
if not self._iter.i:
return 0
if self.__loop_step is None:
return self._iter.i
return self._iter.i / self.__loop_step
def __finish(self):
try:
signal.signal(signal.SIGINT, self.old_handler)
except ValueError:
pass
tracker.save()
tracker.new_line()
if self.__is_save_models:
logger.log("Saving model...")
experiment.save_checkpoint()
# def is_interval(self, interval: int, global_step: Optional[int] = None):
# if global_step is None:
# global_step = tracker.get_global_step()
#
# if global_step - self.__loop_step < 0:
# return False
#
# if global_step // interval > (global_step - self.__loop_step) // interval:
# return True
# else:
# return False
def __next__(self):
if self.__signal_received is not None:
logger.log('\nKilling Loop.', Text.danger)
monit.finish_loop()
self.__finish()
raise StopIteration("SIGINT")
try:
global_step = next(self.__loop)
except StopIteration as e:
self.__finish()
raise e
tracker.set_global_step(global_step)
if global_step - self.__last_write_step >= self.__log_write_interval:
tracker.save()
self.__last_write_step = global_step
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
tracker.new_line()
self.__last_new_line_step = global_step
# if self.is_interval(self.__log_write_interval, global_step):
# tracker.save()
# if self.is_interval(self.__log_new_line_interval, global_step):
# logger.log()
# if (self.__is_save_models and
# self.is_interval(self.__save_models_interval, global_step)):
# experiment.save_checkpoint()
if (self.__is_save_models and
global_step - self.__last_save_step >= self.__save_models_interval):
experiment.save_checkpoint()
self.__last_save_step = global_step
return global_step
def __handler(self, sig, frame):
# Pass second interrupt without delaying
if self.__signal_received is not None:
logger.log('\nSIGINT received twice. Stopping...', Text.danger)
self.old_handler(*self.__signal_received)
return
if self.__is_loop_on_interrupt:
# Store the interrupt signal for later
self.__signal_received = (sig, frame)
logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
else:
self.__finish()
logger.log('Killing loop...', Text.danger)
self.old_handler(sig, frame)
def __str__(self):
return "LabTrainingLoop"
[docs]class TrainingLoopConfigs(BaseConfigs):
r"""
This is a configurable training loop. You can extend this class for your configurations
if it involves a training loop.
>>> for step in conf.training_loop:
>>> ...
Arguments:
loop_count (int): Total number of steps. Defaults to ``10``.
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
is_save_models (bool): Whether to call :func:`labml.experiment.save_checkpoint` on each iteration.
Defaults to ``False``.
save_models_interval (int): The interval (in steps) to save models. Defaults to ``1``.
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
Defaults to ``1``.
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
Defaults to ``1``.
is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete.
Defaults to ``False``.
"""
loop_count: int = 10
loop_step: int = 1
is_save_models: bool = False
log_new_line_interval: int = 1
log_write_interval: int = 1
save_models_interval: int = 1
is_loop_on_interrupt: bool = False
training_loop: TrainingLoop
@option(TrainingLoopConfigs.training_loop)
def _loop_configs(c: TrainingLoopConfigs):
return TrainingLoop(loop_count=c.loop_count,
loop_step=c.loop_step,
is_save_models=c.is_save_models,
log_new_line_interval=c.log_new_line_interval,
log_write_interval=c.log_write_interval,
save_models_interval=c.save_models_interval,
is_loop_on_interrupt=c.is_loop_on_interrupt)
meta_config(TrainingLoopConfigs.loop_step,
TrainingLoopConfigs.loop_count,
TrainingLoopConfigs.is_save_models,
TrainingLoopConfigs.log_new_line_interval,
TrainingLoopConfigs.log_write_interval,
TrainingLoopConfigs.save_models_interval,
TrainingLoopConfigs.is_loop_on_interrupt)