Source code for labml_helpers.metrics.simple_state

from typing import Generic, TypeVar, Optional

from . import StateModule

T = TypeVar('T')


[docs]class SimpleState(Generic[T]): state: Optional[T] def __init__(self): self.state = None def get(self) -> T: return self.state def set(self, data: T): self.state = data def reset(self): self.state = None
[docs]class SimpleStateModule(StateModule, Generic[T]): data: SimpleState[T] def __init__(self): super().__init__() def set(self, data: T): self.data.set(data) def get(self) -> T: return self.data.get() def create_state(self): return SimpleState() def set_state(self, data: any): self.data = data def on_epoch_start(self): self.data.reset() def on_epoch_end(self): pass