Source code for labml_helpers.datasets.mnist

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from labml import lab
from labml.configs import BaseConfigs, aggregate, option


def _dataset(is_train, transform):
    return datasets.MNIST(str(lab.get_data_path()),
                          train=is_train,
                          download=True,
                          transform=transform)


[docs]class MNISTConfigs(BaseConfigs): """ Configurable MNIST data set. Arguments: dataset_name (str): name of the data set, ``MNIST`` dataset_transforms (torchvision.transforms.Compose): image transformations train_dataset (torchvision.datasets.MNIST): training dataset valid_dataset (torchvision.datasets.MNIST): validation dataset train_loader (torch.utils.data.DataLoader): training data loader valid_loader (torch.utils.data.DataLoader): validation data loader train_batch_size (int): training batch size valid_batch_size (int): validation batch size train_loader_shuffle (bool): whether to shuffle training data valid_loader_shuffle (bool): whether to shuffle validation data """ dataset_name: str = 'MNIST' dataset_transforms: transforms.Compose train_dataset: datasets.MNIST valid_dataset: datasets.MNIST train_loader: DataLoader valid_loader: DataLoader train_batch_size: int = 64 valid_batch_size: int = 1024 train_loader_shuffle: bool = True valid_loader_shuffle: bool = False
@option(MNISTConfigs.dataset_transforms) def mnist_transforms(): return transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) @option(MNISTConfigs.train_dataset) def mnist_train_dataset(c: MNISTConfigs): return _dataset(True, c.dataset_transforms) @option(MNISTConfigs.valid_dataset) def mnist_valid_dataset(c: MNISTConfigs): return _dataset(False, c.dataset_transforms) @option(MNISTConfigs.train_loader) def mnist_train_loader(c: MNISTConfigs): return DataLoader(c.train_dataset, batch_size=c.train_batch_size, shuffle=c.train_loader_shuffle) @option(MNISTConfigs.valid_loader) def mnist_valid_loader(c: MNISTConfigs): return DataLoader(c.valid_dataset, batch_size=c.valid_batch_size, shuffle=c.valid_loader_shuffle) aggregate(MNISTConfigs.dataset_name, 'MNIST', (MNISTConfigs.dataset_transforms, 'mnist_transforms'), (MNISTConfigs.train_dataset, 'mnist_train_dataset'), (MNISTConfigs.valid_dataset, 'mnist_valid_dataset'), (MNISTConfigs.train_loader, 'mnist_train_loader'), (MNISTConfigs.valid_loader, 'mnist_valid_loader'))