preprocessing.forecaster_config

preprocessing.forecaster_config

Forecaster configuration utilities.

This module provides functions for initializing and validating forecaster configuration parameters like lags and weights.

Functions

Name Description
check_select_fit_kwargs Check if fit_kwargs is a dict and select only keys used by estimator’s fit.
initialize_lags Validate and normalize lag specification for forecasting.
initialize_weights Validate and initialize weight function configuration for forecasting.

check_select_fit_kwargs

preprocessing.forecaster_config.check_select_fit_kwargs(
    estimator,
    fit_kwargs=None,
)

Check if fit_kwargs is a dict and select only keys used by estimator’s fit.

This function validates that fit_kwargs is a dictionary, warns about unused arguments, removes ‘sample_weight’ (which should be handled via weight_func), and returns a dictionary containing only the arguments accepted by the estimator’s fit method.

Parameters

Name Type Description Default
estimator Any Scikit-learn compatible estimator. required
fit_kwargs Optional[dict] Dictionary of arguments to pass to the estimator’s fit method. None

Returns

Name Type Description
dict Dictionary with only the arguments accepted by the estimator’s fit method.

Raises

Name Type Description
TypeError If fit_kwargs is not a dict.

Warns

If fit_kwargs contains keys not used by fit method, or if ‘sample_weight’ is present (it gets removed).

Examples

import warnings
from sklearn.linear_model import Ridge
from spotforecast2_safe.preprocessing.forecaster_config import check_select_fit_kwargs

estimator = Ridge()

# sample_weight is removed (should be passed via weight_func in forecaster);
# invalid_arg is ignored; both trigger IgnoredArgumentWarning
kwargs = {"sample_weight": [1, 1], "invalid_arg": 10}
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    filtered = check_select_fit_kwargs(estimator, kwargs)
assert filtered == {}
print(f"filtered fit_kwargs: {filtered}")
filtered fit_kwargs: {}

initialize_lags

preprocessing.forecaster_config.initialize_lags(forecaster_name, lags)

Validate and normalize lag specification for forecasting.

This function converts various lag specifications (int, list, tuple, range, ndarray) into a standardized format: sorted numpy array, lag names, and maximum lag value.

Parameters

Name Type Description Default
forecaster_name str Name of the forecaster class for error messages. required
lags Any Lag specification in one of several formats: - int: Creates lags from 1 to lags (e.g., 5 → [1,2,3,4,5]) - list/tuple/range: Converted to numpy array - numpy.ndarray: Validated and used directly - None: Returns (None, None, None) required

Returns

Name Type Description
Optional[np.ndarray] Tuple containing:
Optional[List[str]] - lags: Sorted numpy array of lag values (or None)
Optional[int] - lags_names: List of lag names like [‘lag_1’, ‘lag_2’, …] (or None)
Tuple[Optional[np.ndarray], Optional[List[str]], Optional[int]] - max_lag: Maximum lag value (or None)

Raises

Name Type Description
ValueError If lags < 1, empty array, or not 1-dimensional.
TypeError If lags is not an integer, not in the right format for the forecaster, or array contains non-integer values.

Examples

import numpy as np
from spotforecast2_safe.preprocessing.forecaster_config import initialize_lags

# Integer input
lags, names, max_lag = initialize_lags("ForecasterRecursive", 3)
assert list(lags) == [1, 2, 3]
assert names == ["lag_1", "lag_2", "lag_3"]
assert max_lag == 3

# List input
lags, names, max_lag = initialize_lags("ForecasterRecursive", [1, 3, 5])
assert list(lags) == [1, 3, 5]
assert names == ["lag_1", "lag_3", "lag_5"]

# Range input
lags, names, max_lag = initialize_lags("ForecasterRecursive", range(1, 4))
assert list(lags) == [1, 2, 3]

# None input
lags, names, max_lag = initialize_lags("ForecasterRecursive", None)
assert lags is None
print("initialize_lags: valid inputs OK")
initialize_lags: valid inputs OK
from spotforecast2_safe.preprocessing.forecaster_config import initialize_lags

# Invalid: lags < 1
try:
    initialize_lags("ForecasterRecursive", 0)
except ValueError as e:
    print(f"ValueError: {e}")

# Invalid: negative lag in list
try:
    initialize_lags("ForecasterRecursive", [1, -2, 3])
except ValueError as e:
    print(f"ValueError: {e}")
ValueError: Minimum value of lags allowed is 1.
ValueError: Minimum value of lags allowed is 1.

initialize_weights

preprocessing.forecaster_config.initialize_weights(
    forecaster_name,
    estimator,
    weight_func,
    series_weights,
)

Validate and initialize weight function configuration for forecasting.

This function validates weight_func and series_weights, extracts source code from weight functions for serialization, and checks if the estimator supports sample weights in its fit method.

Parameters

Name Type Description Default
forecaster_name str Name of the forecaster class. required
estimator Any Scikit-learn compatible estimator or pipeline. required
weight_func Any Weight function specification: - Callable: Single weight function - dict: Dictionary of weight functions (for MultiSeries forecasters) - None: No weighting required
series_weights Any Dictionary of series-level weights (for MultiSeries forecasters). - dict: Maps series names to weight values - None: No series weighting required

Returns

Name Type Description
Any Tuple containing:
Optional[Union[str, dict]] - weight_func: Validated weight function (or None if invalid)
Any - source_code_weight_func: Source code of weight function(s) for serialization (or None)
Tuple[Any, Optional[Union[str, dict]], Any] - series_weights: Validated series weights (or None if invalid)

Raises

Name Type Description
TypeError If weight_func is not Callable/dict (depending on forecaster type), or if series_weights is not a dict.

Warns

If estimator doesn’t support sample_weight.

Examples

import numpy as np
from sklearn.linear_model import Ridge
from spotforecast2_safe.preprocessing.forecaster_config import initialize_weights

def custom_weights(index):
    return np.ones(len(index))

estimator = Ridge()

# Valid callable weight function
wf, source, sw = initialize_weights(
    "ForecasterRecursive", estimator, custom_weights, None
)
assert wf is not None
assert isinstance(source, str)
assert sw is None

# No weight function
wf, source, sw = initialize_weights(
    "ForecasterRecursive", estimator, None, None
)
assert wf is None
assert source is None

# Invalid type raises TypeError
try:
    initialize_weights("ForecasterRecursive", estimator, "invalid", None)
except TypeError as e:
    print(f"TypeError: {e}")
TypeError: Argument `weight_func` must be a Callable. Got <class 'str'>.