26  Explainable AI with SpotPython and Pytorch

from torch.utils.data import DataLoader
from spotPython.utils.init import fun_control_init
from spotPython.hyperparameters.values import set_control_key_value
from spotPython.data.diabetes import Diabetes
from spotPython.light.regression.netlightregression import NetLightRegression
from spotPython.hyperdict.light_hyper_dict import LightHyperDict
from spotPython.hyperparameters.values import add_core_model_to_fun_control
from spotPython.hyperparameters.values import (
        get_default_hyperparameters_as_array, get_one_config_from_X)
from spotPython.hyperparameters.values import set_control_key_value
from spotPython.plot.xai import (get_activations, get_gradients, get_weights, plot_nn_values_hist, plot_nn_values_scatter, visualize_weights, visualize_gradients, visualize_activations, visualize_gradient_distributions, visualize_weights_distributions)
fun_control = fun_control_init(
    _L_in=10, # 10: diabetes
    _L_out=1,
    _torchmetric="mean_squared_error",
    )
dataset = Diabetes()
set_control_key_value(control_dict=fun_control,
                        key="data_set",
                        value=dataset,
                        replace=True)
add_core_model_to_fun_control(fun_control=fun_control,
                              core_model=NetLightRegression,
                              hyper_dict=LightHyperDict)
