Source code for labml_helpers.datasets.cifar10

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from labml import lab
from labml.configs import BaseConfigs


def _dataset(is_train, transform):
    return datasets.CIFAR10(str(lab.get_data_path()),
                            train=is_train,
                            download=True,
                            transform=transform)


[docs]class CIFAR10Configs(BaseConfigs): """ Configurable CIFAR 10 data set. Arguments: dataset_name (str): name of the data set, ``CIFAR10`` dataset_transforms (torchvision.transforms.Compose): image transformations train_dataset (torchvision.datasets.CIFAR10): training dataset valid_dataset (torchvision.datasets.CIFAR10): validation dataset train_loader (torch.utils.data.DataLoader): training data loader valid_loader (torch.utils.data.DataLoader): validation data loader train_batch_size (int): training batch size valid_batch_size (int): validation batch size train_loader_shuffle (bool): whether to shuffle training data valid_loader_shuffle (bool): whether to shuffle validation data """ dataset_name: str = 'CIFAR10' dataset_transforms: transforms.Compose train_dataset: datasets.CIFAR10 valid_dataset: datasets.CIFAR10 train_loader: DataLoader valid_loader: DataLoader train_batch_size: int = 64 valid_batch_size: int = 1024 train_loader_shuffle: bool = True valid_loader_shuffle: bool = False
@CIFAR10Configs.calc(CIFAR10Configs.dataset_transforms) def cifar10_transforms(): return transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) @CIFAR10Configs.calc(CIFAR10Configs.train_dataset) def cifar10_train_dataset(c: CIFAR10Configs): return _dataset(True, c.dataset_transforms) @CIFAR10Configs.calc(CIFAR10Configs.valid_dataset) def cifar10_valid_dataset(c: CIFAR10Configs): return _dataset(False, c.dataset_transforms) @CIFAR10Configs.calc(CIFAR10Configs.train_loader) def cifar10_train_loader(c: CIFAR10Configs): return DataLoader(c.train_dataset, batch_size=c.train_batch_size, shuffle=c.train_loader_shuffle) @CIFAR10Configs.calc(CIFAR10Configs.valid_loader) def cifar10_valid_loader(c: CIFAR10Configs): return DataLoader(c.valid_dataset, batch_size=c.valid_batch_size, shuffle=c.valid_loader_shuffle) CIFAR10Configs.aggregate(CIFAR10Configs.dataset_name, 'CIFAR10', (CIFAR10Configs.dataset_transforms, 'cifar10_transforms'), (CIFAR10Configs.train_dataset, 'cifar10_train_dataset'), (CIFAR10Configs.valid_dataset, 'cifar10_valid_dataset'), (CIFAR10Configs.train_loader, 'cifar10_train_loader'), (CIFAR10Configs.valid_loader, 'cifar10_valid_loader'))