Skip to content

lightcrossvalidationdatamodule

LightCrossValidationDataModule

Bases: LightningDataModule

A LightningDataModule for handling cross-validation data splits.

Parameters:

Name Type Description Default
batch_size int

The size of the batch. Defaults to 64.

64
k int

The fold number. Defaults to 1.

1
split_seed int

The random seed for splitting the data. Defaults to 42.

42
num_splits int

The number of splits for cross-validation. Defaults to 10.

10
data_dir str

The path to the dataset. Defaults to “./data”.

'./data'
num_workers int

The number of workers for data loading. Defaults to 0.

0
pin_memory bool

Whether to pin memory for data loading. Defaults to False.

False

Attributes:

Name Type Description
data_train Optional[Dataset]

The training dataset.

data_val Optional[Dataset]

The validation dataset.

Examples:

>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> print(f"Training set size: {len(data_module.data_train)}")
Training set size: 45000
>>> print(f"Validation set size: {len(data_module.data_val)}")
Validation set size: 5000
>>> print(f"Test set size: {len(data_module.data_test)}")
Test set size: 10000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class LightCrossValidationDataModule(L.LightningDataModule):
    """
    A LightningDataModule for handling cross-validation data splits.

    Args:
        batch_size (int): The size of the batch. Defaults to 64.
        k (int): The fold number. Defaults to 1.
        split_seed (int): The random seed for splitting the data. Defaults to 42.
        num_splits (int): The number of splits for cross-validation. Defaults to 10.
        data_dir (str): The path to the dataset. Defaults to "./data".
        num_workers (int): The number of workers for data loading. Defaults to 0.
        pin_memory (bool): Whether to pin memory for data loading. Defaults to False.

    Attributes:
        data_train (Optional[Dataset]): The training dataset.
        data_val (Optional[Dataset]): The validation dataset.

    Examples:
        >>> from spotpython.light import LightCrossValidationDataModule
        >>> data_module = LightCrossValidationDataModule()
        >>> data_module.setup()
        >>> print(f"Training set size: {len(data_module.data_train)}")
        Training set size: 45000
        >>> print(f"Validation set size: {len(data_module.data_val)}")
        Validation set size: 5000
        >>> print(f"Test set size: {len(data_module.data_test)}")
        Test set size: 10000
    """

    def __init__(
        self,
        batch_size=64,
        dataset=None,
        k: int = 1,
        split_seed: int = 42,
        num_splits: int = 10,
        data_dir: str = "./data",
        num_workers: int = 0,
        pin_memory: bool = False,
        scaler: Optional[object] = None,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.data_full = dataset
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.k = k
        self.split_seed = split_seed
        self.num_splits = num_splits
        self.pin_memory = pin_memory
        self.scaler = scaler
        self.save_hyperparameters(logger=False)
        assert 0 <= self.k < self.num_splits, "incorrect fold number"

        # no data transformations
        self.transforms = None

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None

    def prepare_data(self) -> None:
        """Prepares the data for use."""
        # download
        pass

    def setup(self, stage: Optional[str] = None) -> None:
        """
        Sets up the data for use.

        Args:
            stage (Optional[str]): The current stage. Defaults to None.
        """
        if not self.data_train and not self.data_val:
            dataset_full = self.data_full
            kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
            all_splits = [k for k in kf.split(dataset_full)]
            train_indexes, val_indexes = all_splits[self.hparams.k]
            train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()
            self.data_train = Subset(dataset_full, train_indexes)
            print(f"Train Dataset Size: {len(self.data_train)}")
            self.data_val = Subset(dataset_full, val_indexes)
            print(f"Val Dataset Size: {len(self.data_val)}")

        if self.scaler is not None:
            # Fit the scaler on training data and transform both train and val data
            scaler_train_data = torch.stack([self.data_train[i][0] for i in range(len(self.data_train))]).squeeze(1)
            self.scaler.fit(scaler_train_data)
            self.data_train = [(self.scaler.transform(data), target) for data, target in self.data_train]
            data_tensors_train = [data.clone().detach() for data, target in self.data_train]
            target_tensors_train = [target.clone().detach() for data, target in self.data_train]
            self.data_train = TensorDataset(
                torch.stack(data_tensors_train).squeeze(1), torch.stack(target_tensors_train)
            )
            self.data_val = [(self.scaler.transform(data), target) for data, target in self.data_val]
            data_tensors_val = [data.clone().detach() for data, target in self.data_val]
            target_tensors_val = [target.clone().detach() for data, target in self.data_val]
            self.data_val = TensorDataset(torch.stack(data_tensors_val).squeeze(1), torch.stack(target_tensors_val))

    def train_dataloader(self) -> DataLoader:
        """
        Returns the training dataloader.

        Returns:
            DataLoader: The training dataloader.

        Examples:
            >>> from spotpython.light import LightCrossValidationDataModule
            >>> data_module = LightCrossValidationDataModule()
            >>> data_module.setup()
            >>> train_dataloader = data_module.train_dataloader()
            >>> print(f"Training set size: {len(train_dataloader.dataset)}")
            Training set size: 45000

        """
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
        )

    def val_dataloader(self) -> DataLoader:
        """
        Returns the validation dataloader.

        Returns:
            DataLoader: The validation dataloader.

        Examples:
            >>> from spotpython.light import LightCrossValidationDataModule
            >>> data_module = LightCrossValidationDataModule()
            >>> data_module.setup()
            >>> val_dataloader = data_module.val_dataloader()
            >>> print(f"Validation set size: {len(val_dataloader.dataset)}")
            Validation set size: 5000
        """
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
        )

