Source code for labml.utils.lightning

from argparse import Namespace
from typing import Any, Dict, Optional, Union

import torch
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only

from labml import experiment, tracker, lab
from labml.internal.experiment import experiment_singleton


[docs]class LabMLLightningLogger(LightningLoggerBase): """ PyTorch Lightening logger integration. Pass an instance of this class to ``pytorch_lightning.Training`` as argument ``logger``. PyTorch Lightening will call relavent mehtods of this class to log hyper-parameters and metrics. """ def __init__(self): super().__init__() @property @rank_zero_experiment def experiment(self): return None @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = self._convert_params(params) params = self._flatten_dict(params) experiment.configs(params) @rank_zero_only def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None: if step is None: tracker.add_global_step() tracker.save(metrics) else: tracker.save(step, metrics) def reset_experiment(self): pass @property def save_dir(self) -> Optional[str]: return str(lab.get_experiments_path()) @property def name(self) -> str: return str(experiment_singleton().run.name) @property def version(self) -> str: return experiment_singleton().run.uuid