import torch
from spotpython.utils.device import getDevice
from math import inf
= 0
WORKERS ="030"
PREFIX= getDevice()
DEVICE = 1
DEVICES = 0.4 TEST_SIZE
30 HPT PyTorch Lightning: Data
In this tutorial, we will show how spotpython
can be integrated into the PyTorch
Lightning training workflow.
This chapter describes the data preparation and processing in spotpython
. The Diabetes data set is used as an example. This is a PyTorch Dataset for regression. A toy data set from scikit-learn. Ten baseline variables, age, sex, body mass index, average blood pressure, and six blood serum measurements were obtained for each of n = 442 diabetes patients, as well as the response of interest, a quantitative measure of disease progression one year after baseline.
30.1 Setup
- Before we consider the detailed experimental setup, we select the parameters that affect run time, initial design size, etc.
- The parameter
WORKERS
specifies the number of workers. - The prefix
PREFIX
is used for the experiment name and the name of the log file. - The parameter
DEVICE
specifies the device to use for training.
- Although there are no .cuda() or .to(device) calls required, because Lightning does these for you, see LIGHTNINGMODULE, we would like to know which device is used. Threrefore, we imitate the LightningModule behaviour which selects the highest device.
- The method
spotpython.utils.device.getDevice()
returns the device that is used by Lightning.
30.2 Initialization of the fun_control
Dictionary
spotpython
uses a Python dictionary for storing the information required for the hyperparameter tuning process.
from spotpython.utils.init import fun_control_init
import numpy as np
= fun_control_init(
fun_control =10,
_L_in=1,
_L_out="mean_squared_error",
_torchmetric=PREFIX,
PREFIX=DEVICE,
device=False,
enable_progress_bar=WORKERS,
num_workers=True,
show_progress=TEST_SIZE,
test_size )
30.3 Loading the Diabetes Data Set
Here, we load the Diabetes data set from spotpython
’s data
module.
from spotpython.data.diabetes import Diabetes
= Diabetes(target_type=torch.float)
dataset print(len(dataset))
442
30.3.1 Data Set and Data Loader
As shown below, a DataLoader from torch.utils.data
can be used to check the data.
# Set batch size for DataLoader
= 5
batch_size # Create DataLoader
from torch.utils.data import DataLoader
= DataLoader(dataset, batch_size=batch_size, shuffle=False)
dataloader
# Iterate over the data in the DataLoader
for batch in dataloader:
= batch
inputs, targets print(f"Batch Size: {inputs.size(0)}")
print(f"Inputs Shape: {inputs.shape}")
print(f"Targets Shape: {targets.shape}")
print("---------------")
print(f"Inputs: {inputs}")
print(f"Targets: {targets}")
break
Batch Size: 5
Inputs Shape: torch.Size([5, 10])
Targets Shape: torch.Size([5])
---------------
Inputs: tensor([[ 0.0381, 0.0507, 0.0617, 0.0219, -0.0442, -0.0348, -0.0434, -0.0026,
0.0199, -0.0176],
[-0.0019, -0.0446, -0.0515, -0.0263, -0.0084, -0.0192, 0.0744, -0.0395,
-0.0683, -0.0922],
[ 0.0853, 0.0507, 0.0445, -0.0057, -0.0456, -0.0342, -0.0324, -0.0026,
0.0029, -0.0259],
[-0.0891, -0.0446, -0.0116, -0.0367, 0.0122, 0.0250, -0.0360, 0.0343,
0.0227, -0.0094],
[ 0.0054, -0.0446, -0.0364, 0.0219, 0.0039, 0.0156, 0.0081, -0.0026,
-0.0320, -0.0466]])
Targets: tensor([151., 75., 141., 206., 135.])
30.3.2 Preparing Training, Validation, and Test Data
The following code shows how to split the data into training, validation, and test sets. Then a Lightning Trainer is used to train (fit
) the model, validate it, and test it.
from torch.utils.data import DataLoader
from spotpython.data.diabetes import Diabetes
from spotpython.light.regression.netlightregression import NetLightRegression
from torch import nn
import lightning as L
import torch
= 8
BATCH_SIZE = Diabetes(target_type=torch.float)
dataset = torch.utils.data.random_split(dataset, [0.6, 0.4])
train1_set, test_set = torch.utils.data.random_split(train1_set, [0.6, 0.4])
train_set, val_set = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True)
train_loader = DataLoader(test_set, batch_size=BATCH_SIZE)
test_loader = DataLoader(val_set, batch_size=BATCH_SIZE)
val_loader = next(iter(train_loader))
batch_x, batch_y print(f"batch_x.shape: {batch_x.shape}")
print(f"batch_y.shape: {batch_y.shape}")
= NetLightRegression(l1=128,
net_light_base =10,
epochs=BATCH_SIZE,
batch_size='Default',
initialization=nn.ReLU(),
act_fn='Adam',
optimizer=0.1,
dropout_prob=0.1,
lr_mult=5,
patience=10,
_L_in=1,
_L_out="mean_squared_error")
_torchmetric= L.Trainer(max_epochs=10, enable_progress_bar=False)
trainer
trainer.fit(net_light_base, train_loader)
trainer.validate(net_light_base, val_loader) trainer.test(net_light_base, test_loader)
batch_x.shape: torch.Size([8, 10])
batch_y.shape: torch.Size([8])
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Validate metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ hp_metric │ 32421.490234375 │ │ val_loss │ 32421.490234375 │ └───────────────────────────┴───────────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ hp_metric │ 25781.0234375 │ │ val_loss │ 25781.0234375 │ └───────────────────────────┴───────────────────────────┘
[{'val_loss': 25781.0234375, 'hp_metric': 25781.0234375}]
30.3.3 Dataset for spotpython
spotpython
handles the data set, which is added to the fun_control
dictionary with the key data_set
as follows:
from spotpython.hyperparameters.values import set_control_key_value
from spotpython.data.diabetes import Diabetes
= Diabetes(target_type=torch.float)
dataset =fun_control,
set_control_key_value(control_dict="data_set",
key=dataset,
value=True)
replaceprint(len(dataset))
442
If the data set is in the fun_control
dictionary, it is used to create a LightDataModule
object. This object is used to create the data loaders for the training, validation, and test sets. Therefore, the following information must be provided in the fun_control
dictionary:
data_set
: the data setbatch_size
: the batch sizenum_workers
: the number of workerstest_size
: the size of the test settest_seed
: the seed for the test set
from spotpython.utils.init import fun_control_init
import numpy as np
= fun_control_init(
fun_control =dataset,
data_set="cpu",
device=False,
enable_progress_bar=0,
num_workers=True,
show_progress=0.4,
test_size=42,
test_seed )
from spotpython.data.lightdatamodule import LightDataModule
= LightDataModule(
dm =fun_control["data_set"],
dataset=8,
batch_size=fun_control["num_workers"],
num_workers=fun_control["test_size"],
test_size=fun_control["test_seed"],
test_seed
)
dm.setup()print(f"train_model(): Test set size: {len(dm.data_test)}")
print(f"train_model(): Train set size: {len(dm.data_train)}")
train_model(): Test set size: 177
train_model(): Train set size: 160
30.4 The LightDataModule
The steps described above are handled by the LightDataModule
class. This class is used to create the data loaders for the training, validation, and test sets. The LightDataModule
class is part of the spotpython
package. The LightDataModule
class provides the following methods:
prepare_data()
: This method is used to prepare the data set.setup()
: This method is used to create the data loaders for the training, validation, and test sets.train_dataloader()
: This method is used to return the data loader for the training set.val_dataloader()
: This method is used to return the data loader for the validation set.test_dataloader()
: This method is used to return the data loader for the test set.predict_dataloader()
: This method is used to return the data loader for the prediction set.
30.4.1 The prepare_data()
Method
The prepare_data()
method is used to prepare the data set. This method is called only once and on a single process. It can be used to download the data set. In our case, the data set is already available, so this method uses a simple pass
statement.
30.4.2 The setup()
Method
Splits the data for use in training, validation, and testing. It uses torch.utils.data.random_split()
to split the data. Splitting is based on the test_size
and test_seed
. The test_size
can be a float or an int.
30.4.2.1 Determine the Sizes of the Data Sets
from torch.utils.data import random_split
= dataset
data_full = fun_control["test_size"]
test_size =fun_control["test_seed"]
test_seed# if test_size is float, then train_size is 1 - test_size
if isinstance(test_size, float):
= round(1.0 - test_size, 2)
full_train_size = round(full_train_size * test_size, 2)
val_size = round(full_train_size - val_size, 2)
train_size else:
# if test_size is int, then train_size is len(data_full) - test_size
= len(data_full) - test_size
full_train_size = int(full_train_size * test_size / len(data_full))
val_size = full_train_size - val_size
train_size
print(f"LightDataModule setup(): full_train_size: {full_train_size}")
print(f"LightDataModule setup(): val_size: {val_size}")
print(f"LightDataModule setup(): train_size: {train_size}")
print(f"LightDataModule setup(): test_size: {test_size}")
LightDataModule setup(): full_train_size: 0.6
LightDataModule setup(): val_size: 0.24
LightDataModule setup(): train_size: 0.36
LightDataModule setup(): test_size: 0.4
stage
is used to define the data set to be returned. The stage
can be None
, fit
, test
, or predict
. If stage
is None
, the method returns the training (fit
), testing (test
) and prediction (predict
) data sets.
30.4.2.2 Stage “fit”
= "fit"
stage if stage == "fit" or stage is None:
= torch.Generator().manual_seed(test_seed)
generator_fit = random_split(data_full, [train_size, val_size, test_size], generator=generator_fit)
data_train, data_val, _ print(f"LightDataModule setup(): Train set size: {len(data_train)}")
print(f"LightDataModule setup(): Validation set size: {len(data_val)}")
LightDataModule setup(): Train set size: 160
LightDataModule setup(): Validation set size: 106
30.4.2.3 Stage “test”
= "test"
stage if stage == "test" or stage is None:
= torch.Generator().manual_seed(test_seed)
generator_test = random_split(data_full, [test_size, full_train_size], generator=generator_test)
data_test, _ print(f"LightDataModule setup(): Test set size: {len(data_test)}")
# Set batch size for DataLoader
= 5
batch_size # Create DataLoader
from torch.utils.data import DataLoader
= DataLoader(data_test, batch_size=batch_size, shuffle=False)
dataloader # Iterate over the data in the DataLoader
for batch in dataloader:
= batch
inputs, targets print(f"Batch Size: {inputs.size(0)}")
print(f"Inputs Shape: {inputs.shape}")
print(f"Targets Shape: {targets.shape}")
print("---------------")
print(f"Inputs: {inputs}")
print(f"Targets: {targets}")
break
LightDataModule setup(): Test set size: 177
Batch Size: 5
Inputs Shape: torch.Size([5, 10])
Targets Shape: torch.Size([5])
---------------
Inputs: tensor([[ 0.0562, -0.0446, -0.0579, -0.0080, 0.0521, 0.0491, 0.0560, -0.0214,
-0.0283, 0.0445],
[ 0.0018, -0.0446, -0.0709, -0.0229, -0.0016, -0.0010, 0.0266, -0.0395,
-0.0225, 0.0072],
[-0.0527, -0.0446, 0.0542, -0.0263, -0.0552, -0.0339, -0.0139, -0.0395,
-0.0741, -0.0591],
[ 0.0054, -0.0446, -0.0482, -0.0126, 0.0012, -0.0066, 0.0634, -0.0395,
-0.0514, -0.0591],
[-0.0527, -0.0446, -0.0094, -0.0057, 0.0397, 0.0447, 0.0266, -0.0026,
-0.0181, -0.0135]])
Targets: tensor([158., 49., 142., 96., 59.])
30.4.2.4 Stage “predict”
Prediction and testing use the same data set.
= "predict"
stage if stage == "predict" or stage is None:
= torch.Generator().manual_seed(test_seed)
generator_predict = random_split(
data_predict, _ =generator_predict
data_full, [test_size, full_train_size], generator
)print(f"LightDataModule setup(): Predict set size: {len(data_predict)}")
# Set batch size for DataLoader
= 5
batch_size # Create DataLoader
from torch.utils.data import DataLoader
= DataLoader(data_predict, batch_size=batch_size, shuffle=False)
dataloader # Iterate over the data in the DataLoader
for batch in dataloader:
= batch
inputs, targets print(f"Batch Size: {inputs.size(0)}")
print(f"Inputs Shape: {inputs.shape}")
print(f"Targets Shape: {targets.shape}")
print("---------------")
print(f"Inputs: {inputs}")
print(f"Targets: {targets}")
break
LightDataModule setup(): Predict set size: 177
Batch Size: 5
Inputs Shape: torch.Size([5, 10])
Targets Shape: torch.Size([5])
---------------
Inputs: tensor([[ 0.0562, -0.0446, -0.0579, -0.0080, 0.0521, 0.0491, 0.0560, -0.0214,
-0.0283, 0.0445],
[ 0.0018, -0.0446, -0.0709, -0.0229, -0.0016, -0.0010, 0.0266, -0.0395,
-0.0225, 0.0072],
[-0.0527, -0.0446, 0.0542, -0.0263, -0.0552, -0.0339, -0.0139, -0.0395,
-0.0741, -0.0591],
[ 0.0054, -0.0446, -0.0482, -0.0126, 0.0012, -0.0066, 0.0634, -0.0395,
-0.0514, -0.0591],
[-0.0527, -0.0446, -0.0094, -0.0057, 0.0397, 0.0447, 0.0266, -0.0026,
-0.0181, -0.0135]])
Targets: tensor([158., 49., 142., 96., 59.])
30.4.3 The train_dataloader()
Method
Returns the training dataloader, i.e., a Pytorch DataLoader instance using the training dataset. It simply returns a DataLoader with the data_train
set that was created in the setup()
method as described in Section 30.4.2.2.
def train_dataloader(self) -> DataLoader:
return DataLoader(self.data_train, batch_size=self.batch_size, num_workers=self.num_workers)
The train_dataloader()
method can be used as follows:
from spotpython.data.lightdatamodule import LightDataModule
from spotpython.data.diabetes import Diabetes
= Diabetes(target_type=torch.float)
dataset = LightDataModule(dataset=dataset, batch_size=5, test_size=0.4)
data_module
data_module.setup()print(f"Training set size: {len(data_module.data_train)}")
= data_module.train_dataloader()
dl # Iterate over the data in the DataLoader
for batch in dl:
= batch
inputs, targets print(f"Batch Size: {inputs.size(0)}")
print(f"Inputs Shape: {inputs.shape}")
print(f"Targets Shape: {targets.shape}")
print("---------------")
print(f"Inputs: {inputs}")
print(f"Targets: {targets}")
break
Training set size: 160
Batch Size: 5
Inputs Shape: torch.Size([5, 10])
Targets Shape: torch.Size([5])
---------------
Inputs: tensor([[ 0.0562, -0.0446, -0.0579, -0.0080, 0.0521, 0.0491, 0.0560, -0.0214,
-0.0283, 0.0445],
[ 0.0018, -0.0446, -0.0709, -0.0229, -0.0016, -0.0010, 0.0266, -0.0395,
-0.0225, 0.0072],
[-0.0527, -0.0446, 0.0542, -0.0263, -0.0552, -0.0339, -0.0139, -0.0395,
-0.0741, -0.0591],
[ 0.0054, -0.0446, -0.0482, -0.0126, 0.0012, -0.0066, 0.0634, -0.0395,
-0.0514, -0.0591],
[-0.0527, -0.0446, -0.0094, -0.0057, 0.0397, 0.0447, 0.0266, -0.0026,
-0.0181, -0.0135]])
Targets: tensor([158., 49., 142., 96., 59.])
30.4.4 The val_dataloader()
Method
Returns the validation dataloader, i.e., a Pytorch DataLoader instance using the validation dataset. It simply returns a DataLoader with the data_val
set that was created in the setup()
method as desccribed in Section 30.4.2.2.
def val_dataloader(self) -> DataLoader:
return DataLoader(self.data_val, batch_size=self.batch_size, num_workers=self.num_workers)
The val_dataloader()
method can be used as follows:
from spotpython.data.lightdatamodule import LightDataModule
from spotpython.data.diabetes import Diabetes
= Diabetes(target_type=torch.float)
dataset = LightDataModule(dataset=dataset, batch_size=5, test_size=0.4)
data_module
data_module.setup()print(f"Validation set size: {len(data_module.data_val)}")
= data_module.val_dataloader()
dl # Iterate over the data in the DataLoader
for batch in dl:
= batch
inputs, targets print(f"Batch Size: {inputs.size(0)}")
print(f"Inputs Shape: {inputs.shape}")
print(f"Targets Shape: {targets.shape}")
print("---------------")
print(f"Inputs: {inputs}")
print(f"Targets: {targets}")
break
Validation set size: 106
Batch Size: 5
Inputs Shape: torch.Size([5, 10])
Targets Shape: torch.Size([5])
---------------
Inputs: tensor([[ 0.0163, -0.0446, 0.0736, -0.0412, -0.0043, -0.0135, -0.0139, -0.0011,
0.0429, 0.0445],
[ 0.0453, -0.0446, 0.0714, 0.0012, -0.0098, -0.0010, 0.0155, -0.0395,
-0.0412, -0.0715],
[ 0.0308, 0.0507, 0.0326, 0.0494, -0.0401, -0.0436, -0.0692, 0.0343,
0.0630, 0.0031],
[ 0.0235, 0.0507, -0.0396, -0.0057, -0.0484, -0.0333, 0.0118, -0.0395,
-0.1016, -0.0674],
[-0.0091, 0.0507, 0.0013, -0.0022, 0.0796, 0.0701, 0.0339, -0.0026,
0.0267, 0.0818]])
Targets: tensor([275., 141., 208., 78., 142.])
30.4.5 The test_dataloader()
Method
Returns the test dataloader, i.e., a Pytorch DataLoader instance using the test dataset. It simply returns a DataLoader with the data_test
set that was created in the setup()
method as described in Section 30.4.2.3.
def test_dataloader(self) -> DataLoader:
return DataLoader(self.data_test, batch_size=self.batch_size, num_workers=self.num_workers)
The test_dataloader()
method can be used as follows:
from spotpython.data.lightdatamodule import LightDataModule
from spotpython.data.diabetes import Diabetes
= Diabetes(target_type=torch.float)
dataset = LightDataModule(dataset=dataset, batch_size=5, test_size=0.4)
data_module
data_module.setup()print(f"Test set size: {len(data_module.data_test)}")
= data_module.test_dataloader()
dl # Iterate over the data in the DataLoader
for batch in dl:
= batch
inputs, targets print(f"Batch Size: {inputs.size(0)}")
print(f"Inputs Shape: {inputs.shape}")
print(f"Targets Shape: {targets.shape}")
print("---------------")
print(f"Inputs: {inputs}")
print(f"Targets: {targets}")
break
Test set size: 177
Batch Size: 5
Inputs Shape: torch.Size([5, 10])
Targets Shape: torch.Size([5])
---------------
Inputs: tensor([[ 0.0562, -0.0446, -0.0579, -0.0080, 0.0521, 0.0491, 0.0560, -0.0214,
-0.0283, 0.0445],
[ 0.0018, -0.0446, -0.0709, -0.0229, -0.0016, -0.0010, 0.0266, -0.0395,
-0.0225, 0.0072],
[-0.0527, -0.0446, 0.0542, -0.0263, -0.0552, -0.0339, -0.0139, -0.0395,
-0.0741, -0.0591],
[ 0.0054, -0.0446, -0.0482, -0.0126, 0.0012, -0.0066, 0.0634, -0.0395,
-0.0514, -0.0591],
[-0.0527, -0.0446, -0.0094, -0.0057, 0.0397, 0.0447, 0.0266, -0.0026,
-0.0181, -0.0135]])
Targets: tensor([158., 49., 142., 96., 59.])
30.4.6 The predict_dataloader()
Method
Returns the prediction dataloader, i.e., a Pytorch DataLoader instance using the prediction dataset. It simply returns a DataLoader with the data_predict
set that was created in the setup()
method as described in Section 30.4.2.4.
The batch_size
is set to the length of the data_predict
set.
def predict_dataloader(self) -> DataLoader:
return DataLoader(self.data_predict, batch_size=len(self.data_predict), num_workers=self.num_workers)
The predict_dataloader()
method can be used as follows:
from spotpython.data.lightdatamodule import LightDataModule
from spotpython.data.diabetes import Diabetes
= Diabetes(target_type=torch.float)
dataset = LightDataModule(dataset=dataset, batch_size=5, test_size=0.4)
data_module
data_module.setup()print(f"Test set size: {len(data_module.data_predict)}")
= data_module.predict_dataloader()
dl # Iterate over the data in the DataLoader
for batch in dl:
= batch
inputs, targets print(f"Batch Size: {inputs.size(0)}")
print(f"Inputs Shape: {inputs.shape}")
print(f"Targets Shape: {targets.shape}")
print("---------------")
print(f"Inputs: {inputs}")
print(f"Targets: {targets}")
break
Test set size: 177
Batch Size: 177
Inputs Shape: torch.Size([177, 10])
Targets Shape: torch.Size([177])
---------------
Inputs: tensor([[ 0.0562, -0.0446, -0.0579, ..., -0.0214, -0.0283, 0.0445],
[ 0.0018, -0.0446, -0.0709, ..., -0.0395, -0.0225, 0.0072],
[-0.0527, -0.0446, 0.0542, ..., -0.0395, -0.0741, -0.0591],
...,
[ 0.0090, -0.0446, -0.0321, ..., -0.0764, -0.0119, -0.0384],
[-0.0273, -0.0446, -0.0666, ..., -0.0395, -0.0358, -0.0094],
[ 0.0817, 0.0507, 0.0067, ..., 0.0919, 0.0547, 0.0072]])
Targets: tensor([158., 49., 142., 96., 59., 74., 137., 136., 39., 66., 310., 198.,
235., 116., 55., 177., 59., 246., 53., 135., 88., 198., 186., 217.,
51., 118., 153., 180., 51., 229., 84., 72., 237., 142., 185., 91.,
88., 148., 179., 144., 25., 89., 42., 60., 124., 170., 215., 263.,
178., 245., 202., 97., 321., 71., 123., 220., 132., 243., 61., 102.,
187., 70., 242., 134., 63., 72., 88., 219., 127., 146., 122., 143.,
220., 293., 59., 317., 60., 140., 65., 277., 90., 96., 109., 190.,
90., 52., 160., 233., 230., 175., 68., 272., 144., 70., 68., 163.,
71., 93., 263., 118., 220., 90., 232., 120., 163., 88., 85., 52.,
181., 232., 212., 332., 81., 214., 145., 268., 115., 93., 64., 156.,
128., 200., 281., 103., 220., 66., 48., 246., 42., 150., 125., 109.,
129., 97., 265., 97., 173., 216., 237., 121., 42., 151., 31., 68.,
137., 221., 283., 124., 243., 150., 69., 306., 182., 252., 132., 258.,
121., 110., 292., 101., 275., 141., 208., 78., 142., 185., 167., 258.,
144., 89., 225., 140., 303., 236., 87., 77., 131.])
30.5 Using the LightDataModule
in the train_model()
Method
First, a LightDataModule
object is created and the setup()
method is called.
= LightDataModule(
dm =fun_control["data_set"],
dataset=config["batch_size"],
batch_size=fun_control["num_workers"],
num_workers=fun_control["test_size"],
test_size=fun_control["test_seed"],
test_seed
) dm.setup()
Then, the Trainer
is initialized.
# Init trainer
= L.Trainer(
trainer =os.path.join(fun_control["CHECKPOINT_PATH"], config_id),
default_root_dir=model.hparams.epochs,
max_epochs=fun_control["accelerator"],
accelerator=fun_control["devices"],
devices=TensorBoardLogger(
logger=fun_control["TENSORBOARD_PATH"],
save_dir=config_id,
version=True,
default_hp_metric=fun_control["log_graph"],
log_graph
),=[
callbacks="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
EarlyStopping(monitor
],=enable_progress_bar,
enable_progress_bar )
Next, the fit()
method is called to train the model.
# Pass the datamodule as arg to trainer.fit to override model hooks :)
=model, datamodule=dm) trainer.fit(model
Finally, the validate()
method is called to validate the model. The validate()
method returns the validation loss.
# Test best model on validation and test set
# result = trainer.validate(model=model, datamodule=dm, ckpt_path="last")
= trainer.validate(model=model, datamodule=dm)
result # unlist the result (from a list of one dict)
= result[0]
result return result["val_loss"]
30.6 Further Information
30.6.1 Preprocessing
Preprocessing is handled by Lightning
and PyTorch
. It is described in the LIGHTNINGDATAMODULE documentation. Here you can find information about the transforms
methods.