34Hyperparameter Tuning with spotpython and PyTorch Lightning for the Diabetes Data Set Using a ResNet Model
In this section, we will show how spotpython can be integrated into the PyTorch Lightning training workflow for a regression task. It demonstrates how easy it is to use spotpython to tune hyperparameters for a PyTorch Lightning model.
After importing the necessary libraries, the fun_control dictionary is set up via the fun_control_init function. The fun_control dictionary contains
PREFIX: a unique identifier for the experiment
fun_evals: the number of function evaluations
max_time: the maximum run time in minutes
data_set: the data set. Here we use the Diabetes data set that is provided by spotpython.
core_model_name: the class name of the neural network model. This neural network model is provided by spotpython.
hyperdict: the hyperparameter dictionary. This dictionary is used to define the hyperparameters of the neural network model. It is also provided by spotpython.
_L_in: the number of input features. Since the Diabetes data set has 10 features, _L_in is set to 10.
_L_out: the number of output features. Since we want to predict a single value, _L_out is set to 1.
The HyperLight class is used to define the objective function fun. It connects the PyTorch and the spotpython methods and is provided by spotpython.
The method set_hyperparameter allows the user to modify default hyperparameter settings. Here we modify some hyperparameters to keep the model small and to decrease the tuning time.
After the hyperparameter tuning run is finished, the progress of the hyperparameter tuning can be visualized with spotpython’s method plot_progress. The black points represent the performace values (score or metric) of hyperparameter configurations from the initial design, whereas the red points represents the hyperparameter configurations found by the surrogate model based optimization.
spot_tuner.plot_progress()
34.1.2 Tuned Hyperparameters and Their Importance
Results can be printed in tabular form.
from spotpython.utils.eda import gen_design_tableprint(gen_design_table(fun_control=fun_control, spot=spot_tuner))
# set the value of the key "TENSORBOARD_CLEAN" to True in the fun_control dictionary and use the update() method to update the fun_control dictionaryimport os# if the directory "./runs" exists, delete itif os.path.exists("./runs"): os.system("rm -r ./runs")fun_control.update({"tensorboard_log": True})
from spotpython.light.testmodel import test_modelfrom spotpython.utils.init import get_feature_namestest_model(config, fun_control)get_feature_names(fun_control)
This section presented an introduction to the basic setup of hyperparameter tuning with spotpython and PyTorch Lightning using a ResNet model for the Diabetes data set.