forecaster.metrics.add_y_train_argument

forecaster.metrics.add_y_train_argument(func)

Add y_train argument to a function if it is not already present.

Parameters

Name Type Description Default
func Callable Function to which the argument is added. required

Returns

Name Type Description
Callable Function with y_train argument added.

Examples

import numpy as np
from spotforecast2_safe.forecaster.metrics import add_y_train_argument

def my_metric(y_true, y_pred):
    return np.mean(np.abs(y_true - y_pred))

enhanced_metric = add_y_train_argument(my_metric)
result = enhanced_metric(np.array([1, 2, 3]), np.array([1, 2, 3]), y_train=None)
print(result)
0.0