27Hyperparameter Tuning with PyTorch Lightning and User Data Sets
In this section, we will show how user specfied data can be used for the PyTorch Lightning hyperparameter tuning workflow with spotpython.
27.1 Loading a User Specified Data Set
Using a user-specified data set is straightforward.
The user simply needs to provide a data set and loads is as a spotpythonCVSDataset() class by specifying the path, filename, and target column.
Consider the following example, where the user has a data set stored in the userData directory. The data set is stored in a file named data.csv. The target column is named target. To show the data, it is loaded as a pandas data frame and the first 5 rows are displayed. This step is not necessary for the hyperparameter tuning process, but it is useful for understanding the data.
# load the csv data set as a pandas dataframe and dislay the first 5 rowsimport pandas as pddata = pd.read_csv("./userData/data.csv")print(data.head())
The following step is not necessary for the hyperparameter tuning process, but it is useful for understanding the data. The data set is loaded as a DataLoader from torch.utils.data to check the data.
# Set batch size for DataLoaderbatch_size =5# Create DataLoaderfrom torch.utils.data import DataLoaderdataloader = DataLoader(data_set, batch_size=batch_size, shuffle=False)# Iterate over the data in the DataLoaderfor batch in dataloader: inputs, targets = batchprint(f"Batch Size: {inputs.size(0)}")print(f"Inputs Shape: {inputs.shape}")print(f"Targets Shape: {targets.shape}")print("---------------")print(f"Inputs: {inputs}")print(f"Targets: {targets}")break
Similar to the setting from Section 26.1, the hyperparameter tuning setup is defined. Instead of using the Diabetes data set, the user data set is used. The data_set parameter is set to the user data set. The fun_control dictionary is set up via the fun_control_init function.
Note, that we have modified the fun_evals parameter to 12 and the init_size to 7 to reduce the computational time for this example.
This section showed how to use user-specified data sets for the hyperparameter tuning process with spotpython. The user needs to provide the data set and load it as a spotpythonCSVDataset() class.