Skip to content

lightcrossvalidationdatamodule

LightCrossValidationDataModule

Bases: LightningDataModule

A LightningDataModule for handling cross-validation data splits.

Parameters:

Name Type Description Default
batch_size int

The size of the batch. Defaults to 64.

64
k int

The fold number. Defaults to 1.

1
split_seed int

The random seed for splitting the data. Defaults to 42.

42
num_splits int

The number of splits for cross-validation. Defaults to 10.

10
data_dir str

The path to the dataset. Defaults to “./data”.

'./data'
num_workers int

The number of workers for data loading. Defaults to 0.

0
pin_memory bool

Whether to pin memory for data loading. Defaults to False.

False

Attributes:

Name Type Description
data_train Optional[Dataset]

The training dataset.

data_val Optional[Dataset]

The validation dataset.

Examples:

>>> from spotPython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> print(f"Training set size: {len(data_module.data_train)}")
Training set size: 45000
>>> print(f"Validation set size: {len(data_module.data_val)}")
Validation set size: 5000
>>> print(f"Test set size: {len(data_module.data_test)}")
Test set size: 10000
Source code in spotPython/data/lightcrossvalidationdatamodule.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class LightCrossValidationDataModule(L.LightningDataModule):
    """
    A LightningDataModule for handling cross-validation data splits.

    Args:
        batch_size (int): The size of the batch. Defaults to 64.
        k (int): The fold number. Defaults to 1.
        split_seed (int): The random seed for splitting the data. Defaults to 42.
        num_splits (int): The number of splits for cross-validation. Defaults to 10.
        data_dir (str): The path to the dataset. Defaults to "./data".
        num_workers (int): The number of workers for data loading. Defaults to 0.
        pin_memory (bool): Whether to pin memory for data loading. Defaults to False.

    Attributes:
        data_train (Optional[Dataset]): The training dataset.
        data_val (Optional[Dataset]): The validation dataset.

    Examples:
        >>> from spotPython.light import LightCrossValidationDataModule
        >>> data_module = LightCrossValidationDataModule()
        >>> data_module.setup()
        >>> print(f"Training set size: {len(data_module.data_train)}")
        Training set size: 45000
        >>> print(f"Validation set size: {len(data_module.data_val)}")
        Validation set size: 5000
        >>> print(f"Test set size: {len(data_module.data_test)}")
        Test set size: 10000
    """

    def __init__(
        self,
        batch_size=64,
        dataset=None,
        k: int = 1,
        split_seed: int = 42,
        num_splits: int = 10,
        data_dir: str = "./data",
        num_workers: int = 0,
        pin_memory: bool = False,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.data_full = dataset
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.k = k
        self.split_seed = split_seed
        self.num_splits = num_splits
        self.pin_memory = pin_memory
        self.save_hyperparameters(logger=False)
        assert 0 <= self.k < self.num_splits, "incorrect fold number"

        # no data transformations
        self.transforms = None

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None

    def prepare_data(self) -> None:
        """Prepares the data for use."""
        # download
        pass

    def setup(self, stage: Optional[str] = None) -> None:
        """
        Sets up the data for use.

        Args:
            stage (Optional[str]): The current stage. Defaults to None.
        """
        if not self.data_train and not self.data_val:
            dataset_full = self.data_full
            kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
            all_splits = [k for k in kf.split(dataset_full)]
            train_indexes, val_indexes = all_splits[self.hparams.k]
            train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()
            self.data_train = Subset(dataset_full, train_indexes)
            print(f"Train Dataset Size: {len(self.data_train)}")
            self.data_val = Subset(dataset_full, val_indexes)
            print(f"Val Dataset Size: {len(self.data_val)}")

    def train_dataloader(self) -> DataLoader:
        """
        Returns the training dataloader.

        Returns:
            DataLoader: The training dataloader.

        Examples:
            >>> from spotPython.light import LightCrossValidationDataModule
            >>> data_module = LightCrossValidationDataModule()
            >>> data_module.setup()
            >>> train_dataloader = data_module.train_dataloader()
            >>> print(f"Training set size: {len(train_dataloader.dataset)}")
            Training set size: 45000

        """
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
        )

    def val_dataloader(self) -> DataLoader:
        """
        Returns the validation dataloader.

        Returns:
            DataLoader: The validation dataloader.

        Examples:
            >>> from spotPython.light import LightCrossValidationDataModule
            >>> data_module = LightCrossValidationDataModule()
            >>> data_module.setup()
            >>> val_dataloader = data_module.val_dataloader()
            >>> print(f"Validation set size: {len(val_dataloader.dataset)}")
            Validation set size: 5000
        """
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
        )