prepare_data()

Prepares the data for use.

Source code in spotpython/data/lightcrossvalidationdatamodule.py
69
70
71
72
def prepare_data(self) -> None:
    """Prepares the data for use."""
    # download
    pass

setup(stage=None)

Sets up the data for use.

Parameters:

Name Type Description Default
stage Optional[str]

The current stage. Defaults to None.

None
Source code in spotpython/data/lightcrossvalidationdatamodule.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def setup(self, stage: Optional[str] = None) -> None:
    """
    Sets up the data for use.

    Args:
        stage (Optional[str]): The current stage. Defaults to None.
    """
    if not self.data_train and not self.data_val:
        dataset_full = self.data_full
        kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
        all_splits = [k for k in kf.split(dataset_full)]
        train_indexes, val_indexes = all_splits[self.hparams.k]
        train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()
        self.data_train = Subset(dataset_full, train_indexes)
        print(f"Train Dataset Size: {len(self.data_train)}")
        self.data_val = Subset(dataset_full, val_indexes)
        print(f"Val Dataset Size: {len(self.data_val)}")

    if self.scaler is not None:
        # Fit the scaler on training data and transform both train and val data
        scaler_train_data = torch.stack([self.data_train[i][0] for i in range(len(self.data_train))]).squeeze(1)
        self.scaler.fit(scaler_train_data)
        self.data_train = [(self.scaler.transform(data), target) for data, target in self.data_train]
        data_tensors_train = [data.clone().detach() for data, target in self.data_train]
        target_tensors_train = [target.clone().detach() for data, target in self.data_train]
        self.data_train = TensorDataset(
            torch.stack(data_tensors_train).squeeze(1), torch.stack(target_tensors_train)
        )
        self.data_val = [(self.scaler.transform(data), target) for data, target in self.data_val]
        data_tensors_val = [data.clone().detach() for data, target in self.data_val]
        target_tensors_val = [target.clone().detach() for data, target in self.data_val]
        self.data_val = TensorDataset(torch.stack(data_tensors_val).squeeze(1), torch.stack(target_tensors_val))

train_dataloader()

Returns the training dataloader.

Returns:

Name Type Description
DataLoader DataLoader

The training dataloader.

Examples:

>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> train_dataloader = data_module.train_dataloader()
>>> print(f"Training set size: {len(train_dataloader.dataset)}")
Training set size: 45000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def train_dataloader(self) -> DataLoader:
    """
    Returns the training dataloader.

    Returns:
        DataLoader: The training dataloader.

    Examples:
        >>> from spotpython.light import LightCrossValidationDataModule
        >>> data_module = LightCrossValidationDataModule()
        >>> data_module.setup()
        >>> train_dataloader = data_module.train_dataloader()
        >>> print(f"Training set size: {len(train_dataloader.dataset)}")
        Training set size: 45000

    """
    return DataLoader(
        dataset=self.data_train,
        batch_size=self.hparams.batch_size,
        num_workers=self.hparams.num_workers,
        pin_memory=self.hparams.pin_memory,
        shuffle=True,
    )

val_dataloader()

Returns the validation dataloader.

Returns:

Name Type Description
DataLoader DataLoader

The validation dataloader.

Examples:

>>> from spotpython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> val_dataloader = data_module.val_dataloader()
>>> print(f"Validation set size: {len(val_dataloader.dataset)}")
Validation set size: 5000
Source code in spotpython/data/lightcrossvalidationdatamodule.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def val_dataloader(self) -> DataLoader:
    """
    Returns the validation dataloader.

    Returns:
        DataLoader: The validation dataloader.

    Examples:
        >>> from spotpython.light import LightCrossValidationDataModule
        >>> data_module = LightCrossValidationDataModule()
        >>> data_module.setup()
        >>> val_dataloader = data_module.val_dataloader()
        >>> print(f"Validation set size: {len(val_dataloader.dataset)}")
        Validation set size: 5000
    """
    return DataLoader(
        dataset=self.data_val,
        batch_size=self.hparams.batch_size,
        num_workers=self.hparams.num_workers,
        pin_memory=self.hparams.pin_memory,
    )