Skip to content

Attention Layers

About

The d9d.module.block.attention package provides optimized attention mechanism implementations.

Softmax Attention

Grouped-Query Attention

GroupedQueryAttention is a Grouped-Query Attention implementation.

Due to its abstract nature it is also can be used as a Multi-Head Attention and Multi-Query Attention module.

Multi-Head Latent Attention

MultiHeadLatentAttention is an implementation of the Multi-Head Latent Attention (MLA) mechanism introduced in DeepSeek-V2.

Scaled Dot-Product Attention Kernels

  • FlashSdpa - FlashAttention 4.

d9d.module.block.attention

Provides attention layer implementations.

GroupedQueryAttention

Bases: Module, ModuleLateInit

Implements Grouped Query Attention (GQA) with RoPE and optional QK Normalization.

This module performs the full attention mechanism pipeline: 1. Linear projection to Q, K, V. 2. Optional RMS Normalization on Q and K. 3. Rotary Positional Embedding (RoPE) application. 4. Scaled Dot Product Attention (via FlashAttention). 5. Optional sigmoid output gating. 6. Output projection.

__init__(hidden_size, num_attention_heads, num_key_value_heads, head_dim, qk_norm_eps, is_causal, rope_style, rope_dim=None, enable_output_gate=False, qk_norm_zero_centered=False)

Constructs the GroupedQueryAttention layer.

Parameters:

Name Type Description Default
hidden_size int

Hidden size.

required
num_attention_heads int

Number of Query heads.

required
num_key_value_heads int

Number of Key/Value heads. If less than num_attention_heads, GQA/MQA is enabled.

required
head_dim int

Dimensionality of a single attention head.

required
qk_norm_eps float | None

Epsilon for LayerNorm/RMSNorm applied to Q and K. If None, normalization is disabled.

required
is_causal bool

Whether to apply a causal mask (auto-regressive constraint).

required
rope_style RotaryEmbeddingStyle

Rotary embedding layout style alignment.

required
rope_dim int | None

Dimension of the RoPE sub-vector. If None, RoPE is applied to the full head_dim.

None
enable_output_gate bool

If True, enables sigmoid output gating (Qwen 3.5 style).

False
qk_norm_zero_centered bool

If True, utilizes zero-centered scaling weights for the optional Q and K normalization layers.

False
forward(hidden_states, attention_mask, position_embeddings)

Computes the attention operation.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor. Shape: (batch, seq_len, hidden_size).

required
attention_mask Tensor | None

Optional mask associated with the inputs.

required
position_embeddings tuple[Tensor, Tensor]

Tuple of (cos, sin) tensors for RoPE application. Each tensor should be of shape (batch, seq_len, rope_dim) when partial RoPE is used, or (batch, seq_len, head_dim) otherwise.

required

Returns:

Type Description
Tensor

The attention output tensor. Shape: (batch, seq_len, hidden_size).

reset_parameters()

Resets module parameters.

LowRankProjection

Bases: Module

Implements a low-rank linear projection with an intermediate normalization layer.

__init__(in_features, bottleneck, out_features, norm_eps)

Constructs the LowRankProjection object.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
bottleneck int

Intermediate low-rank dimensionality.

required
out_features int

Output dimensionality.

required
norm_eps float

Epsilon value for the intermediate RMSNorm layer.

required
forward(x)

Applies the low-rank projection to the inputs.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Projected output tensor.

reset_parameters()

Resets module parameters.

MultiHeadLatentAttention

Bases: Module, ModuleLateInit

Implements Multi-Head Latent Attention (MLA) from DeepSeek-V2.

This module performs the full attention mechanism pipeline:

  1. Linear projection to Query (either direct or via a low-rank bottleneck with RMSNorm).
  2. Down-projection to a low-rank KV latent vector and a shared Key RoPE sub-vector.
  3. RMSNorm application on the KV latent vector.
  4. Up-projection of the KV latent vector into Key content (NOPE) and Value sub-vectors.
  5. Rotary Positional Embedding (RoPE) application strictly to the decoupled Query and Key RoPE sub-vectors.
  6. Concatenation of the content (NOPE) and rotated (RoPE) sub-vectors to form the final Query and Key heads.
  7. Scaled Dot Product Attention (via FlashAttention).
  8. Output projection.
q_lora_rank property

Rank of the Q low-rank path, or None if Q is projected directly.

__init__(hidden_size, num_attention_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank, q_lora_rank, qk_down_norm_eps, is_causal, rope_style)

Constructs the MultiHeadLatentAttention layer.

Parameters:

Name Type Description Default
hidden_size int

Model hidden dimension.

required
num_attention_heads int

Number of attention heads.

required
qk_nope_head_dim int

Per-head dimension for the content (no-RoPE) part of Q and K.

required
qk_rope_head_dim int

Per-head dimension for the RoPE-rotated part of Q and K.

required
v_head_dim int

Per-head dimension for values.

required
kv_lora_rank int

Rank of the KV latent compression.

required
q_lora_rank int | None

Rank of the Q low-rank path. If None, Q is projected directly.

required
qk_down_norm_eps float

Epsilon for the RMSNorm applied to the KV and Q latent representations.

required
is_causal bool

Whether to apply a causal mask (auto-regressive).

required
rope_style RotaryEmbeddingStyle

Rotary embedding layout style alignment.

required
forward(hidden_states, attention_mask, position_embeddings)

Computes Multi-Head Latent Attention.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor. Shape: (batch, seq_len, hidden_size).

required
attention_mask Tensor | None

Optional attention mask.

required
position_embeddings tuple[Tensor, Tensor]

Tuple (cos, sin) for the RoPE sub-vectors. Each tensor shape: (batch, seq_len, qk_rope_head_dim).

required

Returns:

Type Description
Tensor

Output tensor. Shape: (batch, seq_len, hidden_size).

reset_parameters()

Resets all learnable parameters.

d9d.module.block.attention.sdpa

FlashSdpa

Bases: Module

Scaled Dot Product Attention using Flash Attention 4.

When num_sinks is provided, a learnable per-head sink logit is added to the softmax denominator (attention-sink mechanism). This lets a fraction of attention mass be absorbed by the sink, effectively soft-gating the output without materializing an extra KV column.

Parameters:

Name Type Description Default
num_sinks int | None

Number of learnable sink scalars (one per query head). None (default) disables sinks and gives plain attention.

None
window_size int | None

Sliding-window size for local attention. None (default) disables the window and uses full attention.

None
forward(query_states, key_states, value_states, attention_mask, is_causal, scale)

Computes Scaled Dot-Product Attention.

Parameters:

Name Type Description Default
query_states Tensor

Query tensor. Shape: (batch, seq_len, n_q_heads, head_dim).

required
key_states Tensor

Key tensor. Shape: (batch, seq_len, n_kv_heads, head_dim).

required
value_states Tensor

Value tensor. Shape: (batch, seq_len, n_kv_heads, head_dim).

required
attention_mask Tensor | None

Unused. Present for interface compatibility.

required
is_causal bool

If True, applies a causal mask.

required
scale float

Softmax scaling factor (usually 1/sqrt(head_dim)).

required

Returns:

Type Description
Tensor

Attention output. Shape: (batch, seq_len, n_q_heads, head_dim).

Linear Attention

Gated DeltaNet

GatedDeltaNet is an implementation of the Gated DeltaNet (GDN) attention mechanism.

It acts as a linear attention alternative that combines the Delta Rule with Mamba-style data-dependent gating and short causal convolutions.

d9d.module.block.attention.linear

GatedDeltaNet

Bases: Module, ModuleLateInit

Implements Gated DeltaNet (GDN) attention mechanism.

This module combines linear attention based on the Delta Rule with Mamba-style data-dependent gating and short causal convolutions.

Pipeline
  1. Linear projections for Q, K, V, output gate (G), decay gate (GK), and write strength (Beta).
  2. Causal short depthwise convolution applied to Q, K, V.
  3. Data-dependent decay computation (Mamba-style or log-sigmoid).
  4. GQA/MQA head expansion for Q and K.
  5. Chunked Gated Delta Rule (with optional internal L2 norm on Q/K).
  6. Per-head RMSNorm and SiLU-gated output projection.
__init__(hidden_size, num_query_key_heads, num_value_heads, head_qk_dim, head_v_dim, norm_eps, conv_size, decay_gate, use_qk_l2norm=True)

Constructs a GatedDeltaNet object.

Parameters:

Name Type Description Default
hidden_size int

Hidden size.

required
num_query_key_heads int

Number of query and key attention heads before grouped expansion.

required
num_value_heads int

Number of value attention heads.

required
head_qk_dim int

Dimension allocated for a single query or key per head.

required
head_v_dim int

Dimension allocated for a single value per head.

required
norm_eps float

Small constant added for numerical stability to the normalization layer.

required
conv_size int

Size of the causal convolution kernel context.

required
decay_gate AnyDecayGateParameters

Structured parameters to initialize the selected decay gate mechanism.

required
use_qk_l2norm bool

Whether to enable L2 normalization applied to Q/K internally.

True

Raises:

Type Description
ValueError

When num_value_heads is not uniformly divisible by num_query_key_heads.

forward(hidden_states, attention_mask=None)

Runs forward pass.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor sequence of shape (batch, seq_len, hidden_size).

required
attention_mask Tensor | None

Optional padding mask tensor of shape (batch, seq_len).

None

Returns:

Type Description
Tensor

Processed tensor possessing the identical shape as the input.

reset_parameters()

Resets learnable parameters of this module.

LogSigmoidDecayGateParameters

Bases: BaseModel

Configuration parameters for the LogSigmoid decay gate.

MambaDecayGateParameters

Bases: BaseModel

Configuration parameters for the Mamba-style decay gate.