attention
scaled_dot_product(q, k, v, mask=None)
¶
Compute scaled dot product attention.
Args:
q: Queries
k: Keys
v: Values
mask: Mask to apply to the attention logits
Returns:
Tuple of (Values, Attention weights)
Examples:
>>> from spotpython.light.transformer.attention import scaled_dot_product
seq_len, d_k = 1, 2
pl.seed_everything(42)
q = torch.randn(seq_len, d_k)
k = torch.randn(seq_len, d_k)
v = torch.randn(seq_len, d_k)
values, attention = scaled_dot_product(q, k, v)
print("Q
”, q) print(“K “, k) print(“V “, v) print(“Values “, values) print(“Attention “, attention)
Source code in spotpython/light/transformer/attention.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
|