X = get_default_hyperparameters_as_array(fun_control)
config = get_one_config_from_X(X, fun_control)
_L_in = fun_control["_L_in"]
_L_out = fun_control["_L_out"]
_torchmetric = fun_control["_torchmetric"]
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
batch_size= config["batch_size"]
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
get_activations(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
net: NetLightRegression(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=8, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.01, inplace=False)
    (3): Linear(in_features=8, out_features=4, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.01, inplace=False)
    (6): Linear(in_features=4, out_features=4, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.01, inplace=False)
    (9): Linear(in_features=4, out_features=2, bias=True)
    (10): ReLU()
    (11): Dropout(p=0.01, inplace=False)
    (12): Linear(in_features=2, out_features=1, bias=True)
  )
)
{0: array([ 1.43207282e-01,  6.29712082e-03,  1.04200497e-01, -3.79188173e-03,
        -1.74976081e-01, -7.97475874e-02, -2.00860098e-01,  2.48444736e-01,
         1.42530382e-01, -2.86847632e-03,  3.61538231e-02, -5.21567538e-02,
        -2.15294853e-01, -1.26742452e-01, -1.79230243e-01,  2.73077697e-01,
         1.36738747e-01,  8.57900176e-03,  1.01677164e-01,  3.27536091e-03,
        -1.92429125e-01, -7.95854479e-02, -1.84092522e-01,  2.72164375e-01,
         1.51459932e-01,  3.70034538e-02,  4.94864434e-02, -6.36564642e-02,
        -1.63678646e-01, -1.26617596e-01, -2.05547154e-01,  2.25242063e-01,
         1.54910132e-01,  4.92912624e-03,  6.90693632e-02, -3.28048877e-02,
        -1.77523270e-01, -1.17699921e-01, -1.95609123e-01,  2.50784487e-01,
         1.66618377e-01,  1.22015951e-02,  2.58807316e-02, -8.16192776e-02,
        -2.00623482e-01, -1.17052853e-01, -1.86843857e-01,  2.40996510e-01,
         1.80479109e-01,  3.72159854e-02,  3.55244167e-02, -3.60636115e-02,
        -2.09616780e-01, -1.19843856e-01, -1.44335642e-01,  2.73970902e-01,
         1.46006003e-01, -1.83095373e-02,  8.83664042e-02,  2.28608586e-02,
        -1.77115664e-01, -1.37761638e-01, -1.90622538e-01,  2.85049856e-01,
         1.44436464e-01,  1.36893094e-02,  6.65568933e-02, -2.01083720e-04,
        -1.99043870e-01, -1.11171007e-01, -1.76820531e-01,  2.78549373e-01,
         1.31597325e-01,  1.31126186e-02,  5.92438355e-02, -6.50760308e-02,
        -1.55642599e-01, -1.12090096e-01, -2.32182071e-01,  2.25448400e-01,
         2.09733546e-01,  4.48576249e-02,  1.76887661e-02, -7.26176351e-02,
        -1.81560591e-01, -1.18118793e-01, -1.55840069e-01,  2.45131850e-01,
         1.57539800e-01,  4.57477495e-02,  8.64019692e-02,  1.06538832e-02,
        -2.25713193e-01, -8.36062431e-02, -1.51326194e-01,  2.42097050e-01,
         1.46130219e-01, -6.08363096e-03,  4.69235368e-02, -4.06553932e-02,
        -1.90215483e-01, -1.30105391e-01, -1.91207454e-01,  2.75829703e-01,
         1.37035578e-01,  1.32784406e-02,  8.11730623e-02, -2.83420049e-02,
        -1.72134370e-01, -1.05717532e-01, -1.93411276e-01,  2.68321246e-01,
         1.24822736e-01, -2.49985531e-02,  5.46513572e-02, -3.76938097e-02,
        -2.02080101e-01, -1.29510283e-01, -1.99880868e-01,  2.84415126e-01,
         1.36025175e-01,  2.10405551e-02,  1.25923336e-01, -1.76883545e-02,
        -1.46617338e-01, -1.00234658e-01, -2.21794963e-01,  2.05139250e-01],
       dtype=float32),
 3: array([-0.09106569,  0.15831017,  0.29874575, -0.05709065, -0.07168067,
         0.13238071,  0.29310873, -0.04537551, -0.08868651,  0.15093939,
         0.29576218, -0.0508837 , -0.07256822,  0.15756649,  0.29804155,
        -0.06024086, -0.07925774,  0.15159754,  0.29655144, -0.05204485,
        -0.06510481,  0.14707124,  0.2955585 , -0.05045141, -0.05945833,
         0.15397519,  0.28643152, -0.03937227, -0.0780265 ,  0.1443048 ,
         0.2993904 , -0.04338943, -0.07745007,  0.1438258 ,  0.29152495,
        -0.04569358, -0.08201659,  0.14775375,  0.3020632 , -0.06361471,
        -0.05014775,  0.16657498,  0.28808075, -0.04191205, -0.07614301,
         0.16806594,  0.29809946, -0.05615523, -0.07369395,  0.13612927,
         0.2925982 , -0.04455032, -0.08367015,  0.14735378,  0.29441217,
        -0.05101945, -0.07929114,  0.12925598,  0.29300398, -0.04631315,
        -0.09977546,  0.1741175 ,  0.30642375, -0.07330882], dtype=float32),
 6: array([ 0.02894721, -0.15329668,  0.0478624 ,  0.5073338 ,  0.03414171,
        -0.1624101 ,  0.0582981 ,  0.5058923 ,  0.0301194 , -0.15560818,
         0.05099656,  0.5068564 ,  0.02897662, -0.15344843,  0.04822758,
         0.5072659 ,  0.03012998, -0.15550745,  0.05065323,  0.5069246 ,
         0.03103478, -0.1570965 ,  0.05247599,  0.50667256,  0.02730933,
        -0.15252227,  0.0509877 ,  0.5065358 ,  0.03256607, -0.15896471,
         0.05305116,  0.5067359 ,  0.03095146, -0.15756464,  0.05418621,
         0.5063291 ,  0.03229896, -0.1581411 ,  0.05142961,  0.50702184,
         0.02454497, -0.14787357,  0.04604906,  0.50718296,  0.02638294,
        -0.14930864,  0.04427201,  0.5077407 ,  0.03309863, -0.16082475,
         0.05695035,  0.50603575,  0.03071198, -0.15675312,  0.05250891,
         0.5066291 ,  0.03489432, -0.16362445,  0.05948593,  0.50574666,
         0.02671532, -0.14859803,  0.04098557,  0.5084204 ], dtype=float32),
 9: array([0.04397329, 0.23183572, 0.04112439, 0.22675759, 0.0430866 ,
        0.23046201, 0.04386139, 0.2317175 , 0.04319487, 0.23055825,
        0.04269706, 0.22967225, 0.04286424, 0.23156166, 0.04263979,
        0.2289063 , 0.0421553 , 0.2292049 , 0.04312573, 0.22948454,
        0.04418794, 0.23408437, 0.04489119, 0.23388621, 0.0414625 ,
        0.22755873, 0.0426609 , 0.22978865, 0.04081305, 0.22611658,
        0.04594607, 0.23471704], dtype=float32),
 12: array([-0.30635476, -0.30988604, -0.3073418 , -0.30644947, -0.30726004,
        -0.30787635, -0.306807  , -0.308307  , -0.3082779 , -0.3078606 ,
        -0.30507785, -0.3049926 , -0.30935943, -0.30782318, -0.3103186 ,
        -0.30425358], dtype=float32)}
