predictmodel
predict_model(config, fun_control)
¶
Predicts using the given configuration and function control parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
dict
|
A dictionary containing the configuration parameters for the model. |
required |
fun_control |
dict
|
A dictionary containing the function control parameters. |
required |
Returns:
Type | Description |
---|---|
Tuple[float, float]
|
Tuple[float, float]: The validation loss and the hyperparameter metric of the tested model. |
Notes
test_model
saves the last checkpoint of the model from the training phase, which is called as follows:trainer.fit(model=model, datamodule=dm)
.- The test result is evaluated with the following function call:
trainer.test(datamodule=dm, ckpt_path="last")
.
Examples:
>>> from spotpython.utils.init import fun_control_init
from spotpython.light.netlightregression import NetLightRegression
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
from spotpython.hyperparameters.values import (add_core_model_to_fun_control,
get_default_hyperparameters_as_array)
from spotpython.data.diabetes import Diabetes
from spotpython.hyperparameters.values import set_control_key_value
from spotpython.hyperparameters.values import (get_var_name, assign_values,
generate_one_config_from_var_dict)
import spotpython.light.testmodel as tm
fun_control = fun_control_init(
_L_in=10,
_L_out=1,
_torchmetric="mean_squared_error")
dataset = Diabetes()
set_control_key_value(control_dict=fun_control,
key="data_set",
value=dataset)
add_core_model_to_fun_control(core_model=NetLightRegression,
fun_control=fun_control,
hyper_dict=LightHyperDict)
X = get_default_hyperparameters_as_array(fun_control)
var_dict = assign_values(X, get_var_name(fun_control))
for config in generate_one_config_from_var_dict(var_dict, fun_control):
y_test = tm.test_model(config, fun_control)
Source code in spotpython/light/predictmodel.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|