Source code for labml.utils.cache

import functools
import json
import pickle
from pathlib import Path
from typing import Callable, Any, Optional

from labml import lab


def _cache_load(path: Path, file_type: str):
    if file_type == 'json':
        with open(str(path), 'r') as f:
            return json.load(f)
    elif file_type == 'pickle':
        with open(str(path), 'rb') as f:
            return pickle.load(f)
    else:
        raise ValueError(f'Unknown file type: {file_type}')


def _cache_save(path: Path, value: Any, file_type: str):
    if file_type == 'json':
        with open(str(path), 'w') as f:
            json.dump(value, f)
    elif file_type == 'pickle':
        with open(str(path), 'wb') as f:
            pickle.dump(value, f)
    else:
        raise ValueError(f'Unknown file type: {file_type}')


def _get_cache_path(name: str, file_type: str):
    cache_path = lab.get_data_path() / 'cache'
    if not cache_path.exists():
        cache_path.mkdir(parents=True)
    return cache_path / f'{name}.{file_type}'


def _cache_wrap(name: str, loader: Callable[[], Any], *, file_type: str) -> Any:
    path = _get_cache_path(name, file_type)
    if path.exists():
        return _cache_load(path, file_type)
    else:
        value = loader()
        _cache_save(path, value, file_type)

        return value


[docs] def cache(name: str, loader: Optional[Callable[[], Any]] = None, *, file_type: str = 'json') -> Any: """ This caches results of a function. Can be used as a decorator or you can pass a lambda function to it that takes no arguments. *It doesn't cache by arguments.* Arguments: name (str): name of the cache loader (Callable[[], Any], optional): the function that generates the data to be cached Keyword Arguments: file_type (str, optional): The file type to store the data. Defaults to ``json``. """ if loader is not None: return _cache_wrap(name, loader, file_type=file_type) def decorator_func(f: Callable): @functools.wraps(f) def wrapper(*args, **kwargs): path = _get_cache_path(name, file_type) if path.exists(): return _cache_load(path, file_type) else: value = f(*args, **kwargs) _cache_save(path, value, file_type) return value return wrapper return decorator_func
[docs] def cache_get(name: str, file_type: str = 'json') -> Any: """ Get cached data. Arguments: name (str): name of the cache file_type (str, optional): The file type to store the data. Defaults to ``json``. """ path = _get_cache_path(name, file_type) if path.exists(): return _cache_load(path, file_type) else: return None
[docs] def cache_set(name: str, value: Any, file_type: str = 'json') -> Any: """ Save data in cache. Arguments: name (str): name of the cache value (any): data to be cached file_type (str, optional): The file type to store the data. Defaults to ``json``. """ path = _get_cache_path(name, file_type) _cache_save(path, value, file_type)