Source code for labml_helpers.datasets.remote.server

import pickle

import uvicorn
from fastapi import FastAPI, Request, Response
from import Dataset

class _ServerDataset:
    def __init__(self, name: str, dataset: Dataset):
        self.dataset = dataset = name

    def len_handler(self, request: Request):
        sample = pickle.dumps(len(self.dataset))
        return Response(sample, media_type='binary/pickle')

    def item_handler(self, request: Request, idx: str):
        sample = self.dataset[int(idx)]

        sample = pickle.dumps(sample)
        return Response(sample, media_type='binary/pickle')

[docs]class DatasetServer: r""" Remote dataset server `Here's a sample usage of the server <>`_ """ def __init__(self): = FastAPI() self.datasets = {}
[docs] def add_dataset(self, name: str, dataset: Dataset): """ Add a dataset Arguments: name (str): name of the data set dataset (Dataset): dataset to be served """ assert name not in self.datasets sd = _ServerDataset(name, dataset) self.datasets[name] = sd"/" + name + "/len", sd.len_handler, methods=["GET"])"/" + name + "/item/{idx}", sd.item_handler, methods=["GET"])
[docs] def start(self, host: str = "", port: int = 8000): """ Start the server Arguments: host (str): hostname of the server port (int): server port """, host=host, port=port)
def _test(): from labml import lab from torchvision import datasets, transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) dataset = datasets.MNIST(str(lab.get_data_path()), train=True, download=True, transform=transform) s = DatasetServer() s.add_dataset('mnist_train', dataset) s.start() if __name__ == '__main__': _test()