Skip to content

dimensions

extract_linear_dims(model)

Extracts the input and output dimensions of the Linear layers in a PyTorch model.

Parameters:

Name Type Description Default
model Module

PyTorch model.

required

Returns:

Type Description
array

np.array: Array with the input and output dimensions of the Linear layers.

Examples:

>>> from spotpython.torch.dimensions import extract_linear_dims
>>> net = NNLinearRegressor()
>>> result = extract_linear_dims(net)
Source code in spotpython/torch/dimensions.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def extract_linear_dims(model) -> np.array:
    """Extracts the input and output dimensions of the Linear layers in a PyTorch model.

    Args:
        model (nn.Module): PyTorch model.

    Returns:
        np.array: Array with the input and output dimensions of the Linear layers.

    Examples:
        >>> from spotpython.torch.dimensions import extract_linear_dims
        >>> net = NNLinearRegressor()
        >>> result = extract_linear_dims(net)

    """
    dims = []
    for layer in model.layers:
        if isinstance(layer, nn.Linear):
            # Append input and output features of the Linear layer
            dims.append(layer.in_features)
            dims.append(layer.out_features)
    return np.array(dims)