28  HPT PyTorch Lightning Transformer: Diabetes

In this tutorial, we will show how spotPython can be integrated into the PyTorch Lightning training workflow for a regression task.

This chapter describes the hyperparameter tuning of a PyTorch Lightning network on the Diabetes data set. This is a PyTorch Dataset for regression. A toy data set from scikit-learn. Ten baseline variables, age, sex, body mass index, average blood pressure, and six blood serum measurements were obtained for each of n = 442 diabetes patients, as well as the response of interest, a quantitative measure of disease progression one year after baseline.

28.1 Step 1: Setup

  • Before we consider the detailed experimental setup, we select the parameters that affect run time, initial design size, etc.
  • The parameter MAX_TIME specifies the maximum run time in seconds.
  • The parameter INIT_SIZE specifies the initial design size.
  • The parameter WORKERS specifies the number of workers.
  • The prefix PREFIX is used for the experiment name and the name of the log file.
  • The parameter DEVICE specifies the device to use for training.
from spotPython.utils.device import getDevice
from math import inf

MAX_TIME = 1
FUN_EVALS = inf
INIT_SIZE = 5
WORKERS = 0
PREFIX="036"
DEVICE = getDevice()
DEVICES = 1
TEST_SIZE = 0.3
TORCH_METRIC = "mean_squared_error"
Caution: Run time and initial design size should be increased for real experiments
  • MAX_TIME is set to one minute for demonstration purposes. For real experiments, this should be increased to at least 1 hour.
  • INIT_SIZE is set to 5 for demonstration purposes. For real experiments, this should be increased to at least 10.
  • WORKERS is set to 0 for demonstration purposes. For real experiments, this should be increased. See the warnings that are printed when the number of workers is set to 0.
Note: Device selection
  • Although there are no .cuda() or .to(device) calls required, because Lightning does these for you, see LIGHTNINGMODULE, we would like to know which device is used. Threrefore, we imitate the LightningModule behaviour which selects the highest device.
  • The method spotPython.utils.device.getDevice() returns the device that is used by Lightning.

28.2 Step 2: Initialization of the fun_control Dictionary

spotPython uses a Python dictionary for storing the information required for the hyperparameter tuning process.

from spotPython.utils.init import fun_control_init
import numpy as np
fun_control = fun_control_init(
    _L_in=10,
    _L_out=1,
    _torchmetric=TORCH_METRIC,
    PREFIX=PREFIX,
    TENSORBOARD_CLEAN=True,
    device=DEVICE,
    enable_progress_bar=False,
    fun_evals=FUN_EVALS,
    log_level=10,
    max_time=MAX_TIME,
    num_workers=WORKERS,
    show_progress=True,
    test_size=TEST_SIZE,
    tolerance_x=np.sqrt(np.spacing(1)),
    )

28.3 Step 3: Loading the Diabetes Data Set

from spotPython.hyperparameters.values import set_control_key_value
from spotPython.data.diabetes import Diabetes
dataset = Diabetes()
set_control_key_value(control_dict=fun_control,
                        key="data_set",
                        value=dataset,
                        replace=True)
print(len(dataset))
Note: Data Set and Data Loader
  • As shown below, a DataLoader from torch.utils.data can be used to check the data.
# Set batch size for DataLoader
batch_size = 5
# Create DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Iterate over the data in the DataLoader
for batch in dataloader:
    inputs, targets = batch
    print(f"Batch Size: {inputs.size(0)}")
    print(f"Inputs Shape: {inputs.shape}")
    print(f"Targets Shape: {targets.shape}")
    print("---------------")
    print(f"Inputs: {inputs}")
    print(f"Targets: {targets}")
    break

28.4 Step 4: Preprocessing

Preprocessing is handled by Lightning and PyTorch. It is described in the LIGHTNINGDATAMODULE documentation. Here you can find information about the transforms methods.

28.5 Step 5: Select the Core Model (algorithm) and core_model_hyper_dict

