Source code for labml_helpers.schedule

from typing import Tuple, List

[docs]class Schedule: def __call__(self, x): raise NotImplementedError()
[docs]class Flat(Schedule): def __init__(self, value): self.__value = value def __call__(self, x): return self.__value def __str__(self): return f"Schedule({self.__value})"
[docs]class Dynamic(Schedule): def __init__(self, value): self.__value = value def __call__(self, x): return self.__value def update(self, value): self.__value = value def __str__(self): return "Dynamic"
[docs]class Piecewise(Schedule): """ ## Piecewise schedule """ def __init__(self, endpoints: List[Tuple[float, float]], outside_value: float = None): """ ### Initialize `endpoints` is list of pairs `(x, y)`. The values between endpoints are linearly interpolated. `y` values outside the range covered by `x` are `outside_value`. """ # `(x, y)` pairs should be sorted indexes = [e[0] for e in endpoints] assert indexes == sorted(indexes) self._outside_value = outside_value self._endpoints = endpoints def __call__(self, x): """ ### Find `y` for given `x` """ # iterate through each segment for (x1, y1), (x2, y2) in zip(self._endpoints[:-1], self._endpoints[1:]): # interpolate if `x` is within the segment if x1 <= x < x2: dx = float(x - x1) / (x2 - x1) return y1 + dx * (y2 - y1) # return outside value otherwise return self._outside_value def __str__(self): endpoints = ", ".join([f"({e[0]}, {e[1]})" for e in self._endpoints]) return f"Schedule[{endpoints}, {self._outside_value}]"
[docs]class RelativePiecewise(Piecewise): def __init__(self, relative_endpoits: List[Tuple[float, float]], total_steps: int): endpoints = [] for e in relative_endpoits: index = int(total_steps * e[0]) assert index >= 0 endpoints.append((index, e[1])) super().__init__(endpoints, outside_value=relative_endpoits[-1][1])