preprocessing.forecaster_config
Forecaster configuration utilities.
This module provides functions for initializing and validating forecaster configuration parameters like lags and weights.
Functions
check_select_fit_kwargs
preprocessing.forecaster_config.check_select_fit_kwargs(
estimator,
fit_kwargs= None ,
)
Check if fit_kwargs is a dict and select only keys used by estimator’s fit.
This function validates that fit_kwargs is a dictionary, warns about unused arguments, removes ‘sample_weight’ (which should be handled via weight_func), and returns a dictionary containing only the arguments accepted by the estimator’s fit method.
Parameters
estimator
Any
Scikit-learn compatible estimator.
required
fit_kwargs
Optional [dict ]
Dictionary of arguments to pass to the estimator’s fit method.
None
Returns
dict
Dictionary with only the arguments accepted by the estimator’s fit method.
Warns
If fit_kwargs contains keys not used by fit method, or if ‘sample_weight’ is present (it gets removed).
Examples
import warnings
from sklearn.linear_model import Ridge
from spotforecast2_safe.preprocessing.forecaster_config import check_select_fit_kwargs
estimator = Ridge()
# sample_weight is removed (should be passed via weight_func in forecaster);
# invalid_arg is ignored; both trigger IgnoredArgumentWarning
kwargs = {"sample_weight" : [1 , 1 ], "invalid_arg" : 10 }
with warnings.catch_warnings():
warnings.simplefilter("ignore" )
filtered = check_select_fit_kwargs(estimator, kwargs)
assert filtered == {}
print (f"filtered fit_kwargs: { filtered} " )
initialize_lags
preprocessing.forecaster_config.initialize_lags(forecaster_name, lags)
Validate and normalize lag specification for forecasting.
This function converts various lag specifications (int, list, tuple, range, ndarray) into a standardized format: sorted numpy array, lag names, and maximum lag value.
Parameters
forecaster_name
str
Name of the forecaster class for error messages.
required
lags
Any
Lag specification in one of several formats: - int: Creates lags from 1 to lags (e.g., 5 → [1,2,3,4,5]) - list/tuple/range: Converted to numpy array - numpy.ndarray: Validated and used directly - None: Returns (None, None, None)
required
Raises
ValueError
If lags < 1, empty array, or not 1-dimensional.
TypeError
If lags is not an integer, not in the right format for the forecaster, or array contains non-integer values.
Examples
import numpy as np
from spotforecast2_safe.preprocessing.forecaster_config import initialize_lags
# Integer input
lags, names, max_lag = initialize_lags("ForecasterRecursive" , 3 )
assert list (lags) == [1 , 2 , 3 ]
assert names == ["lag_1" , "lag_2" , "lag_3" ]
assert max_lag == 3
# List input
lags, names, max_lag = initialize_lags("ForecasterRecursive" , [1 , 3 , 5 ])
assert list (lags) == [1 , 3 , 5 ]
assert names == ["lag_1" , "lag_3" , "lag_5" ]
# Range input
lags, names, max_lag = initialize_lags("ForecasterRecursive" , range (1 , 4 ))
assert list (lags) == [1 , 2 , 3 ]
# None input
lags, names, max_lag = initialize_lags("ForecasterRecursive" , None )
assert lags is None
print ("initialize_lags: valid inputs OK" )
initialize_lags: valid inputs OK
from spotforecast2_safe.preprocessing.forecaster_config import initialize_lags
# Invalid: lags < 1
try :
initialize_lags("ForecasterRecursive" , 0 )
except ValueError as e:
print (f"ValueError: { e} " )
# Invalid: negative lag in list
try :
initialize_lags("ForecasterRecursive" , [1 , - 2 , 3 ])
except ValueError as e:
print (f"ValueError: { e} " )
ValueError: Minimum value of lags allowed is 1.
ValueError: Minimum value of lags allowed is 1.
initialize_weights
preprocessing.forecaster_config.initialize_weights(
forecaster_name,
estimator,
weight_func,
series_weights,
)
Validate and initialize weight function configuration for forecasting.
This function validates weight_func and series_weights, extracts source code from weight functions for serialization, and checks if the estimator supports sample weights in its fit method.
Parameters
forecaster_name
str
Name of the forecaster class.
required
estimator
Any
Scikit-learn compatible estimator or pipeline.
required
weight_func
Any
Weight function specification: - Callable: Single weight function - dict: Dictionary of weight functions (for MultiSeries forecasters) - None: No weighting
required
series_weights
Any
Dictionary of series-level weights (for MultiSeries forecasters). - dict: Maps series names to weight values - None: No series weighting
required
Returns
Any
Tuple containing:
Optional [Union [str , dict ]]
- weight_func: Validated weight function (or None if invalid)
Any
- source_code_weight_func: Source code of weight function(s) for serialization (or None)
Tuple [Any , Optional [Union [str , dict ]], Any ]
- series_weights: Validated series weights (or None if invalid)
Raises
TypeError
If weight_func is not Callable/dict (depending on forecaster type), or if series_weights is not a dict.
Warns
If estimator doesn’t support sample_weight.
Examples
import numpy as np
from sklearn.linear_model import Ridge
from spotforecast2_safe.preprocessing.forecaster_config import initialize_weights
def custom_weights(index):
return np.ones(len (index))
estimator = Ridge()
# Valid callable weight function
wf, source, sw = initialize_weights(
"ForecasterRecursive" , estimator, custom_weights, None
)
assert wf is not None
assert isinstance (source, str )
assert sw is None
# No weight function
wf, source, sw = initialize_weights(
"ForecasterRecursive" , estimator, None , None
)
assert wf is None
assert source is None
# Invalid type raises TypeError
try :
initialize_weights("ForecasterRecursive" , estimator, "invalid" , None )
except TypeError as e:
print (f"TypeError: { e} " )
TypeError: Argument `weight_func` must be a Callable. Got <class 'str'>.