get_gradients(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
{'layers.0.weight': array([ 0.10417589, -0.04161514,  0.10597268,  0.02180895,  0.12001497,
         0.0289035 ,  0.01146171,  0.08183315,  0.2495192 ,  0.5108763 ,
         0.14668097, -0.07902835,  0.00912531,  0.02640062,  0.14108549,
         0.06816658,  0.14256881, -0.00347908,  0.07373644,  0.23171763,
         0.08313344, -0.0332093 ,  0.08456729,  0.01740377,  0.09577318,
         0.0230653 ,  0.00914656,  0.0653037 ,  0.1991189 ,  0.4076846 ,
         0.04405227,  0.03805925,  0.015035  ,  0.0069457 ,  0.0094994 ,
         0.03021198, -0.01876849,  0.02160799, -0.03238906, -0.02050959,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        -0.05415884,  0.02163483, -0.05509295, -0.01133801, -0.06239325,
        -0.01502632, -0.0059587 , -0.04254333, -0.12971975, -0.2655938 ],
       dtype=float32),
 'layers.3.weight': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
        -5.8896484e+00, -6.3058013e-01, -2.5641673e+00, -8.9936234e-02,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00, -1.0009734e+01,
         5.1539743e-01,  5.5181440e-02,  2.2438775e-01,  7.8702327e-03,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  8.7594193e-01,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       dtype=float32),
 'layers.6.weight': array([ 0.       ,  7.6445217, 15.007772 ,  0.       ,  0.       ,
         0.       ,  0.       ,  0.       ,  0.       , 11.027901 ,
        21.650045 ,  0.       ,  0.       ,  3.458755 ,  6.7902493,
         0.       ], dtype=float32),
 'layers.9.weight': array([ -2.3285942,   0.       ,  -3.9471323, -39.11015  ,  -4.6057286,
          0.       ,  -7.8070364, -77.35598  ], dtype=float32),
 'layers.12.weight': array([-12.126856, -64.91129 ], dtype=float32)}
get_weights(model)
{'Layer 0': array([-0.12895013,  0.01047491, -0.15705723,  0.11925378, -0.26944348,
         0.23180884, -0.22984707, -0.25141433, -0.19982024,  0.1432175 ,
        -0.11684369,  0.11833665, -0.2683918 , -0.19186287, -0.11611126,
        -0.06214499, -0.24123858,  0.20706302, -0.07457636,  0.10150522,
         0.22361842,  0.05891513,  0.08647271,  0.3052416 , -0.1426217 ,
         0.10016554, -0.14069483,  0.22599207,  0.25255734, -0.29155323,
         0.26994652,  0.1510033 ,  0.13780165,  0.13018303,  0.26287985,
        -0.04175457, -0.26743335, -0.09074122, -0.2227112 ,  0.02090477,
        -0.05904209, -0.16961981, -0.02875187,  0.2995954 , -0.0249426 ,
         0.01004026, -0.04931906,  0.04971322,  0.28176296,  0.19337103,
         0.11224869,  0.06871963,  0.07456426,  0.12216929, -0.04086405,
        -0.29390487, -0.19555901,  0.2699275 ,  0.01890202, -0.25616774,
         0.04987781,  0.26129004, -0.29883513, -0.21289697, -0.12594265,
         0.0126926 , -0.07375361, -0.03475064, -0.30828732,  0.14808287,
         0.2775668 ,  0.19329055, -0.22393112, -0.25491226,  0.13131432,
         0.00710202,  0.12963155, -0.3090024 , -0.01885445,  0.22301763],
       dtype=float32),
 'Layer 3': array([ 0.19455571,  0.12364562, -0.2711233 ,  0.2728095 ,  0.11085409,
         0.24458633, -0.13908438,  0.07495222,  0.34520328,  0.23782092,
         0.28354865, -0.07424083,  0.26936427, -0.2769144 ,  0.03057847,
        -0.19906998, -0.08245403, -0.09054411,  0.02645254,  0.32178298,
         0.17503859, -0.00149773,  0.2509683 , -0.1811804 ,  0.18221132,
        -0.03278595, -0.06152213,  0.0413917 , -0.27085608,  0.04085568,
         0.11887809,  0.302264  ], dtype=float32),
 'Layer 6': array([ 0.4752962 , -0.24824601,  0.22039747,  0.19587505,  0.13966405,
         0.39540154, -0.20208222,  0.13140953,  0.00280607, -0.3760708 ,
        -0.12140697, -0.33391154,  0.22107768,  0.04494798,  0.04898232,
        -0.15168536], dtype=float32),
 'Layer 9': array([ 0.07573527, -0.22145915, -0.30541402,  0.03821951, -0.3709231 ,
        -0.3758251 , -0.3254385 , -0.1698224 ], dtype=float32),
 'Layer 12': array([0.2738903, 0.5417278], dtype=float32)}
