multiheadattention
MultiheadAttention
¶
Bases: Module
Source code in spotpython/light/transformer/multiheadattention.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
__init__(input_dim, embed_dim, num_heads)
¶
Constructor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dim |
int
|
input dimensionality. |
required |
embed_dim |
int
|
embedding dimensionality. |
required |
num_heads |
int
|
number of heads. |
required |
Source code in spotpython/light/transformer/multiheadattention.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
|
expand_mask(mask)
¶
Helper function to support different mask shapes. Expands the mask to the correct shape for the MultiheadAttention layer. Output shape supports (batch_size, number of heads, seq length, seq length). If 2D: broadcasted over batch size and number of heads. If 3D: broadcasted over number of heads. If 4D: leave as is.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mask |
Tensor
|
Mask tensor of shape (batch_size, seq_length, seq_length) or (seq_length, seq_length). |
required |
Source code in spotpython/light/transformer/multiheadattention.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
|