multitask.PredictTask

multitask.PredictTask(
    config=None,
    *,
    dataframe=None,
    data_test=None,
    cache_home=None,
    log_level=logging.INFO,
    **overrides,
)

Task 5 — Predict-only using previously saved models.

Loads fitted forecasters that were persisted by a prior LazyTask, OptunaTask, or SpotOptimTask run and produces predictions for all configured targets. No training or tuning is performed.

If no saved models exist in the cache directory the run method raises RuntimeError with an informative message.

Examples

from spotforecast2.multitask import PredictTask

task = PredictTask(data_frame_name="demo10", predict_size=24)
print(f"Task: {task.TASK}")
print(f"Predict size: {task.config.predict_size}")
Task: predict
Predict size: 24

Methods

Name Description
run Run prediction using previously saved models.

run

multitask.PredictTask.run(
    show=True,
    task_name=None,
    max_age_days=None,
    **kwargs,
)

Run prediction using previously saved models.

Parameters

Name Type Description Default
show bool If True, display prediction figures. True
task_name Optional[str] Restrict model loading to a specific source task ("lazy", "optuna", or "spotoptim"). None loads the most recent model regardless of source. None
max_age_days Optional[float] Maximum age in days for saved models. Models older than this are ignored. None accepts any age. None

Returns

Name Type Description
Dict[str, Any] Aggregated prediction package. Per-target packages are stored
Dict[str, Any] on self.results["predict"].

Raises

Name Type Description
RuntimeError If no saved models are found in the cache directory, or if a target has no matching model.

Examples

import tempfile
from pathlib import Path
from spotforecast2_safe.data.fetch_data import fetch_data, get_package_data_home
from spotforecast2.multitask import LazyTask, PredictTask

demo_df = fetch_data(filename=str(get_package_data_home() / "demo10.csv"))

with tempfile.TemporaryDirectory() as tmp:
    # Train and persist a model for a single target.
    lazy = LazyTask(
        data_frame_name="demo10",
        cache_home=Path(tmp),
        predict_size=24,
        targets=["A"],
        use_exogenous_features=False,
    )
    lazy.prepare_data(demo_data=demo_df)
    lazy.detect_outliers()
    lazy.impute()
    lazy.build_exogenous_features()
    lazy.run(show=False)
    lazy.save_models(task_name="lazy")

    # Load the saved model and produce predictions.
    pred = PredictTask(
        data_frame_name="demo10",
        cache_home=Path(tmp),
        predict_size=24,
        targets=["A"],
        use_exogenous_features=False,
    )
    pred.prepare_data(demo_data=demo_df)
    pred.detect_outliers()
    pred.impute()
    pred.build_exogenous_features()
    result = pred.run(show=False, task_name="lazy")

pkg = pred.results["predict"]["A"]
print(f"future_pred length: {len(pkg['future_pred'])}")
print(f"result keys: {sorted(result.keys())}")
assert len(pkg["future_pred"]) == 24
assert "future_pred" in result
/tmp/ipykernel_3195/494582107.py:17: DeprecationWarning: Derived pipeline fields (start_download, end_download, data_start, data_end, cov_start, cov_end, end_train_ts, start_train_ts) have moved to task.run_state. Reading them from the config is deprecated and will stop working in the next major release. config.targets continues to hold the user input unchanged; read the resolved list from task.run_state.targets.
  lazy.prepare_data(demo_data=demo_df)
/tmp/ipykernel_3195/494582107.py:32: DeprecationWarning: Derived pipeline fields (start_download, end_download, data_start, data_end, cov_start, cov_end, end_train_ts, start_train_ts) have moved to task.run_state. Reading them from the config is deprecated and will stop working in the next major release. config.targets continues to hold the user input unchanged; read the resolved list from task.run_state.targets.
  pred.prepare_data(demo_data=demo_df)
future_pred length: 24
result keys: ['future_actual', 'future_pred', 'metrics_future', 'metrics_future_one_day', 'metrics_train', 'train_actual', 'train_pred', 'validation_passed']