Source code for labml_helpers.datasets.mnist

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from labml import lab
from labml.configs import BaseConfigs, aggregate, option


def _dataset(is_train, transform):
    return datasets.MNIST(str(lab.get_data_path()),
                          train=is_train,
                          download=True,
                          transform=transform)


[docs]class MNISTConfigs(BaseConfigs): dataset_name: str = 'MNIST' dataset_transforms: transforms.Compose train_dataset: datasets.MNIST valid_dataset: datasets.MNIST train_loader: DataLoader valid_loader: DataLoader train_batch_size: int = 64 valid_batch_size: int = 1024 train_loader_shuffle: bool = True valid_loader_shuffle: bool = False
@option(MNISTConfigs.dataset_transforms) def mnist_transforms(): return transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) @option(MNISTConfigs.train_dataset) def mnist_train_dataset(c: MNISTConfigs): return _dataset(True, c.dataset_transforms) @option(MNISTConfigs.valid_dataset) def mnist_valid_dataset(c: MNISTConfigs): return _dataset(False, c.dataset_transforms) @option(MNISTConfigs.train_loader) def mnist_train_loader(c: MNISTConfigs): return DataLoader(c.train_dataset, batch_size=c.train_batch_size, shuffle=c.train_loader_shuffle) @option(MNISTConfigs.valid_loader) def mnist_valid_loader(c: MNISTConfigs): return DataLoader(c.valid_dataset, batch_size=c.valid_batch_size, shuffle=c.valid_loader_shuffle) aggregate(MNISTConfigs.dataset_name, 'MNIST', (MNISTConfigs.dataset_transforms, 'mnist_transforms'), (MNISTConfigs.train_dataset, 'mnist_train_dataset'), (MNISTConfigs.valid_dataset, 'mnist_valid_dataset'), (MNISTConfigs.train_loader, 'mnist_train_loader'), (MNISTConfigs.valid_loader, 'mnist_valid_loader'))