utils.forecaster_config
utils.forecaster_config
Forecaster configuration utilities.
This module provides functions for initializing and validating forecaster configuration parameters like lags and weights.
Functions
| Name | Description |
|---|---|
| check_select_fit_kwargs | Check if fit_kwargs is a dict and select only keys used by estimator’s fit. |
| initialize_lags | Validate and normalize lag specification for forecasting. |
| initialize_weights | Validate and initialize weight function configuration for forecasting. |
check_select_fit_kwargs
utils.forecaster_config.check_select_fit_kwargs(estimator, fit_kwargs=None)Check if fit_kwargs is a dict and select only keys used by estimator’s fit.
This function validates that fit_kwargs is a dictionary, warns about unused arguments, removes ‘sample_weight’ (which should be handled via weight_func), and returns a dictionary containing only the arguments accepted by the estimator’s fit method.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| estimator | Any | Scikit-learn compatible estimator. | required |
| fit_kwargs | Optional[dict] | Dictionary of arguments to pass to the estimator’s fit method. | None |
Returns
| Name | Type | Description |
|---|---|---|
| dict | Dictionary with only the arguments accepted by the estimator’s fit method. |
Raises
| Name | Type | Description |
|---|---|---|
| TypeError | If fit_kwargs is not a dict. |
Warns
If fit_kwargs contains keys not used by fit method, or if ‘sample_weight’ is present (it gets removed).
Examples
>>> from sklearn.linear_model import Ridge
>>> from spotforecast2_safe.utils.forecaster_config import check_select_fit_kwargs
>>>
>>> estimator = Ridge()
>>> # Valid argument for Ridge.fit
>>> kwargs = {"sample_weight": [1, 1], "invalid_arg": 10}
>>> # sample_weight is removed (should be passed via weight_func in forecaster)
>>> # invalid_arg is ignored
>>> filtered = check_select_fit_kwargs(estimator, kwargs)
>>> filtered
{}initialize_lags
utils.forecaster_config.initialize_lags(forecaster_name, lags)Validate and normalize lag specification for forecasting.
This function converts various lag specifications (int, list, tuple, range, ndarray) into a standardized format: sorted numpy array, lag names, and maximum lag value.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| forecaster_name | str | Name of the forecaster class for error messages. | required |
| lags | Any | Lag specification in one of several formats: - int: Creates lags from 1 to lags (e.g., 5 → [1,2,3,4,5]) - list/tuple/range: Converted to numpy array - numpy.ndarray: Validated and used directly - None: Returns (None, None, None) | required |
Returns
| Name | Type | Description |
|---|---|---|
| Optional[np.ndarray] | Tuple containing: | |
| Optional[List[str]] | - lags: Sorted numpy array of lag values (or None) | |
| Optional[int] | - lags_names: List of lag names like [‘lag_1’, ‘lag_2’, …] (or None) | |
| Tuple[Optional[np.ndarray], Optional[List[str]], Optional[int]] | - max_lag: Maximum lag value (or None) |
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If lags < 1, empty array, or not 1-dimensional. | |
| TypeError | If lags is not an integer, not in the right format for the forecaster, or array contains non-integer values. |
Examples
>>> import numpy as np
>>> from spotforecast2_safe.utils.forecaster_config import initialize_lags
>>>
>>> # Integer input
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", 3)
>>> lags
array([1, 2, 3])
>>> names
['lag_1', 'lag_2', 'lag_3']
>>> max_lag
3
>>>
>>> # List input
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", [1, 3, 5])
>>> lags
array([1, 3, 5])
>>> names
['lag_1', 'lag_3', 'lag_5']
>>>
>>> # Range input
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", range(1, 4))
>>> lags
array([1, 2, 3])
>>>
>>> # None input
>>> lags, names, max_lag = initialize_lags("ForecasterRecursive", None)
>>> lags is None
True
>>>
>>> # Invalid: lags < 1
>>> try:
... initialize_lags("ForecasterRecursive", 0)
... except ValueError as e:
... print("Error: Minimum value of lags allowed is 1")
Error: Minimum value of lags allowed is 1
>>>
>>> # Invalid: negative lags
>>> try:
... initialize_lags("ForecasterRecursive", [1, -2, 3])
... except ValueError as e:
... print("Error: Minimum value of lags allowed is 1")
Error: Minimum value of lags allowed is 1initialize_weights
utils.forecaster_config.initialize_weights(
forecaster_name,
estimator,
weight_func,
series_weights,
)Validate and initialize weight function configuration for forecasting.
This function validates weight_func and series_weights, extracts source code from weight functions for serialization, and checks if the estimator supports sample weights in its fit method.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| forecaster_name | str | Name of the forecaster class. | required |
| estimator | Any | Scikit-learn compatible estimator or pipeline. | required |
| weight_func | Any | Weight function specification: - Callable: Single weight function - dict: Dictionary of weight functions (for MultiSeries forecasters) - None: No weighting | required |
| series_weights | Any | Dictionary of series-level weights (for MultiSeries forecasters). - dict: Maps series names to weight values - None: No series weighting | required |
Returns
| Name | Type | Description |
|---|---|---|
| Any | Tuple containing: | |
| Optional[Union[str, dict]] | - weight_func: Validated weight function (or None if invalid) | |
| Any | - source_code_weight_func: Source code of weight function(s) for serialization (or None) | |
| Tuple[Any, Optional[Union[str, dict]], Any] | - series_weights: Validated series weights (or None if invalid) |
Raises
| Name | Type | Description |
|---|---|---|
| TypeError | If weight_func is not Callable/dict (depending on forecaster type), or if series_weights is not a dict. |
Warns
If estimator doesn’t support sample_weight.
Examples
>>> import numpy as np
>>> from sklearn.linear_model import Ridge
>>> from spotforecast2_safe.utils.forecaster_config import initialize_weights
>>>
>>> # Simple weight function
>>> def custom_weights(index):
... return np.ones(len(index))
>>>
>>> estimator = Ridge()
>>> wf, source, sw = initialize_weights(
... "ForecasterRecursive", estimator, custom_weights, None
... )
>>> wf is not None
True
>>> isinstance(source, str)
True
>>>
>>> # No weight function
>>> wf, source, sw = initialize_weights(
... "ForecasterRecursive", estimator, None, None
... )
>>> wf is None
True
>>> source is None
True
>>>
>>> # Invalid type for non-MultiSeries forecaster
>>> try:
... initialize_weights("ForecasterRecursive", estimator, "invalid", None)
... except TypeError as e:
... print("Error: weight_func must be Callable")
Error: weight_func must be Callable