model_selection.grid_search

model_selection.grid_search

Functions

Name Description
grid_search_forecaster Exhaustive grid search over parameter values for a Forecaster.

grid_search_forecaster

model_selection.grid_search.grid_search_forecaster(
    forecaster,
    y,
    cv,
    param_grid,
    metric,
    exog=None,
    lags_grid=None,
    return_best=True,
    n_jobs='auto',
    verbose=False,
    show_progress=True,
    suppress_warnings=False,
    output_file=None,
)

Exhaustive grid search over parameter values for a Forecaster.

Examples

import warnings
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge
from spotforecast2_safe.forecaster.recursive import ForecasterRecursive
from spotforecast2_safe.splitter import TimeSeriesFold
from spotforecast2.model_selection.grid_search import grid_search_forecaster

rng = np.random.default_rng(0)
idx = pd.date_range("2020-01-01", periods=120, freq="h")
y = pd.Series(rng.normal(0, 1, 120), index=idx)

forecaster = ForecasterRecursive(estimator=Ridge(), lags=3)
cv = TimeSeriesFold(steps=3, initial_train_size=90, refit=False)
param_grid = {"alpha": [0.1, 1.0]}

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    results = grid_search_forecaster(
        forecaster=forecaster,
        y=y,
        cv=cv,
        param_grid=param_grid,
        metric="mean_absolute_error",
        lags_grid=[3, 5],
        return_best=True,
        n_jobs=1,
        verbose=False,
        show_progress=False,
        suppress_warnings=True,
    )

print(results[["lags_label", "params", "mean_absolute_error"]].head())
assert results.shape == (4, 5)
assert "mean_absolute_error" in results.columns
Number of models compared: 4. Training models...
`Forecaster` refitted using the best-found lags and parameters, and the whole data set: 
  Lags: [1 2 3 4 5] 
  Parameters: {'alpha': 1.0}
  Backtesting metric: 0.7679862927050902
        lags_label          params  mean_absolute_error
0  [1, 2, 3, 4, 5]  {'alpha': 1.0}             0.767986
1  [1, 2, 3, 4, 5]  {'alpha': 0.1}             0.767986
2        [1, 2, 3]  {'alpha': 1.0}             0.772652
3        [1, 2, 3]  {'alpha': 0.1}             0.772652