lightcrossvalidationdatamodule
LightCrossValidationDataModule
¶
Bases: LightningDataModule
A LightningDataModule for handling cross-validation data splits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size |
int
|
The size of the batch. Defaults to 64. |
64
|
dataset |
Dataset
|
The dataset from the torch.utils.data Dataset class. It must implement three functions: init, len, and getitem. |
None
|
data_full_train |
Dataset
|
The full training dataset from which training and validation sets will be derived. |
None
|
data_test |
Dataset
|
The separate test dataset that will be used for testing. |
None
|
k |
int
|
The fold number. Defaults to 1. |
1
|
split_seed |
int
|
The random seed for splitting the data. Defaults to 42. |
42
|
num_splits |
int
|
The number of splits for cross-validation. Defaults to 10. |
10
|
data_dir |
str
|
The path to the dataset. Defaults to “./data”. |
'./data'
|
num_workers |
int
|
The number of workers for data loading. Defaults to 0. |
0
|
pin_memory |
bool
|
Whether to pin memory for data loading. Defaults to False. |
False
|
verbosity |
int
|
The verbosity level. Defaults to 0. |
0
|
Attributes:
Name | Type | Description |
---|---|---|
data_train |
Optional[Dataset]
|
The training dataset. |
data_val |
Optional[Dataset]
|
The validation dataset. |
Examples:
>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> print(f"Training set size: {len(data_module.data_train)}")
Training set size: 45000
>>> print(f"Validation set size: {len(data_module.data_val)}")
Validation set size: 5000
>>> print(f"Test set size: {len(data_module.data_test)}")
Test set size: 10000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
|
prepare_data()
¶
Prepares the data for use.
Source code in spotpython/data/lightcrossvalidationdatamodule.py
83 84 85 86 |
|
setup(stage=None)
¶
Sets up the data for use.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Optional[str]
|
The current stage. Defaults to None. |
None
|
Source code in spotpython/data/lightcrossvalidationdatamodule.py
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
|
train_dataloader()
¶
Returns the training dataloader.
Returns:
Name | Type | Description |
---|---|---|
DataLoader |
DataLoader
|
The training dataloader. |
Examples:
>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> train_dataloader = data_module.train_dataloader()
>>> print(f"Training set size: {len(train_dataloader.dataset)}")
Training set size: 45000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
|
val_dataloader()
¶
Returns the validation dataloader.
Returns:
Name | Type | Description |
---|---|---|
DataLoader |
DataLoader
|
The validation dataloader. |
Examples:
>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> val_dataloader = data_module.val_dataloader()
>>> print(f"Validation set size: {len(val_dataloader.dataset)}")
Validation set size: 5000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
|