cifar10datamodule
CIFAR10DataModule
¶
Bases: LightningDataModule
A LightningDataModule for handling CIFAR10 data.
Torchvision provides many built-in datasets in the torchvision.datasets module,
as well as utility classes for building your own datasets. All datasets are subclasses of torch.utils.data.Dataset i.e, they have getitem and len methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples in parallel using torch.multiprocessing workers, see [1].
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size |
int
|
The size of the batch. |
required |
data_dir |
str
|
The directory where the data is stored. Defaults to “./data”. |
'./data'
|
num_workers |
int
|
The number of workers for data loading. Defaults to 0. |
0
|
Attributes:
Name | Type | Description |
---|---|---|
data_train |
Dataset
|
The training dataset. |
data_val |
Dataset
|
The validation dataset. |
data_test |
Dataset
|
The test dataset. |
References
Source code in spotpython/light/cifar10/cifar10datamodule.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|
prepare_data()
¶
Prepares the data for use.
Source code in spotpython/light/cifar10/cifar10datamodule.py
38 39 40 41 42 |
|
setup(stage=None)
¶
Sets up the data for use.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Optional[str]
|
The current stage. Defaults to None. |
None
|
Source code in spotpython/light/cifar10/cifar10datamodule.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
|
test_dataloader()
¶
Returns the test dataloader.
Returns:
Name | Type | Description |
---|---|---|
DataLoader |
DataLoader
|
The test dataloader. |
Source code in spotpython/light/cifar10/cifar10datamodule.py
90 91 92 93 94 95 96 97 98 99 100 |
|
train_dataloader()
¶
Returns the training dataloader.
Returns:
Name | Type | Description |
---|---|---|
DataLoader |
DataLoader
|
The training dataloader. |
Source code in spotpython/light/cifar10/cifar10datamodule.py
67 68 69 70 71 72 73 74 75 76 |
|
val_dataloader()
¶
Returns the validation dataloader.
Returns:
Name | Type | Description |
---|---|---|
DataLoader |
DataLoader
|
The validation dataloader. |
Source code in spotpython/light/cifar10/cifar10datamodule.py
78 79 80 81 82 83 84 85 86 87 88 |
|