About
The d9d.module.block.attention package provides optimized attention mechanism implementations.
Features
Scaled Dot-Product Attention Kernels
FlashSdpa- FlashAttention 2 (using new Torch SDPA API)
Grouped-Query Attention
GroupedQueryAttention is a Grouped-Query Attention implementation.
Due to its abstract nature it is also can be used as Multi-Head Attention and Multi-Query Attention module.
Uses FlashSDPA kernel.
Uses Rotary Positional Encoding
Supports optional QK Normalization.
d9d.module.block.attention
Provides attention layer implementations.
GroupedQueryAttention
Bases: Module, ModuleLateInit
Implements Grouped Query Attention (GQA) with RoPE and optional QK Normalization.
This module performs the full attention mechanism pipeline: 1. Linear projection to Q, K, V. 2. Optional RMS Normalization on Q and K. 3. Rotary Positional Embedding (RoPE) application. 4. Scaled Dot Product Attention (via FlashAttention). 5. Output projection.
Source code in d9d/module/block/attention/grouped_query.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | |
__init__(hidden_size, num_attention_heads, num_key_value_heads, head_dim, qk_norm_eps, is_causal)
Constructs the GroupedQueryAttention layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_size
|
int
|
Hidden size. |
required |
num_attention_heads
|
int
|
Number of Query heads. |
required |
num_key_value_heads
|
int
|
Number of Key/Value heads. If less than |
required |
head_dim
|
int
|
Dimensionality of a single attention head. |
required |
qk_norm_eps
|
float | None
|
Epsilon for LayerNorm/RMSNorm applied to Q and K. If None, normalization is disabled. |
required |
is_causal
|
bool
|
Whether to apply a causal mask (auto-regressive constraint). |
required |
Source code in d9d/module/block/attention/grouped_query.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | |
forward(hidden_states, attention_mask, position_embeddings)
Computes the attention operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
Input tensor. Shape: |
required |
attention_mask
|
Tensor | None
|
Optional mask associated with the inputs. |
required |
position_embeddings
|
tuple[Tensor, Tensor]
|
Tuple of |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The attention output tensor. Shape: |
Source code in d9d/module/block/attention/grouped_query.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | |
reset_parameters()
Resets module parameters.
Source code in d9d/module/block/attention/grouped_query.py
129 130 131 132 133 134 135 136 137 138 139 | |
d9d.module.block.attention.sdpa
FlashSdpa
Bases: Module
Executes Scaled Dot Product Attention (SDPA) enforcing the FlashAttention backend.
Source code in d9d/module/block/attention/sdpa/flash.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | |
__init__()
Constructs the FlashSdpa object.
Source code in d9d/module/block/attention/sdpa/flash.py
10 11 12 13 14 | |
forward(query_states, key_states, value_states, attention_mask, is_causal, scale)
Computes Scaled Dot-Product Attention using FlashAttention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_states
|
Tensor
|
Query tensor. Shape: |
required |
key_states
|
Tensor
|
Key tensor. Shape: |
required |
value_states
|
Tensor
|
Value tensor. Shape: |
required |
attention_mask
|
Tensor | None
|
Optional attention mask (usually not needed for FlashAttn with causal=True). |
required |
is_causal
|
bool
|
If True, applies a causal mask (upper triangular masking). |
required |
scale
|
float
|
Scaling factor applied to the dot products (usually |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The attention output tensor, permuted to channel-last format.
Shape: |
Source code in d9d/module/block/attention/sdpa/flash.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | |