Source code for labml.utils.keras

from typing import Dict, Optional

import tensorflow as tf

from labml import tracker, logger

_MAP = {
    'size': None,
    'batch': None,
    'loss': 'loss.train',
    'accuracy': 'accuracy.train',
    'val_loss': 'loss.valid',
    'val_accuracy': 'accuracy.valid'
}


[docs]class LabMLKerasCallback(tf.keras.callbacks.Callback): """ Keras callback integration. Pass an instance of this class to Keras model ``fit`` method as argument ``callbacks``. Keras will call relavent mehtods of this class to log metrics. """ def __init__(self, save_batch_frequency: int = 1): super().__init__() self.save_batch_frequency = save_batch_frequency @staticmethod def _parse_logs(logs: Optional[Dict[str, any]]): data = {} if logs is None: return data for k, v in logs.items(): if k in _MAP: k = _MAP[k] if k is None: continue data[k] = v return data def on_epoch_end(self, epoch, logs=None): tracker.save(self._parse_logs(logs)) logger.log() def on_train_batch_end(self, batch, logs=None): tracker.add_global_step() tracker.add(self._parse_logs(logs)) if batch % self.save_batch_frequency == 0: tracker.save()