Skip to content

trainmodel

train_model(config, fun_control, timestamp=True)

Trains a model using the given configuration and function control parameters.

Parameters:

Name Type Description Default
config dict

A dictionary containing the configuration parameters for the model.

required
fun_control dict

A dictionary containing the function control parameters.

required
timestamp bool

A boolean value indicating whether to include a timestamp in the config id. Default is True. If False, the string “_TRAIN” is appended to the config id.

True

Returns:

Name Type Description
float float

The validation loss of the trained model.

Examples:

>>> from spotPython.utils.init import fun_control_init
    from spotPython.light.netlightregression import NetLightRegression
    from spotPython.hyperdict.light_hyper_dict import LightHyperDict
    from spotPython.hyperparameters.values import (
        add_core_model_to_fun_control,
        get_default_hyperparameters_as_array)
    from spotPython.data.diabetes import Diabetes
    from spotPython.hyperparameters.values import set_control_key_value
    from spotPython.hyperparameters.values import get_var_name, assign_values, generate_one_config_from_var_dict
    from spotPython.light.traintest import train_model
    fun_control = fun_control_init(
        _L_in=10,
        _L_out=1,)
    # Select a dataset
    dataset = Diabetes()
    set_control_key_value(control_dict=fun_control,
                        key="data_set",
                        value=dataset)
    # Select a model
    add_core_model_to_fun_control(core_model=NetLightRegression,
                                fun_control=fun_control,
                                hyper_dict=LightHyperDict)
    # Select hyperparameters
    X = get_default_hyperparameters_as_array(fun_control)
    var_dict = assign_values(X, get_var_name(fun_control))
    for config in generate_one_config_from_var_dict(var_dict, fun_control):
        y = train_model(config, fun_control)
        break
    | Name   | Type       | Params | In sizes | Out sizes
    -------------------------------------------------------------
    0 | layers | Sequential | 157    | [16, 10] | [16, 1]
    -------------------------------------------------------------
    157       Trainable params
    0         Non-trainable params
    157       Total params
    0.001     Total estimated model params size (MB)
    Train_model(): Test set size: 266
    ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        Validate metric           DataLoader 0
    ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
            hp_metric             27462.841796875
            val_loss              27462.841796875
    ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    train_model result: {'val_loss': 27462.841796875, 'hp_metric': 27462.841796875}
Source code in spotPython/light/trainmodel.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> float:
    """
    Trains a model using the given configuration and function control parameters.

    Args:
        config (dict):
            A dictionary containing the configuration parameters for the model.
        fun_control (dict):
            A dictionary containing the function control parameters.
        timestamp (bool):
            A boolean value indicating whether to include a timestamp in the config id. Default is True.
            If False, the string "_TRAIN" is appended to the config id.

    Returns:
        float: The validation loss of the trained model.

    Examples:
        >>> from spotPython.utils.init import fun_control_init
            from spotPython.light.netlightregression import NetLightRegression
            from spotPython.hyperdict.light_hyper_dict import LightHyperDict
            from spotPython.hyperparameters.values import (
                add_core_model_to_fun_control,
                get_default_hyperparameters_as_array)
            from spotPython.data.diabetes import Diabetes
            from spotPython.hyperparameters.values import set_control_key_value
            from spotPython.hyperparameters.values import get_var_name, assign_values, generate_one_config_from_var_dict
            from spotPython.light.traintest import train_model
            fun_control = fun_control_init(
                _L_in=10,
                _L_out=1,)
            # Select a dataset
            dataset = Diabetes()
            set_control_key_value(control_dict=fun_control,
                                key="data_set",
                                value=dataset)
            # Select a model
            add_core_model_to_fun_control(core_model=NetLightRegression,
                                        fun_control=fun_control,
                                        hyper_dict=LightHyperDict)
            # Select hyperparameters
            X = get_default_hyperparameters_as_array(fun_control)
            var_dict = assign_values(X, get_var_name(fun_control))
            for config in generate_one_config_from_var_dict(var_dict, fun_control):
                y = train_model(config, fun_control)
                break
            | Name   | Type       | Params | In sizes | Out sizes
            -------------------------------------------------------------
            0 | layers | Sequential | 157    | [16, 10] | [16, 1]
            -------------------------------------------------------------
            157       Trainable params
            0         Non-trainable params
            157       Total params
            0.001     Total estimated model params size (MB)
            Train_model(): Test set size: 266
            ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                Validate metric           DataLoader 0
            ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                    hp_metric             27462.841796875
                    val_loss              27462.841796875
            ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
            train_model result: {'val_loss': 27462.841796875, 'hp_metric': 27462.841796875}

    """
    _L_in = fun_control["_L_in"]
    _L_out = fun_control["_L_out"]
    _torchmetric = fun_control["_torchmetric"]
    if fun_control["enable_progress_bar"] is None:
        enable_progress_bar = False
    else:
        enable_progress_bar = fun_control["enable_progress_bar"]
    if timestamp:
        # config id is unique. Since the model is not loaded from a checkpoint,
        # the config id is generated here with a timestamp.
        config_id = generate_config_id(config, timestamp=True)
    else:
        # config id is not time-dependent and therefore unique,
        # so that the model can be loaded from a checkpoint,
        # the config id is generated here without a timestamp.
        config_id = generate_config_id(config, timestamp=False) + "_TRAIN"
    model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
    initialization = config["initialization"]
    if initialization == "Xavier":
        xavier_init(model)
    elif initialization == "Kaiming":
        kaiming_init(model)
    else:
        pass

    dm = LightDataModule(
        dataset=fun_control["data_set"],
        batch_size=config["batch_size"],
        num_workers=fun_control["num_workers"],
        test_size=fun_control["test_size"],
        test_seed=fun_control["test_seed"],
    )
    # TODO: Check if this is necessary:
    # dm.setup()
    # print(f"train_model(): Test set size: {len(dm.data_test)}")
    # print(f"train_model(): Train set size: {len(dm.data_train)}")
    # print(f"train_model(): Batch size: {config['batch_size']}")

    # Callbacks
    callbacks = [
        EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
    ]
    if not timestamp:
        # add ModelCheckpoint only if timestamp is False
        callbacks.append(
            ModelCheckpoint(dirpath=os.path.join(fun_control["CHECKPOINT_PATH"], config_id), save_last=True)
        )  # Save the last checkpoint

    # Init trainer
    trainer = L.Trainer(
        # Where to save models
        default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),
        max_epochs=model.hparams.epochs,
        accelerator=fun_control["accelerator"],
        devices=fun_control["devices"],
        logger=TensorBoardLogger(
            save_dir=fun_control["TENSORBOARD_PATH"],
            version=config_id,
            default_hp_metric=True,
            log_graph=fun_control["log_graph"],
        ),
        callbacks=callbacks,
        enable_progress_bar=enable_progress_bar,
    )
    # Pass the datamodule as arg to trainer.fit to override model hooks :)
    trainer.fit(model=model, datamodule=dm)
    # Test best model on validation and test set
    # result = trainer.validate(model=model, datamodule=dm, ckpt_path="last")
    result = trainer.validate(model=model, datamodule=dm)
    # unlist the result (from a list of one dict)
    result = result[0]
    print(f"train_model result: {result}")
    return result["val_loss"]