Skip to content

Attention Layers

About

The d9d.module.block.attention package provides optimized attention mechanism implementations.

Features

Scaled Dot-Product Attention Kernels

  • FlashSdpa - FlashAttention 2 (using new Torch SDPA API)

Grouped-Query Attention

GroupedQueryAttention is a Grouped-Query Attention implementation.

Due to its abstract nature it is also can be used as Multi-Head Attention and Multi-Query Attention module.

Uses FlashSDPA kernel.

Uses Rotary Positional Encoding.

Supports optional QK Normalization.

Multi-Head Latent Attention

MultiHeadLatentAttention is an implementation of the Multi-Head Latent Attention (MLA) mechanism introduced in DeepSeek-V2.

Uses FlashSDPA kernel.

Uses Rotary Positional Encoding.

d9d.module.block.attention

Provides attention layer implementations.

GroupedQueryAttention

Bases: Module, ModuleLateInit

Implements Grouped Query Attention (GQA) with RoPE and optional QK Normalization.

This module performs the full attention mechanism pipeline: 1. Linear projection to Q, K, V. 2. Optional RMS Normalization on Q and K. 3. Rotary Positional Embedding (RoPE) application. 4. Scaled Dot Product Attention (via FlashAttention). 5. Output projection.

__init__(hidden_size, num_attention_heads, num_key_value_heads, head_dim, qk_norm_eps, is_causal, rope_style)

Constructs the GroupedQueryAttention layer.

Parameters:

Name Type Description Default
hidden_size int

Hidden size.

required
num_attention_heads int

Number of Query heads.

required
num_key_value_heads int

Number of Key/Value heads. If less than num_attention_heads, GQA/MQA is enabled.

required
head_dim int

Dimensionality of a single attention head.

required
qk_norm_eps float | None

Epsilon for LayerNorm/RMSNorm applied to Q and K. If None, normalization is disabled.

required
is_causal bool

Whether to apply a causal mask (auto-regressive constraint).

required
rope_style RotaryEmbeddingStyle

Rotary embedding layout style alignment.

required

forward(hidden_states, attention_mask, position_embeddings)

Computes the attention operation.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor. Shape: (batch, seq_len, hidden_size).

required
attention_mask Tensor | None

Optional mask associated with the inputs.

required
position_embeddings tuple[Tensor, Tensor]

Tuple of (cos, sin) tensors for RoPE application. Each tensor should be of shape (batch, seq_len, head_dim)

required

Returns:

Type Description
Tensor

The attention output tensor. Shape: (batch, seq_len, hidden_size).

reset_parameters()

Resets module parameters.

LowRankProjection

Bases: Module

Implements a low-rank linear projection with an intermediate normalization layer.

__init__(in_features, bottleneck, out_features, norm_eps)

Constructs the LowRankProjection object.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
bottleneck int

Intermediate low-rank dimensionality.

required
out_features int

Output dimensionality.

required
norm_eps float

Epsilon value for the intermediate RMSNorm layer.

required

forward(x)

Applies the low-rank projection to the inputs.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Projected output tensor.

reset_parameters()

Resets module parameters.

MultiHeadLatentAttention

Bases: Module, ModuleLateInit

Implements Multi-Head Latent Attention (MLA) from DeepSeek-V2.

This module performs the full attention mechanism pipeline:

  1. Linear projection to Query (either direct or via a low-rank bottleneck with RMSNorm).
  2. Down-projection to a low-rank KV latent vector and a shared Key RoPE sub-vector.
  3. RMSNorm application on the KV latent vector.
  4. Up-projection of the KV latent vector into Key content (NOPE) and Value sub-vectors.
  5. Rotary Positional Embedding (RoPE) application strictly to the decoupled Query and Key RoPE sub-vectors.
  6. Concatenation of the content (NOPE) and rotated (RoPE) sub-vectors to form the final Query and Key heads.
  7. Scaled Dot Product Attention (via FlashAttention).
  8. Output projection.

q_lora_rank property

Rank of the Q low-rank path, or None if Q is projected directly.

__init__(hidden_size, num_attention_heads, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, kv_lora_rank, q_lora_rank, qk_down_norm_eps, is_causal, rope_style)

Constructs the MultiHeadLatentAttention layer.

Parameters:

Name Type Description Default
hidden_size int

Model hidden dimension.

required
num_attention_heads int

Number of attention heads.

required
qk_nope_head_dim int

Per-head dimension for the content (no-RoPE) part of Q and K.

required
qk_rope_head_dim int

Per-head dimension for the RoPE-rotated part of Q and K.

required
v_head_dim int

Per-head dimension for values.

required
kv_lora_rank int

Rank of the KV latent compression.

required
q_lora_rank int | None

Rank of the Q low-rank path. If None, Q is projected directly.

required
qk_down_norm_eps float

Epsilon for the RMSNorm applied to the KV and Q latent representations.

required
is_causal bool

Whether to apply a causal mask (auto-regressive).

required
rope_style RotaryEmbeddingStyle

Rotary embedding layout style alignment.

required

forward(hidden_states, attention_mask, position_embeddings)

Computes Multi-Head Latent Attention.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor. Shape: (batch, seq_len, hidden_size).

required
attention_mask Tensor | None

Optional attention mask.

required
position_embeddings tuple[Tensor, Tensor]

Tuple (cos, sin) for the RoPE sub-vectors. Each tensor shape: (batch, seq_len, qk_rope_head_dim).

required

Returns:

Type Description
Tensor

Output tensor. Shape: (batch, seq_len, hidden_size).

reset_parameters()

Resets all learnable parameters.

d9d.module.block.attention.sdpa

FlashSdpa

Bases: Module

Executes Scaled Dot Product Attention (SDPA) enforcing the FlashAttention backend.

__init__()

Constructs the FlashSdpa object.

forward(query_states, key_states, value_states, attention_mask, is_causal, scale)

Computes Scaled Dot-Product Attention using FlashAttention.

Parameters:

Name Type Description Default
query_states Tensor

Query tensor. Shape: (batch, n_q_heads, seq_len, head_dim).

required
key_states Tensor

Key tensor. Shape: (batch, n_kv_heads, seq_len, head_dim).

required
value_states Tensor

Value tensor. Shape: (batch, n_kv_heads, seq_len, head_dim).

required
attention_mask Tensor | None

Optional attention mask (usually not needed for FlashAttn with causal=True).

required
is_causal bool

If True, applies a causal mask (upper triangular masking).

required
scale float

Scaling factor applied to the dot products (usually 1 / sqrt(head_dim)).

required

Returns:

Type Description
Tensor

The attention output tensor, permuted to channel-last format. Shape: (batch, seq_len, n_q_heads, head_dim).