multitask.predict.PredictTask

multitask.predict.PredictTask(
    config=None,
    *,
    dataframe=None,
    data_test=None,
    cache_home=None,
    log_level=logging.INFO,
    **overrides,
)

Task 5 — Predict-only using previously saved models.

Loads fitted forecasters that were persisted by a prior LazyTask or DefaultsTask run (or by any task in the spotforecast2 sibling package) and produces predictions for all configured targets. No training or tuning is performed.

If no saved models exist in the cache directory the run method raises RuntimeError with an informative message.

Examples

import tempfile
from pathlib import Path
from spotforecast2_safe.multitask import PredictTask
from spotforecast2_safe.configurator.config_multi import ConfigMulti

with tempfile.TemporaryDirectory() as tmp:
    cfg = ConfigMulti(data_frame_name="demo10", predict_size=24, cache_home=Path(tmp))
    task = PredictTask(cfg)
    print(f"Task: {task.TASK}")
    print(f"Predict size: {task.config.predict_size}")
Task: predict
Predict size: 24

Methods

Name Description
run Run prediction using previously saved models.

run

multitask.predict.PredictTask.run(
    show=False,
    task_name=None,
    max_age_days=None,
    **kwargs,
)

Run prediction using previously saved models.

Parameters

Name Type Description Default
show bool If True, invoke the visualisation hooks. False
task_name Optional[str] Restrict model loading to a specific source task ("lazy", "defaults", "optuna", or "spotoptim"). None loads the most recent model regardless of source. None
max_age_days Optional[float] Maximum age in days for saved models. Models older than this are ignored. None accepts any age. None

Returns

Name Type Description
Dict[str, Any] Aggregated prediction package. Per-target packages are stored
Dict[str, Any] on self.results["predict"].

Raises

Name Type Description
RuntimeError If no saved models are found in the cache directory, or if a target has no matching model.

Examples

import tempfile
import warnings
from pathlib import Path

import numpy as np
import pandas as pd

from spotforecast2_safe.configurator.config_multi import ConfigMulti
from spotforecast2_safe.multitask.lazy import LazyTask
from spotforecast2_safe.multitask.predict import PredictTask

warnings.filterwarnings("ignore")

rng = np.random.default_rng(0)
idx = pd.date_range("2023-01-01", periods=300, freq="h", tz="UTC")
df = pd.DataFrame(
    {"total_load_actual": rng.normal(10_000, 500, 300).clip(8_000, 12_000)},
    index=idx,
)
df.index.name = "DateTime"

with tempfile.TemporaryDirectory() as tmp:
    cfg = ConfigMulti(
        data_frame_name="demo",
        predict_size=12,
        auto_save_models=True,
        cache_home=Path(tmp),
        targets=["total_load_actual"],
        verbose=False,
        use_outlier_detection=False,
        use_exogenous_features=False,
        number_folds=2,
    )
    # Train and save a model with LazyTask first.
    lazy = LazyTask(cfg, dataframe=df)
    lazy.prepare_data()
    lazy.run()

    # Run PredictTask to load the saved model and forecast.
    task = PredictTask(cfg, dataframe=df)
    task.prepare_data()
    result = task.run()

    print(f"Future predictions shape: {result['future_pred'].shape}")
    assert result["future_pred"].shape == (12,)
    assert "metrics_train" in result
Future predictions shape: (12,)