Source code for labml.utils.fastai

from fastai.callback.core import Callback, patch
from fastai.learner import Learner

from labml import tracker


[docs]class LabMLFastAICallback(Callback): """ FastAI callback integration. Pass an instance of this class to FastAI learner as argument ``cbs``. FastAI will call relavent mehtods of this class to log metrics. """ def __init__(self, **kwargs): super().__init__(**kwargs) def before_fit(self): pass def after_batch(self): tracker.add_global_step() if self.training: metrics = {'loss.train': self.learn.loss} else: metrics = {'loss.valid': self.learn.loss} try: for m in self.learn.metrics: if m.value is not None: metrics[m.name] = m.value except: pass tracker.save(metrics) def after_epoch(self): metrics = {} try: for m in self.learn.metrics: if m.value is not None: metrics[m.name] = m.value except: pass tracker.save(metrics) tracker.new_line() def after_fit(self): pass
@patch def labml_configs(self: Learner): configs = {} try: configs['n_epoch'] = self.learn.n_epoch configs['model_class'] = str(type(self.learn.model)) except: pass return configs