utils.forecaster_config

utils.forecaster_config

Forecaster configuration utilities.

This module provides functions for initializing and validating forecaster configuration parameters like lags and weights.

Functions

Name Description
check_select_fit_kwargs Check if fit_kwargs is a dict and select only keys used by estimator’s fit.
initialize_lags Validate and normalize lag specification for forecasting.
initialize_weights Validate and initialize weight function configuration for forecasting.

check_select_fit_kwargs

utils.forecaster_config.check_select_fit_kwargs(estimator, fit_kwargs=None)

Check if fit_kwargs is a dict and select only keys used by estimator’s fit.

This function validates that fit_kwargs is a dictionary, warns about unused arguments, removes ‘sample_weight’ (which should be handled via weight_func), and returns a dictionary containing only the arguments accepted by the estimator’s fit method.

Parameters

Name Type Description Default
estimator Any Scikit-learn compatible estimator. required
fit_kwargs Optional[dict] Dictionary of arguments to pass to the estimator’s fit method. None

Returns

Name Type Description
dict Dictionary with only the arguments accepted by the estimator’s fit method.

Raises

Name Type Description
TypeError If fit_kwargs is not a dict.

Warns

If fit_kwargs contains keys not used by fit method, or if ‘sample_weight’ is present (it gets removed).

Examples

>>> from sklearn.linear_model import Ridge
>>> from spotforecast2.utils.forecaster_config import check_select_fit_kwargs
>>>
>>> estimator = Ridge()
>>> # Valid argument for Ridge.fit
>>> kwargs = {"sample_weight": [1, 1], "invalid_arg": 10}
>>> # sample_weight is removed (should be passed via weight_func in forecaster)
>>> # invalid_arg is ignored
>>> filtered = check_select_fit_kwargs(estimator, kwargs)
>>> filtered
{}

initialize_lags

utils.forecaster_config.initialize_lags(forecaster_name, lags)

Validate and normalize lag specification for forecasting.

This function converts various lag specifications (int, list, tuple, range, ndarray) into a standardized format: sorted numpy array, lag names, and maximum lag value.

Parameters

Name Type Description Default
forecaster_name str Name of the forecaster class for error messages. required
lags Any Lag specification in one of several formats: - int: Creates lags from 1 to lags (e.g., 5 → [1,2,3,4,5]) - list/tuple/range: Converted to numpy array - numpy.ndarray: Validated and used directly - None: Returns (None, None, None) required

Returns

Name Type Description
Optional[np.ndarray] Tuple containing:
Optional[List[str]] - lags: Sorted numpy array of lag values (or None)
Optional[int] - lags_names: List of lag names like [‘lag_1’, ‘lag_2’, …] (or None)
Tuple[Optional[np.ndarray], Optional[List[str]], Optional[int]] - max_lag: Maximum lag value (or None)

Raises

Name Type Description
ValueError If lags < 1, empty array, or not 1-dimensional.
TypeError If lags is not an integer, not in the right format for the forecaster, or array contains non-integer values.

Examples

>>> import numpy as np
>>> from spotforecast2.utils.forecaster_config import initialize_lags
>>>
>>> # Integer input
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", 3)
>>> lags
array([1, 2, 3])
>>> names
['lag_1', 'lag_2', 'lag_3']
>>> max_lag
3
>>>
>>> # List input
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", [1, 3, 5])
>>> lags
array([1, 3, 5])
>>> names
['lag_1', 'lag_3', 'lag_5']
>>>
>>> # Range input
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", range(1, 4))
>>> lags
array([1, 2, 3])
>>>
>>> # None input
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", None)
>>> lags is None
True
>>>
>>> # Invalid: lags < 1
>>> try:
...     initialize_lags("ForecasterRecursive", 0)
... except ValueError as e:
...     print("Error: Minimum value of lags allowed is 1")
Error: Minimum value of lags allowed is 1
>>>
>>> # Invalid: negative lags
>>> try:
...     initialize_lags("ForecasterRecursive", [1, -2, 3])
... except ValueError as e:
...     print("Error: Minimum value of lags allowed is 1")
Error: Minimum value of lags allowed is 1

initialize_weights

utils.forecaster_config.initialize_weights(
    forecaster_name,
    estimator,
    weight_func,
    series_weights,
)

Validate and initialize weight function configuration for forecasting.

This function validates weight_func and series_weights, extracts source code from weight functions for serialization, and checks if the estimator supports sample weights in its fit method.

Parameters

Name Type Description Default
forecaster_name str Name of the forecaster class. required
estimator Any Scikit-learn compatible estimator or pipeline. required
weight_func Any Weight function specification: - Callable: Single weight function - dict: Dictionary of weight functions (for MultiSeries forecasters) - None: No weighting required
series_weights Any Dictionary of series-level weights (for MultiSeries forecasters). - dict: Maps series names to weight values - None: No series weighting required

Returns

Name Type Description
Any Tuple containing:
Optional[Union[str, dict]] - weight_func: Validated weight function (or None if invalid)
Any - source_code_weight_func: Source code of weight function(s) for serialization (or None)
Tuple[Any, Optional[Union[str, dict]], Any] - series_weights: Validated series weights (or None if invalid)

Raises

Name Type Description
TypeError If weight_func is not Callable/dict (depending on forecaster type), or if series_weights is not a dict.

Warns

If estimator doesn’t support sample_weight.

Examples

>>> import numpy as np
>>> from sklearn.linear_model import Ridge
>>> from spotforecast2.utils.forecaster_config import initialize_weights
>>>
>>> # Simple weight function
>>> def custom_weights(index):
...     return np.ones(len(index))
>>>
>>> estimator = Ridge()
>>> wf, source, sw = initialize_weights(
...     "ForecasterRecursive", estimator, custom_weights, None
... )
>>> wf is not None
True
>>> isinstance(source, str)
True
>>>
>>> # No weight function
>>> wf, source, sw = initialize_weights(
...     "ForecasterRecursive", estimator, None, None
... )
>>> wf is None
True
>>> source is None
True
>>>
>>> # Invalid type for non-MultiSeries forecaster
>>> try:
...     initialize_weights("ForecasterRecursive", estimator, "invalid", None)
... except TypeError as e:
...     print("Error: weight_func must be Callable")
Error: weight_func must be Callable