Skip to content

predictmodel

predict_model(config, fun_control)

Predicts using the given configuration and function control parameters.

Parameters:

Name Type Description Default
config dict

A dictionary containing the configuration parameters for the model.

required
fun_control dict

A dictionary containing the function control parameters.

required

Returns:

Type Description
Tuple[float, float]

Tuple[float, float]: The validation loss and the hyperparameter metric of the tested model.

Notes
  • test_model saves the last checkpoint of the model from the training phase, which is called as follows: trainer.fit(model=model, datamodule=dm).
  • The test result is evaluated with the following function call: trainer.test(datamodule=dm, ckpt_path="last").

Examples:

>>> from spotpython.utils.init import fun_control_init
     from spotpython.light.netlightregression import NetLightRegression
    from spotpython.hyperdict.light_hyper_dict import LightHyperDict
    from spotpython.hyperparameters.values import (add_core_model_to_fun_control,
      get_default_hyperparameters_as_array)
    from spotpython.data.diabetes import Diabetes
    from spotpython.hyperparameters.values import set_control_key_value
    from spotpython.hyperparameters.values import (get_var_name, assign_values,
        generate_one_config_from_var_dict)
    import spotpython.light.testmodel as tm
    fun_control = fun_control_init(
        _L_in=10,
        _L_out=1,
        _torchmetric="mean_squared_error")
    dataset = Diabetes()
    set_control_key_value(control_dict=fun_control,
                            key="data_set",
                            value=dataset)
    add_core_model_to_fun_control(core_model=NetLightRegression,
                                fun_control=fun_control,
                                hyper_dict=LightHyperDict)
    X = get_default_hyperparameters_as_array(fun_control)
    var_dict = assign_values(X, get_var_name(fun_control))
    for config in generate_one_config_from_var_dict(var_dict, fun_control):
        y_test = tm.test_model(config, fun_control)
Source code in spotpython/light/predictmodel.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
    """
    Predicts using the given configuration and function control parameters.

    Args:
        config (dict): A dictionary containing the configuration parameters for the model.
        fun_control (dict): A dictionary containing the function control parameters.

    Returns:
        Tuple[float, float]: The validation loss and the hyperparameter metric of the tested model.

    Notes:
        * `test_model` saves the last checkpoint of the model from the training phase, which is called as follows:
            `trainer.fit(model=model, datamodule=dm)`.
        * The test result is evaluated with the following function call:
        `trainer.test(datamodule=dm, ckpt_path="last")`.

    Examples:
        >>> from spotpython.utils.init import fun_control_init
             from spotpython.light.netlightregression import NetLightRegression
            from spotpython.hyperdict.light_hyper_dict import LightHyperDict
            from spotpython.hyperparameters.values import (add_core_model_to_fun_control,
              get_default_hyperparameters_as_array)
            from spotpython.data.diabetes import Diabetes
            from spotpython.hyperparameters.values import set_control_key_value
            from spotpython.hyperparameters.values import (get_var_name, assign_values,
                generate_one_config_from_var_dict)
            import spotpython.light.testmodel as tm
            fun_control = fun_control_init(
                _L_in=10,
                _L_out=1,
                _torchmetric="mean_squared_error")
            dataset = Diabetes()
            set_control_key_value(control_dict=fun_control,
                                    key="data_set",
                                    value=dataset)
            add_core_model_to_fun_control(core_model=NetLightRegression,
                                        fun_control=fun_control,
                                        hyper_dict=LightHyperDict)
            X = get_default_hyperparameters_as_array(fun_control)
            var_dict = assign_values(X, get_var_name(fun_control))
            for config in generate_one_config_from_var_dict(var_dict, fun_control):
                y_test = tm.test_model(config, fun_control)
    """
    _L_in = fun_control["_L_in"]
    _L_out = fun_control["_L_out"]
    _L_cond = fun_control["_L_cond"]
    _torchmetric = fun_control["_torchmetric"]
    if fun_control["enable_progress_bar"] is None:
        enable_progress_bar = False
    else:
        enable_progress_bar = fun_control["enable_progress_bar"]
    # Add "TEST" postfix to config_id
    # config id is unique. Since the model is loaded from a checkpoint,
    # the config id is generated here without a timestamp. This differs from
    # the config id generated in cvmodel.py and trainmodel.py.
    config_id = generate_config_id(config, timestamp=False) + "_TEST"
    if fun_control["data_module"] is None:
        dm = LightDataModule(
            dataset=fun_control["data_set"],
            data_full_train=fun_control["data_full_train"],
            data_test=fun_control["data_test"],
            batch_size=config["batch_size"],
            num_workers=fun_control["num_workers"],
            test_size=fun_control["test_size"],
            test_seed=fun_control["test_seed"],
            scaler=fun_control["scaler"],
            verbosity=fun_control["verbosity"],
        )
    else:
        dm = fun_control["data_module"]
    # TODO: Check if this is necessary:
    dm.setup(stage="train")
    # Init model from datamodule's attributes
    model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _L_cond=_L_cond, _torchmetric=_torchmetric)

    trainer = L.Trainer(
        # Where to save models
        default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),
        max_epochs=model.hparams.epochs,
        accelerator=fun_control["accelerator"],
        devices=fun_control["devices"],
        strategy=fun_control["strategy"],
        num_nodes=fun_control["num_nodes"],
        precision=fun_control["precision"],
        logger=TensorBoardLogger(
            save_dir=fun_control["TENSORBOARD_PATH"],
            version=config_id,
            default_hp_metric=True,
            log_graph=fun_control["log_graph"],
        ),
        callbacks=[
            EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False),
            ModelCheckpoint(dirpath=os.path.join(fun_control["CHECKPOINT_PATH"], config_id), save_last=True),  # Save the last checkpoint
        ],
        enable_progress_bar=enable_progress_bar,
    )
    # Pass the datamodule as arg to trainer.fit to override model hooks :)
    trainer.fit(model=model, datamodule=dm)

    # Changed in spotpython 0.18.12: commented out the following line
    dm.setup(stage="predict")

    # predictions = trainer.predict(model=model, datamodule=dm)
    # Changed in spotpython 0.18.12: use ckpt_path="last" to load the last checkpoint and not the model
    # predictions = trainer.predict(datamodule=dm, ckpt_path="last")
    # Changed in spotpython 0.19.5: use ckpt_path="best" to load the best checkpoint and not the model
    predictions = trainer.predict(datamodule=dm, ckpt_path="best")

    # # Load the last checkpoint
    # test_result = trainer.test(datamodule=dm, ckpt_path="last")
    # test_result = test_result[0]
    # print(f"test_model result: {test_result}")
    return predictions