plot
plot_predictionintervals(y_train, y_train_pred, y_train_pred_low, y_train_pred_high, y_test, y_test_pred, y_test_pred_low, y_test_pred_high, suptitle)
¶
Plots prediction intervals for training and testing data. This function generates four subplots arranged in a 2x2 grid: 1. True vs predicted values with error bars representing prediction intervals. 2. Prediction interval width vs true values. 3. Ordered prediction interval widths for both training and testing data. 4. Histograms of the interval widths for training and testing data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y_train |
array - like
|
True values for the training set. |
required |
y_train_pred |
array - like
|
Predicted values for the training set. |
required |
y_train_pred_low |
array - like
|
Lower bounds of prediction intervals for the training set. |
required |
y_train_pred_high |
array - like
|
Upper bounds of prediction intervals for the training set. |
required |
y_test |
array - like
|
True values for the testing set. |
required |
y_test_pred |
array - like
|
Predicted values for the testing set. |
required |
y_test_pred_low |
array - like
|
Lower bounds of prediction intervals for the testing set. |
required |
y_test_pred_high |
array - like
|
Upper bounds of prediction intervals for the testing set. |
required |
suptitle |
str
|
The title for the entire figure. |
required |
Returns:
Name | Type | Description |
---|---|---|
None |
None
|
The function displays the plots but does not return any value. |
Notes
- The first subplot compares true and predicted values with error bars for both training and testing data.
- The second subplot visualizes the width of prediction intervals as a function of true values.
- The third subplot orders the prediction interval widths and displays them for both training and testing data.
- The fourth subplot shows histograms of the interval widths for training and testing data.
References
Function adapted from: https://github.com/scikit-learn-contrib/MAPIE/blob/master/notebooks/regression/exoplanets.ipynb
Examples:
>>> import numpy as np
>>> from spotpython.uc.plot import plot_predictionintervals
>>> y_train = np.array([1, 2, 3, 4, 5])
>>> y_train_pred = np.array([1.1, 2.2, 3.3, 4.4, 5.5])
>>> y_train_pred_low = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
>>> y_train_pred_high = np.array([1.2, 2.4, 3.6, 4.8, 6.0])
>>> y_test = np.array([6, 7, 8])
>>> y_test_pred = np.array([6.1, 7.2, 8.3])
>>> y_test_pred_low = np.array([6.0, 7.0, 8.0])
>>> y_test_pred_high = np.array([6.2, 7.4, 8.6])
>>> suptitle = "Prediction Intervals"
>>> plot_predictionintervals(y_train, y_train_pred, y_train_pred_low, y_train_pred_high, y_test, y_test_pred, y_test_pred_low, y_test_pred_high, suptitle)
Source code in spotpython/uc/plot.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
|