visualize_activations(model, fun_control=fun_control, batch_size=batch_size, device = "cpu", cmap="BlueWhiteRed", absolute=False)
net: NetLightRegression(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=8, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.01, inplace=False)
    (3): Linear(in_features=8, out_features=4, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.01, inplace=False)
    (6): Linear(in_features=4, out_features=4, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.01, inplace=False)
    (9): Linear(in_features=4, out_features=2, bias=True)
    (10): ReLU()
    (11): Dropout(p=0.01, inplace=False)
    (12): Linear(in_features=2, out_features=1, bias=True)
  )
)
128 values in Layer 0.
16 padding values added.
144 values now in Layer 0.
64 values in Layer 3.
64 values now in Layer 3.
64 values in Layer 6.
64 values now in Layer 6.
32 values in Layer 9.
4 padding values added.
36 values now in Layer 9.
16 values in Layer 12.
16 values now in Layer 12.

visualize_weights_distributions(model, color=f"C{0}")
n:5

visualize_gradient_distributions(model, fun_control, batch_size=batch_size, color=f"C{0}")
n:5

visualize_weights(model, absolute=True, cmap="gray", figsize=(6, 6))
80 values in Layer Layer 0.
1 padding values added.
81 values now in Layer Layer 0.
32 values in Layer Layer 3.
4 padding values added.
36 values now in Layer Layer 3.
16 values in Layer Layer 6.
16 values now in Layer Layer 6.
8 values in Layer Layer 9.
1 padding values added.
9 values now in Layer Layer 9.
2 values in Layer Layer 12.
2 padding values added.
4 values now in Layer Layer 12.

visualize_gradients(model, fun_control, batch_size, absolute=True, cmap="BlueWhiteRed", figsize=(6, 6))
80 values in Layer layers.0.weight.
1 padding values added.
81 values now in Layer layers.0.weight.
32 values in Layer layers.3.weight.
4 padding values added.
36 values now in Layer layers.3.weight.
16 values in Layer layers.6.weight.
16 values now in Layer layers.6.weight.
8 values in Layer layers.9.weight.
1 padding values added.
9 values now in Layer layers.9.weight.
2 values in Layer layers.12.weight.
2 padding values added.
4 values now in Layer layers.12.weight.

visualize_activations(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
net: NetLightRegression(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=8, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.01, inplace=False)
    (3): Linear(in_features=8, out_features=4, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.01, inplace=False)
    (6): Linear(in_features=4, out_features=4, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.01, inplace=False)
    (9): Linear(in_features=4, out_features=2, bias=True)
    (10): ReLU()
    (11): Dropout(p=0.01, inplace=False)
    (12): Linear(in_features=2, out_features=1, bias=True)
  )
)
128 values in Layer 0.
16 padding values added.
144 values now in Layer 0.
64 values in Layer 3.
64 values now in Layer 3.
64 values in Layer 6.
64 values now in Layer 6.
32 values in Layer 9.
4 padding values added.
36 values now in Layer 9.
16 values in Layer 12.
16 values now in Layer 12.

visualize_activations(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
net: NetLightRegression(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=8, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.01, inplace=False)
    (3): Linear(in_features=8, out_features=4, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.01, inplace=False)
    (6): Linear(in_features=4, out_features=4, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.01, inplace=False)
    (9): Linear(in_features=4, out_features=2, bias=True)
    (10): ReLU()
    (11): Dropout(p=0.01, inplace=False)
    (12): Linear(in_features=2, out_features=1, bias=True)
  )
)
128 values in Layer 0.
16 padding values added.
144 values now in Layer 0.
64 values in Layer 3.
64 values now in Layer 3.
64 values in Layer 6.
64 values now in Layer 6.
32 values in Layer 9.
4 padding values added.
36 values now in Layer 9.
16 values in Layer 12.
16 values now in Layer 12.