lightcrossvalidationdatamodule
LightCrossValidationDataModule
¶
Bases: LightningDataModule
A LightningDataModule for handling cross-validation data splits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size |
int
|
The size of the batch. Defaults to 64. |
64
|
k |
int
|
The fold number. Defaults to 1. |
1
|
split_seed |
int
|
The random seed for splitting the data. Defaults to 42. |
42
|
num_splits |
int
|
The number of splits for cross-validation. Defaults to 10. |
10
|
data_dir |
str
|
The path to the dataset. Defaults to “./data”. |
'./data'
|
num_workers |
int
|
The number of workers for data loading. Defaults to 0. |
0
|
pin_memory |
bool
|
Whether to pin memory for data loading. Defaults to False. |
False
|
Attributes:
Name | Type | Description |
---|---|---|
data_train |
Optional[Dataset]
|
The training dataset. |
data_val |
Optional[Dataset]
|
The validation dataset. |
Examples:
>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> print(f"Training set size: {len(data_module.data_train)}")
Training set size: 45000
>>> print(f"Validation set size: {len(data_module.data_val)}")
Validation set size: 5000
>>> print(f"Test set size: {len(data_module.data_test)}")
Test set size: 10000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
|
prepare_data()
¶
Prepares the data for use.
Source code in spotpython/data/lightcrossvalidationdatamodule.py
69 70 71 72 |
|
setup(stage=None)
¶
Sets up the data for use.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Optional[str]
|
The current stage. Defaults to None. |
None
|
Source code in spotpython/data/lightcrossvalidationdatamodule.py
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
train_dataloader()
¶
Returns the training dataloader.
Returns:
Name | Type | Description |
---|---|---|
DataLoader |
DataLoader
|
The training dataloader. |
Examples:
>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> train_dataloader = data_module.train_dataloader()
>>> print(f"Training set size: {len(train_dataloader.dataset)}")
Training set size: 45000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
|
val_dataloader()
¶
Returns the validation dataloader.
Returns:
Name | Type | Description |
---|---|---|
DataLoader |
DataLoader
|
The validation dataloader. |
Examples:
>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> val_dataloader = data_module.val_dataloader()
>>> print(f"Validation set size: {len(val_dataloader.dataset)}")
Validation set size: 5000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
|