spotPython includes the NetLightRegression class [SOURCE] for configurable neural networks. The class is imported here. It inherits from the class Lightning.LightningModule, which is the base class for all models in Lightning. Lightning.LightningModule is a subclass of torch.nn.Module and provides additional functionality for the training and testing of neural networks. The class Lightning.LightningModule is described in the Lightning documentation.

  • Here we simply add the NN Model to the fun_control dictionary by calling the function add_core_model_to_fun_control:
from spotPython.light.regression.transformerlightregression import TransformerLightRegression
from spotPython.hyperdict.light_hyper_dict import LightHyperDict
from spotPython.hyperparameters.values import add_core_model_to_fun_control
add_core_model_to_fun_control(fun_control=fun_control,
                              core_model=TransformerLightRegression,
                              hyper_dict=LightHyperDict)

The hyperparameters of the model are specified in the core_model_hyper_dict dictionary [SOURCE].

28.6 Step 6: Modify hyper_dict Hyperparameters for the Selected Algorithm aka core_model

spotPython provides functions for modifying the hyperparameters, their bounds and factors as well as for activating and de-activating hyperparameters without re-compilation of the Python source code.

Caution: Small number of epochs for demonstration purposes
  • epochs and patience are set to small values for demonstration purposes. These values are too small for a real application.
  • More resonable values are, e.g.:
    • set_control_hyperparameter_value(fun_control, "epochs", [7, 9]) and
    • set_control_hyperparameter_value(fun_control, "patience", [2, 7])
from spotPython.hyperparameters.values import set_control_hyperparameter_value

# set_control_hyperparameter_value(fun_control, "l1", [2, 3])
# set_control_hyperparameter_value(fun_control, "epochs", [5, 7])
# set_control_hyperparameter_value(fun_control, "batch_size", [3, 4])
# set_control_hyperparameter_value(fun_control, "optimizer", [
#                 "Adadelta",
#                 "Adagrad",
#                 "Adam",
#                 "Adamax",                
#             ])
# set_control_hyperparameter_value(fun_control, "dropout_prob", [0.01, 0.1])
# set_control_hyperparameter_value(fun_control, "lr_mult", [0.5, 5.0])
# set_control_hyperparameter_value(fun_control, "patience", [3, 5])
# set_control_hyperparameter_value(fun_control, "act_fn",[
#                 "ReLU",
#                 "LeakyReLU",
#             ] )
set_control_hyperparameter_value(fun_control, "initialization",["Default"] )

Now, the dictionary fun_control contains all information needed for the hyperparameter tuning. Before the hyperparameter tuning is started, it is recommended to take a look at the experimental design. The method gen_design_table [SOURCE] generates a design table as follows:

from spotPython.utils.eda import gen_design_table
print(gen_design_table(fun_control))

This allows to check if all information is available and if the information is correct.

Note: Hyperparameters of the Tuned Model and the fun_control Dictionary

The updated fun_control dictionary can be shown with the command fun_control["core_model_hyper_dict"].

28.7 Step 7: Data Splitting, the Objective (Loss) Function and the Metric

28.7.1 Evaluation

The evaluation procedure requires the specification of two elements:

  1. the way how the data is split into a train and a test set
  2. the loss function (and a metric).
Caution: Data Splitting in Lightning

The data splitting is handled by Lightning.

28.7.2 Loss Function

The loss function is specified in the configurable network class [SOURCE] We will use MSE.

28.7.3 Metric

  • Similar to the loss function, the metric is specified in the configurable network class [SOURCE].
Caution: Loss Function and Metric in Lightning
  • The loss function and the metric are not hyperparameters that can be tuned with spotPython.
  • They are handled by Lightning.

28.8 Step 8: Calling the SPOT Function

28.8.1 Preparing the SPOT Call

from spotPython.utils.init import design_control_init, surrogate_control_init
design_control = design_control_init(init_size=INIT_SIZE)

surrogate_control = surrogate_control_init(noise=True,
                                            n_theta=2)
Note: Modifying Values in the Control Dictionaries
  • The values in the control dictionaries can be modified with the function set_control_key_value [SOURCE], for example:
set_control_key_value(control_dict=surrogate_control,
                        key="noise",
                        value=True,
                        replace=True)                       
set_control_key_value(control_dict=surrogate_control,
                        key="n_theta",
                        value=2,
                        replace=True)      

