mo.mo_mm.mo_xy_desirability_plot
mo.mo_mm.mo_xy_desirability_plot(
models,
X_base,
J_base,
d_base,
phi_base,
D_overall,
bounds= None ,
mm_objective= True ,
resolution= 50 ,
feature_names= None ,
** kwargs,
)
Generates a plot of the desirability landscape. Plots the 2-dim X values as points in the plane and colors them according to their desirability values. For each pair of inputs, x_i and x_j (with i < j), one plot is generated.
Parameters
models
list
List of trained models (one per objective).
required
X_base
np .ndarray
Existing design points.
required
J_base
np .ndarray
Multiplicities of distances for X_base.
required
d_base
np .ndarray
Unique distances for X_base.
required
phi_base
float
Base Morris-Mitchell metric.
required
D_overall
DOverall
The overall desirability function.
required
bounds
list
List of tuples (min, max) for each dimension. If None, derived from X_base.
None
mm_objective
bool
Whether to include space-filling improvement. Defaults to True.
True
resolution
int
Grid resolution for the plot. Defaults to 50.
50
feature_names
list
List of names for the input variables. Defaults to None.
None
**kwargs
Any
Additional arguments for plt.subplots (e.g., figsize).
{}
Examples
>>> from spotoptim.mo.mo_mm import mo_xy_desirability_plot
>>> import numpy as np
>>> from spotoptim.function.mo import mo_conv2_max
>>> from sklearn.ensemble import RandomForestRegressor
>>> from spotoptim.sampling.mm import mmphi_intensive
>>> # X_base in the range [0,1]
>>> X_base = np.random.rand(500 , 2 )
>>> y = mo_conv2_max(X_base)
>>> models = []
>>> for i in range (y.shape[1 ]):
... model = RandomForestRegressor(n_estimators= 100 , random_state= 42 )
... model.fit(X_base, y[:, i])
... models.append(model)
>>> # calculate base Morris-Mitchell stats
>>> phi_base, J_base, d_base = mmphi_intensive(X_base, q= 2 , p= 2 )
>>> d_funcs = []
>>> for i in range (y.shape[1 ]):
... d_func = DMax(low= np.min (y[:, i]), high= np.max (y[:, i]))
... d_funcs.append(d_func)
>>> D_overall = DOverall(* d_funcs)
>>> mo_xy_desirability_plot(models, X_base, J_base, d_base, phi_base, D_overall)