prepare_data()

Prepares the data for use.

Source code in spotPython/data/lightcrossvalidationdatamodule.py
66
67
68
69
def prepare_data(self) -> None:
    """Prepares the data for use."""
    # download
    pass

setup(stage=None)

Sets up the data for use.

Parameters:

Name Type Description Default
stage Optional[str]

The current stage. Defaults to None.

None
Source code in spotPython/data/lightcrossvalidationdatamodule.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def setup(self, stage: Optional[str] = None) -> None:
    """
    Sets up the data for use.

    Args:
        stage (Optional[str]): The current stage. Defaults to None.
    """
    if not self.data_train and not self.data_val:
        dataset_full = self.data_full
        kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
        all_splits = [k for k in kf.split(dataset_full)]
        train_indexes, val_indexes = all_splits[self.hparams.k]
        train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()
        self.data_train = Subset(dataset_full, train_indexes)
        print(f"Train Dataset Size: {len(self.data_train)}")
        self.data_val = Subset(dataset_full, val_indexes)
        print(f"Val Dataset Size: {len(self.data_val)}")

train_dataloader()

Returns the training dataloader.

Returns:

Name Type Description
DataLoader DataLoader

The training dataloader.

Examples:

>>> from spotPython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> train_dataloader = data_module.train_dataloader()
>>> print(f"Training set size: {len(train_dataloader.dataset)}")
Training set size: 45000
Source code in spotPython/data/lightcrossvalidationdatamodule.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def train_dataloader(self) -> DataLoader:
    """
    Returns the training dataloader.

    Returns:
        DataLoader: The training dataloader.

    Examples:
        >>> from spotPython.light import LightCrossValidationDataModule
        >>> data_module = LightCrossValidationDataModule()
        >>> data_module.setup()
        >>> train_dataloader = data_module.train_dataloader()
        >>> print(f"Training set size: {len(train_dataloader.dataset)}")
        Training set size: 45000

    """
    return DataLoader(
        dataset=self.data_train,
        batch_size=self.hparams.batch_size,
        num_workers=self.hparams.num_workers,
        pin_memory=self.hparams.pin_memory,
        shuffle=True,
    )

val_dataloader()

Returns the validation dataloader.

Returns:

Name Type Description
DataLoader DataLoader

The validation dataloader.

Examples:

>>> from spotPython.light import LightCrossValidationDataModule
>>> data_module = LightCrossValidationDataModule()
>>> data_module.setup()
>>> val_dataloader = data_module.val_dataloader()
>>> print(f"Validation set size: {len(val_dataloader.dataset)}")
Validation set size: 5000
Source code in spotPython/data/lightcrossvalidationdatamodule.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def val_dataloader(self) -> DataLoader:
    """
    Returns the validation dataloader.

    Returns:
        DataLoader: The validation dataloader.

    Examples:
        >>> from spotPython.light import LightCrossValidationDataModule
        >>> data_module = LightCrossValidationDataModule()
        >>> data_module.setup()
        >>> val_dataloader = data_module.val_dataloader()
        >>> print(f"Validation set size: {len(val_dataloader.dataset)}")
        Validation set size: 5000
    """
    return DataLoader(
        dataset=self.data_val,
        batch_size=self.hparams.batch_size,
        num_workers=self.hparams.num_workers,
        pin_memory=self.hparams.pin_memory,
    )