from torch.utils.data import DataLoader
from spotPython.utils.init import fun_control_init
from spotPython.hyperparameters.values import set_control_key_value
from spotPython.data.diabetes import Diabetes
from spotPython.light.regression.netlightregression import NetLightRegression
from spotPython.hyperdict.light_hyper_dict import LightHyperDict
from spotPython.hyperparameters.values import add_core_model_to_fun_control
from spotPython.hyperparameters.values import (
get_default_hyperparameters_as_array, get_one_config_from_X)from spotPython.hyperparameters.values import set_control_key_value
from spotPython.plot.xai import (get_activations, get_gradients, get_weights, plot_nn_values_hist, plot_nn_values_scatter, visualize_weights, visualize_gradients, visualize_activations, visualize_gradient_distributions, visualize_weights_distributions)
= fun_control_init(
fun_control =10, # 10: diabetes
_L_in=1,
_L_out="mean_squared_error",
_torchmetric
)= Diabetes()
dataset =fun_control,
set_control_key_value(control_dict="data_set",
key=dataset,
value=True)
replace=fun_control,
add_core_model_to_fun_control(fun_control=NetLightRegression,
core_model=LightHyperDict)
hyper_dict= get_default_hyperparameters_as_array(fun_control)
X = get_one_config_from_X(X, fun_control)
config = fun_control["_L_in"]
_L_in = fun_control["_L_out"]
_L_out = fun_control["_torchmetric"]
_torchmetric = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
model = config["batch_size"]
batch_size= DataLoader(dataset, batch_size=batch_size, shuffle=False) dataloader
26 Explainable AI with SpotPython and Pytorch
=fun_control, batch_size=batch_size, device = "cpu") get_activations(model, fun_control
net: NetLightRegression(
(layers): Sequential(
(0): Linear(in_features=10, out_features=8, bias=True)
(1): ReLU()
(2): Dropout(p=0.01, inplace=False)
(3): Linear(in_features=8, out_features=4, bias=True)
(4): ReLU()
(5): Dropout(p=0.01, inplace=False)
(6): Linear(in_features=4, out_features=4, bias=True)
(7): ReLU()
(8): Dropout(p=0.01, inplace=False)
(9): Linear(in_features=4, out_features=2, bias=True)
(10): ReLU()
(11): Dropout(p=0.01, inplace=False)
(12): Linear(in_features=2, out_features=1, bias=True)
)
)
{0: array([ 1.43207282e-01, 6.29712082e-03, 1.04200497e-01, -3.79188173e-03,
-1.74976081e-01, -7.97475874e-02, -2.00860098e-01, 2.48444736e-01,
1.42530382e-01, -2.86847632e-03, 3.61538231e-02, -5.21567538e-02,
-2.15294853e-01, -1.26742452e-01, -1.79230243e-01, 2.73077697e-01,
1.36738747e-01, 8.57900176e-03, 1.01677164e-01, 3.27536091e-03,
-1.92429125e-01, -7.95854479e-02, -1.84092522e-01, 2.72164375e-01,
1.51459932e-01, 3.70034538e-02, 4.94864434e-02, -6.36564642e-02,
-1.63678646e-01, -1.26617596e-01, -2.05547154e-01, 2.25242063e-01,
1.54910132e-01, 4.92912624e-03, 6.90693632e-02, -3.28048877e-02,
-1.77523270e-01, -1.17699921e-01, -1.95609123e-01, 2.50784487e-01,
1.66618377e-01, 1.22015951e-02, 2.58807316e-02, -8.16192776e-02,
-2.00623482e-01, -1.17052853e-01, -1.86843857e-01, 2.40996510e-01,
1.80479109e-01, 3.72159854e-02, 3.55244167e-02, -3.60636115e-02,
-2.09616780e-01, -1.19843856e-01, -1.44335642e-01, 2.73970902e-01,
1.46006003e-01, -1.83095373e-02, 8.83664042e-02, 2.28608586e-02,
-1.77115664e-01, -1.37761638e-01, -1.90622538e-01, 2.85049856e-01,
1.44436464e-01, 1.36893094e-02, 6.65568933e-02, -2.01083720e-04,
-1.99043870e-01, -1.11171007e-01, -1.76820531e-01, 2.78549373e-01,
1.31597325e-01, 1.31126186e-02, 5.92438355e-02, -6.50760308e-02,
-1.55642599e-01, -1.12090096e-01, -2.32182071e-01, 2.25448400e-01,
2.09733546e-01, 4.48576249e-02, 1.76887661e-02, -7.26176351e-02,
-1.81560591e-01, -1.18118793e-01, -1.55840069e-01, 2.45131850e-01,
1.57539800e-01, 4.57477495e-02, 8.64019692e-02, 1.06538832e-02,
-2.25713193e-01, -8.36062431e-02, -1.51326194e-01, 2.42097050e-01,
1.46130219e-01, -6.08363096e-03, 4.69235368e-02, -4.06553932e-02,
-1.90215483e-01, -1.30105391e-01, -1.91207454e-01, 2.75829703e-01,
1.37035578e-01, 1.32784406e-02, 8.11730623e-02, -2.83420049e-02,
-1.72134370e-01, -1.05717532e-01, -1.93411276e-01, 2.68321246e-01,
1.24822736e-01, -2.49985531e-02, 5.46513572e-02, -3.76938097e-02,
-2.02080101e-01, -1.29510283e-01, -1.99880868e-01, 2.84415126e-01,
1.36025175e-01, 2.10405551e-02, 1.25923336e-01, -1.76883545e-02,
-1.46617338e-01, -1.00234658e-01, -2.21794963e-01, 2.05139250e-01],
dtype=float32),
3: array([-0.09106569, 0.15831017, 0.29874575, -0.05709065, -0.07168067,
0.13238071, 0.29310873, -0.04537551, -0.08868651, 0.15093939,
0.29576218, -0.0508837 , -0.07256822, 0.15756649, 0.29804155,
-0.06024086, -0.07925774, 0.15159754, 0.29655144, -0.05204485,
-0.06510481, 0.14707124, 0.2955585 , -0.05045141, -0.05945833,
0.15397519, 0.28643152, -0.03937227, -0.0780265 , 0.1443048 ,
0.2993904 , -0.04338943, -0.07745007, 0.1438258 , 0.29152495,
-0.04569358, -0.08201659, 0.14775375, 0.3020632 , -0.06361471,
-0.05014775, 0.16657498, 0.28808075, -0.04191205, -0.07614301,
0.16806594, 0.29809946, -0.05615523, -0.07369395, 0.13612927,
0.2925982 , -0.04455032, -0.08367015, 0.14735378, 0.29441217,
-0.05101945, -0.07929114, 0.12925598, 0.29300398, -0.04631315,
-0.09977546, 0.1741175 , 0.30642375, -0.07330882], dtype=float32),
6: array([ 0.02894721, -0.15329668, 0.0478624 , 0.5073338 , 0.03414171,
-0.1624101 , 0.0582981 , 0.5058923 , 0.0301194 , -0.15560818,
0.05099656, 0.5068564 , 0.02897662, -0.15344843, 0.04822758,
0.5072659 , 0.03012998, -0.15550745, 0.05065323, 0.5069246 ,
0.03103478, -0.1570965 , 0.05247599, 0.50667256, 0.02730933,
-0.15252227, 0.0509877 , 0.5065358 , 0.03256607, -0.15896471,
0.05305116, 0.5067359 , 0.03095146, -0.15756464, 0.05418621,
0.5063291 , 0.03229896, -0.1581411 , 0.05142961, 0.50702184,
0.02454497, -0.14787357, 0.04604906, 0.50718296, 0.02638294,
-0.14930864, 0.04427201, 0.5077407 , 0.03309863, -0.16082475,
0.05695035, 0.50603575, 0.03071198, -0.15675312, 0.05250891,
0.5066291 , 0.03489432, -0.16362445, 0.05948593, 0.50574666,
0.02671532, -0.14859803, 0.04098557, 0.5084204 ], dtype=float32),
9: array([0.04397329, 0.23183572, 0.04112439, 0.22675759, 0.0430866 ,
0.23046201, 0.04386139, 0.2317175 , 0.04319487, 0.23055825,
0.04269706, 0.22967225, 0.04286424, 0.23156166, 0.04263979,
0.2289063 , 0.0421553 , 0.2292049 , 0.04312573, 0.22948454,
0.04418794, 0.23408437, 0.04489119, 0.23388621, 0.0414625 ,
0.22755873, 0.0426609 , 0.22978865, 0.04081305, 0.22611658,
0.04594607, 0.23471704], dtype=float32),
12: array([-0.30635476, -0.30988604, -0.3073418 , -0.30644947, -0.30726004,
-0.30787635, -0.306807 , -0.308307 , -0.3082779 , -0.3078606 ,
-0.30507785, -0.3049926 , -0.30935943, -0.30782318, -0.3103186 ,
-0.30425358], dtype=float32)}
=fun_control, batch_size=batch_size, device = "cpu") get_gradients(model, fun_control
{'layers.0.weight': array([ 0.10417589, -0.04161514, 0.10597268, 0.02180895, 0.12001497,
0.0289035 , 0.01146171, 0.08183315, 0.2495192 , 0.5108763 ,
0.14668097, -0.07902835, 0.00912531, 0.02640062, 0.14108549,
0.06816658, 0.14256881, -0.00347908, 0.07373644, 0.23171763,
0.08313344, -0.0332093 , 0.08456729, 0.01740377, 0.09577318,
0.0230653 , 0.00914656, 0.0653037 , 0.1991189 , 0.4076846 ,
0.04405227, 0.03805925, 0.015035 , 0.0069457 , 0.0094994 ,
0.03021198, -0.01876849, 0.02160799, -0.03238906, -0.02050959,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
-0.05415884, 0.02163483, -0.05509295, -0.01133801, -0.06239325,
-0.01502632, -0.0059587 , -0.04254333, -0.12971975, -0.2655938 ],
dtype=float32),
'layers.3.weight': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
-5.8896484e+00, -6.3058013e-01, -2.5641673e+00, -8.9936234e-02,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, -1.0009734e+01,
5.1539743e-01, 5.5181440e-02, 2.2438775e-01, 7.8702327e-03,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 8.7594193e-01,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
dtype=float32),
'layers.6.weight': array([ 0. , 7.6445217, 15.007772 , 0. , 0. ,
0. , 0. , 0. , 0. , 11.027901 ,
21.650045 , 0. , 0. , 3.458755 , 6.7902493,
0. ], dtype=float32),
'layers.9.weight': array([ -2.3285942, 0. , -3.9471323, -39.11015 , -4.6057286,
0. , -7.8070364, -77.35598 ], dtype=float32),
'layers.12.weight': array([-12.126856, -64.91129 ], dtype=float32)}
get_weights(model)
{'Layer 0': array([-0.12895013, 0.01047491, -0.15705723, 0.11925378, -0.26944348,
0.23180884, -0.22984707, -0.25141433, -0.19982024, 0.1432175 ,
-0.11684369, 0.11833665, -0.2683918 , -0.19186287, -0.11611126,
-0.06214499, -0.24123858, 0.20706302, -0.07457636, 0.10150522,
0.22361842, 0.05891513, 0.08647271, 0.3052416 , -0.1426217 ,
0.10016554, -0.14069483, 0.22599207, 0.25255734, -0.29155323,
0.26994652, 0.1510033 , 0.13780165, 0.13018303, 0.26287985,
-0.04175457, -0.26743335, -0.09074122, -0.2227112 , 0.02090477,
-0.05904209, -0.16961981, -0.02875187, 0.2995954 , -0.0249426 ,
0.01004026, -0.04931906, 0.04971322, 0.28176296, 0.19337103,
0.11224869, 0.06871963, 0.07456426, 0.12216929, -0.04086405,
-0.29390487, -0.19555901, 0.2699275 , 0.01890202, -0.25616774,
0.04987781, 0.26129004, -0.29883513, -0.21289697, -0.12594265,
0.0126926 , -0.07375361, -0.03475064, -0.30828732, 0.14808287,
0.2775668 , 0.19329055, -0.22393112, -0.25491226, 0.13131432,
0.00710202, 0.12963155, -0.3090024 , -0.01885445, 0.22301763],
dtype=float32),
'Layer 3': array([ 0.19455571, 0.12364562, -0.2711233 , 0.2728095 , 0.11085409,
0.24458633, -0.13908438, 0.07495222, 0.34520328, 0.23782092,
0.28354865, -0.07424083, 0.26936427, -0.2769144 , 0.03057847,
-0.19906998, -0.08245403, -0.09054411, 0.02645254, 0.32178298,
0.17503859, -0.00149773, 0.2509683 , -0.1811804 , 0.18221132,
-0.03278595, -0.06152213, 0.0413917 , -0.27085608, 0.04085568,
0.11887809, 0.302264 ], dtype=float32),
'Layer 6': array([ 0.4752962 , -0.24824601, 0.22039747, 0.19587505, 0.13966405,
0.39540154, -0.20208222, 0.13140953, 0.00280607, -0.3760708 ,
-0.12140697, -0.33391154, 0.22107768, 0.04494798, 0.04898232,
-0.15168536], dtype=float32),
'Layer 9': array([ 0.07573527, -0.22145915, -0.30541402, 0.03821951, -0.3709231 ,
-0.3758251 , -0.3254385 , -0.1698224 ], dtype=float32),
'Layer 12': array([0.2738903, 0.5417278], dtype=float32)}
=fun_control, batch_size=batch_size, device = "cpu", cmap="BlueWhiteRed", absolute=False) visualize_activations(model, fun_control
net: NetLightRegression(
(layers): Sequential(
(0): Linear(in_features=10, out_features=8, bias=True)
(1): ReLU()
(2): Dropout(p=0.01, inplace=False)
(3): Linear(in_features=8, out_features=4, bias=True)
(4): ReLU()
(5): Dropout(p=0.01, inplace=False)
(6): Linear(in_features=4, out_features=4, bias=True)
(7): ReLU()
(8): Dropout(p=0.01, inplace=False)
(9): Linear(in_features=4, out_features=2, bias=True)
(10): ReLU()
(11): Dropout(p=0.01, inplace=False)
(12): Linear(in_features=2, out_features=1, bias=True)
)
)
128 values in Layer 0.
16 padding values added.
144 values now in Layer 0.
64 values in Layer 3.
64 values now in Layer 3.
64 values in Layer 6.
64 values now in Layer 6.
32 values in Layer 9.
4 padding values added.
36 values now in Layer 9.
16 values in Layer 12.
16 values now in Layer 12.
=f"C{0}") visualize_weights_distributions(model, color
n:5
=batch_size, color=f"C{0}") visualize_gradient_distributions(model, fun_control, batch_size
n:5
=True, cmap="gray", figsize=(6, 6)) visualize_weights(model, absolute
80 values in Layer Layer 0.
1 padding values added.
81 values now in Layer Layer 0.
32 values in Layer Layer 3.
4 padding values added.
36 values now in Layer Layer 3.
16 values in Layer Layer 6.
16 values now in Layer Layer 6.
8 values in Layer Layer 9.
1 padding values added.
9 values now in Layer Layer 9.
2 values in Layer Layer 12.
2 padding values added.
4 values now in Layer Layer 12.
=True, cmap="BlueWhiteRed", figsize=(6, 6)) visualize_gradients(model, fun_control, batch_size, absolute
80 values in Layer layers.0.weight.
1 padding values added.
81 values now in Layer layers.0.weight.
32 values in Layer layers.3.weight.
4 padding values added.
36 values now in Layer layers.3.weight.
16 values in Layer layers.6.weight.
16 values now in Layer layers.6.weight.
8 values in Layer layers.9.weight.
1 padding values added.
9 values now in Layer layers.9.weight.
2 values in Layer layers.12.weight.
2 padding values added.
4 values now in Layer layers.12.weight.
=fun_control, batch_size=batch_size, device = "cpu") visualize_activations(model, fun_control
net: NetLightRegression(
(layers): Sequential(
(0): Linear(in_features=10, out_features=8, bias=True)
(1): ReLU()
(2): Dropout(p=0.01, inplace=False)
(3): Linear(in_features=8, out_features=4, bias=True)
(4): ReLU()
(5): Dropout(p=0.01, inplace=False)
(6): Linear(in_features=4, out_features=4, bias=True)
(7): ReLU()
(8): Dropout(p=0.01, inplace=False)
(9): Linear(in_features=4, out_features=2, bias=True)
(10): ReLU()
(11): Dropout(p=0.01, inplace=False)
(12): Linear(in_features=2, out_features=1, bias=True)
)
)
128 values in Layer 0.
16 padding values added.
144 values now in Layer 0.
64 values in Layer 3.
64 values now in Layer 3.
64 values in Layer 6.
64 values now in Layer 6.
32 values in Layer 9.
4 padding values added.
36 values now in Layer 9.
16 values in Layer 12.
16 values now in Layer 12.
=fun_control, batch_size=batch_size, device = "cpu") visualize_activations(model, fun_control
net: NetLightRegression(
(layers): Sequential(
(0): Linear(in_features=10, out_features=8, bias=True)
(1): ReLU()
(2): Dropout(p=0.01, inplace=False)
(3): Linear(in_features=8, out_features=4, bias=True)
(4): ReLU()
(5): Dropout(p=0.01, inplace=False)
(6): Linear(in_features=4, out_features=4, bias=True)
(7): ReLU()
(8): Dropout(p=0.01, inplace=False)
(9): Linear(in_features=4, out_features=2, bias=True)
(10): ReLU()
(11): Dropout(p=0.01, inplace=False)
(12): Linear(in_features=2, out_features=1, bias=True)
)
)
128 values in Layer 0.
16 padding values added.
144 values now in Layer 0.
64 values in Layer 3.
64 values now in Layer 3.
64 values in Layer 6.
64 values now in Layer 6.
32 values in Layer 9.
4 padding values added.
36 values now in Layer 9.
16 values in Layer 12.
16 values now in Layer 12.