28.8.2 The Objective Function fun

The objective function fun from the class HyperLight [SOURCE] is selected next. It implements an interface from PyTorch’s training, validation, and testing methods to spotPython.

from spotPython.fun.hyperlight import HyperLight
fun = HyperLight(log_level=10).fun

28.8.3 Showing the fun_control Dictionary

import pprint
pprint.pprint(fun_control)

28.8.4 Starting the Hyperparameter Tuning

The spotPython hyperparameter tuning is started by calling the Spot function [SOURCE].

from spotPython.spot import spot
spot_tuner = spot.Spot(fun=fun,
                       fun_control=fun_control,
                       design_control=design_control,
                       surrogate_control=surrogate_control)
spot_tuner.run()

28.9 Step 9: Tensorboard

The textual output shown in the console (or code cell) can be visualized with Tensorboard.

tensorboard --logdir="runs/"

Further information can be found in the PyTorch Lightning documentation for Tensorboard.

28.10 Step 10: Results

After the hyperparameter tuning run is finished, the results can be analyzed.

spot_tuner.plot_progress(log_y=False,
    filename="./figures/" + PREFIX +"_progress.png")
from spotPython.utils.eda import gen_design_table
print(gen_design_table(fun_control=fun_control, spot=spot_tuner))
spot_tuner.plot_importance(threshold=50,
    filename="./figures/" + PREFIX + "_importance.png")

28.10.1 Get the Tuned Architecture

from spotPython.hyperparameters.values import get_tuned_architecture
config = get_tuned_architecture(spot_tuner, fun_control)
print(config)
  • Test on the full data set
from spotPython.light.testmodel import test_model
test_model(config, fun_control)
from spotPython.light.loadmodel import load_light_from_checkpoint

model_loaded = load_light_from_checkpoint(config, fun_control)
# filename = "./figures/" + PREFIX
filename = None
spot_tuner.plot_important_hyperparameter_contour(filename=filename, threshold=50)

28.10.2 Parallel Coordinates Plot

spot_tuner.parallel_plot()

28.10.3 Cross Validation With Lightning

  • The KFold class from sklearn.model_selection is used to generate the folds for cross-validation.
  • These mechanism is used to generate the folds for the final evaluation of the model.
  • The CrossValidationDataModule class [SOURCE] is used to generate the folds for the hyperparameter tuning process.
  • It is called from the cv_model function [SOURCE].
from spotPython.light.cvmodel import cv_model
set_control_key_value(control_dict=fun_control,
                        key="k_folds",
                        value=2,
                        replace=True)
set_control_key_value(control_dict=fun_control,
                        key="test_size",
                        value=0.6,
                        replace=True)
cv_model(config, fun_control)

28.10.4 Plot all Combinations of Hyperparameters

  • Warning: this may take a while.
PLOT_ALL = False
if PLOT_ALL:
    n = spot_tuner.k
    for i in range(n-1):
        for j in range(i+1, n):
            spot_tuner.plot_contour(i=i, j=j, min_z=min_z, max_z = max_z)

28.10.5 Visualizing the Activation Distribution (Under Development)

Reference:

After we have trained the models, we can look at the actual activation values that find inside the model. For instance, how many neurons are set to zero in ReLU? Where do we find most values in Tanh? To answer these questions, we can write a simple function which takes a trained model, applies it to a batch of images, and plots the histogram of the activations inside the network:

from spotPython.torch.activation import Sigmoid, Tanh, ReLU, LeakyReLU, ELU, Swish
act_fn_by_name = {"sigmoid": Sigmoid, "tanh": Tanh, "relu": ReLU, "leakyrelu": LeakyReLU, "elu": ELU, "swish": Swish}
from spotPython.hyperparameters.values import get_one_config_from_X
X = spot_tuner.to_all_dim(spot_tuner.min_X.reshape(1,-1))
config = get_one_config_from_X(X, fun_control)
model = fun_control["core_model"](**config, _L_in=64, _L_out=11, _torchmetric=TORCH_METRIC)
model
# from spotPython.utils.eda import visualize_activations
# visualize_activations(model, color=f"C{0}")