dimensions
extract_linear_dims(model)
¶
Extracts the input and output dimensions of the Linear layers in a PyTorch model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
PyTorch model. |
required |
Returns:
Type | Description |
---|---|
array
|
np.array: Array with the input and output dimensions of the Linear layers. |
Examples:
>>> from spotpython.torch.dimensions import extract_linear_dims
>>> net = NNLinearRegressor()
>>> result = extract_linear_dims(net)
Source code in spotpython/torch/dimensions.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
|