Attention Layers
About
The d9d.module.block.attention package provides optimized attention mechanism implementations.
Features
Scaled Dot-Product Attention Kernels
FlashSdpa- FlashAttention 2 (using new Torch SDPA API)
Grouped-Query Attention
GroupedQueryAttention is a Grouped-Query Attention implementation.
Due to its abstract nature it is also can be used as Multi-Head Attention and Multi-Query Attention module.
Uses FlashSDPA kernel.
Uses Rotary Positional Encoding.
Supports optional QK Normalization.
Multi-Head Latent Attention
MultiHeadLatentAttention is an implementation of the Multi-Head Latent Attention (MLA) mechanism introduced in DeepSeek-V2.
Uses FlashSDPA kernel.
Uses Rotary Positional Encoding.
d9d.module.block.attention
Provides attention layer implementations.
GroupedQueryAttention
Bases: Module, ModuleLateInit
Implements Grouped Query Attention (GQA) with RoPE and optional QK Normalization.
This module performs the full attention mechanism pipeline: 1. Linear projection to Q, K, V. 2. Optional RMS Normalization on Q and K. 3. Rotary Positional Embedding (RoPE) application. 4. Scaled Dot Product Attention (via FlashAttention). 5. Output projection.
__init__(hidden_size, num_attention_heads, num_key_value_heads, head_dim, qk_norm_eps, is_causal, rope_style)
Constructs the GroupedQueryAttention layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_size
|
int
|
Hidden size. |
required |
num_attention_heads
|
int
|
Number of Query heads. |
required |
num_key_value_heads
|
int
|
Number of Key/Value heads. If less than |
required |
head_dim
|
int
|
Dimensionality of a single attention head. |
required |
qk_norm_eps
|
float | None
|
Epsilon for LayerNorm/RMSNorm applied to Q and K. If None, normalization is disabled. |
required |
is_causal
|
bool
|
Whether to apply a causal mask (auto-regressive constraint). |
required |
rope_style
|
RotaryEmbeddingStyle
|
Rotary embedding layout style alignment. |
required |
forward(hidden_states, attention_mask, position_embeddings)
Computes the attention operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
Input tensor. Shape: |
required |
attention_mask
|
Tensor | None
|
Optional mask associated with the inputs. |
required |
position_embeddings
|
tuple[Tensor, Tensor]
|
Tuple of |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The attention output tensor. Shape: |
reset_parameters()
Resets module parameters.
LowRankProjection
Bases: Module
Implements a low-rank linear projection with an intermediate normalization layer.
__init__(in_features, bottleneck, out_features, norm_eps)
Constructs the LowRankProjection object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
bottleneck
|
int
|
Intermediate low-rank dimensionality. |
required |
out_features
|
int
|
Output dimensionality. |
required |
norm_eps
|
float
|
Epsilon value for the intermediate RMSNorm layer. |
required |
forward(x)
reset_parameters()
Resets module parameters.
MultiHeadLatentAttention
Bases: Module, ModuleLateInit
Implements Multi-Head Latent Attention (MLA) from DeepSeek-V2.
This module performs the full attention mechanism pipeline:
- Linear projection to Query (either direct or via a low-rank bottleneck with RMSNorm).
- Down-projection to a low-rank KV latent vector and a shared Key RoPE sub-vector.
- RMSNorm application on the KV latent vector.
- Up-projection of the KV latent vector into Key content (NOPE) and Value sub-vectors.
- Rotary Positional Embedding (RoPE) application strictly to the decoupled Query and Key RoPE sub-vectors.
- Concatenation of the content (NOPE) and rotated (RoPE) sub-vectors to form the final Query and Key heads.
- Scaled Dot Product Attention (via FlashAttention).
- Output projection.
q_lora_rank
property
Rank of the Q low-rank path, or None if Q is projected directly.
__init__(hidden_size, num_attention_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank, q_lora_rank, qk_down_norm_eps, is_causal, rope_style)
Constructs the MultiHeadLatentAttention layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_size
|
int
|
Model hidden dimension. |
required |
num_attention_heads
|
int
|
Number of attention heads. |
required |
qk_nope_head_dim
|
int
|
Per-head dimension for the content (no-RoPE) part of Q and K. |
required |
qk_rope_head_dim
|
int
|
Per-head dimension for the RoPE-rotated part of Q and K. |
required |
v_head_dim
|
int
|
Per-head dimension for values. |
required |
kv_lora_rank
|
int
|
Rank of the KV latent compression. |
required |
q_lora_rank
|
int | None
|
Rank of the Q low-rank path. If |
required |
qk_down_norm_eps
|
float
|
Epsilon for the RMSNorm applied to the KV and Q latent representations. |
required |
is_causal
|
bool
|
Whether to apply a causal mask (auto-regressive). |
required |
rope_style
|
RotaryEmbeddingStyle
|
Rotary embedding layout style alignment. |
required |
forward(hidden_states, attention_mask, position_embeddings)
Computes Multi-Head Latent Attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
Input tensor. Shape: |
required |
attention_mask
|
Tensor | None
|
Optional attention mask. |
required |
position_embeddings
|
tuple[Tensor, Tensor]
|
Tuple |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor. Shape: |
reset_parameters()
Resets all learnable parameters.
d9d.module.block.attention.sdpa
FlashSdpa
Bases: Module
Executes Scaled Dot Product Attention (SDPA) enforcing the FlashAttention backend.
__init__()
Constructs the FlashSdpa object.
forward(query_states, key_states, value_states, attention_mask, is_causal, scale)
Computes Scaled Dot-Product Attention using FlashAttention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_states
|
Tensor
|
Query tensor. Shape: |
required |
key_states
|
Tensor
|
Key tensor. Shape: |
required |
value_states
|
Tensor
|
Value tensor. Shape: |
required |
attention_mask
|
Tensor | None
|
Optional attention mask (usually not needed for FlashAttn with causal=True). |
required |
is_causal
|
bool
|
If True, applies a causal mask (upper triangular masking). |
required |
scale
|
float
|
Scaling factor applied to the dot products (usually |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The attention output tensor, permuted to channel-last format.
Shape: |