Skip to content

multiheadattention

MultiheadAttention

Bases: Module

Source code in spotpython/light/transformer/multiheadattention.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads):
        """Constructor.

        Args:
            input_dim (int): input dimensionality.
            embed_dim (int): embedding dimensionality.
            num_heads (int): number of heads.
        """
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dim. must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3 * embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        if mask is not None:
            mask = expand_mask(mask)
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)  # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3)  # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

__init__(input_dim, embed_dim, num_heads)

Constructor.

Parameters:

Name Type Description Default
input_dim int

input dimensionality.

required
embed_dim int

embedding dimensionality.

required
num_heads int

number of heads.

required
Source code in spotpython/light/transformer/multiheadattention.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(self, input_dim, embed_dim, num_heads):
    """Constructor.

    Args:
        input_dim (int): input dimensionality.
        embed_dim (int): embedding dimensionality.
        num_heads (int): number of heads.
    """
    super().__init__()
    assert embed_dim % num_heads == 0, "Embedding dim. must be 0 modulo number of heads."

    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads

    # Stack all weight matrices 1...h together for efficiency
    # Note that in many implementations you see "bias=False" which is optional
    self.qkv_proj = nn.Linear(input_dim, 3 * embed_dim)
    self.o_proj = nn.Linear(embed_dim, embed_dim)

    self._reset_parameters()

expand_mask(mask)

Helper function to support different mask shapes. Expands the mask to the correct shape for the MultiheadAttention layer. Output shape supports (batch_size, number of heads, seq length, seq length). If 2D: broadcasted over batch size and number of heads. If 3D: broadcasted over number of heads. If 4D: leave as is.

Parameters:

Name Type Description Default
mask Tensor

Mask tensor of shape (batch_size, seq_length, seq_length) or (seq_length, seq_length).

required
Source code in spotpython/light/transformer/multiheadattention.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def expand_mask(mask):
    """
    Helper function to support different mask shapes.
    Expands the mask to the correct shape for the MultiheadAttention layer.
    Output shape supports (batch_size, number of heads, seq length, seq length).
    If 2D: broadcasted over batch size and number of heads.
    If 3D: broadcasted over number of heads.
    If 4D: leave as is.

    Args:
        mask (torch.Tensor):
            Mask tensor of shape (batch_size, seq_length, seq_length) or (seq_length, seq_length).
    """
    assert mask.ndim >= 2, "Mask must be >= 2-dim. with seq_length x seq_length"
    if mask.ndim == 3:
        mask = mask.unsqueeze(1)
    while mask.ndim < 4:
        mask = mask.unsqueeze(0)
    return mask