Skip to content

torchdata

load_data_cifar10(data_dir='./data')

Load the CIFAR-10 dataset. This function loads the CIFAR-10 dataset using the torchvision library. The data is split into a training set and a test set.

Parameters:

Name Type Description Default
data_dir str

The directory where the data is stored. Defaults to “./data”.

'./data'

Returns:

Type Description
Tuple[CIFAR10, CIFAR10]

Tuple[datasets.CIFAR10, datasets.CIFAR10]: A tuple containing the training set and the test set.

Examples:

>>> trainset, testset = load_data_cifar10()
>>> print(f"Training set size: {len(trainset)}")
Training set size: 50000
>>> print(f"Test set size: {len(testset)}")
Test set size: 10000
Source code in spotpython/data/torchdata.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def load_data_cifar10(data_dir: str = "./data") -> Tuple[datasets.CIFAR10, datasets.CIFAR10]:
    """Load the CIFAR-10 dataset.
        This function loads the CIFAR-10 dataset using the torchvision library.
        The data is split into a training set and a test set.

    Args:
        data_dir (str):
            The directory where the data is stored. Defaults to "./data".

    Returns:
        Tuple[datasets.CIFAR10, datasets.CIFAR10]:
            A tuple containing the training set and the test set.

    Examples:
        >>> trainset, testset = load_data_cifar10()
        >>> print(f"Training set size: {len(trainset)}")
        Training set size: 50000
        >>> print(f"Test set size: {len(testset)}")
        Test set size: 10000

    """
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)

    testset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset