About

The d9d.module.block.attention package provides optimized attention mechanism implementations.

Features

Scaled Dot-Product Attention Kernels

  • FlashSdpa - FlashAttention 2 (using new Torch SDPA API)

Grouped-Query Attention

GroupedQueryAttention is a Grouped-Query Attention implementation.

Due to its abstract nature it is also can be used as Multi-Head Attention and Multi-Query Attention module.

Uses FlashSDPA kernel.

Uses Rotary Positional Encoding

Supports optional QK Normalization.

d9d.module.block.attention

Provides attention layer implementations.

GroupedQueryAttention

Bases: Module, ModuleLateInit

Implements Grouped Query Attention (GQA) with RoPE and optional QK Normalization.

This module performs the full attention mechanism pipeline: 1. Linear projection to Q, K, V. 2. Optional RMS Normalization on Q and K. 3. Rotary Positional Embedding (RoPE) application. 4. Scaled Dot Product Attention (via FlashAttention). 5. Output projection.

Source code in d9d/module/block/attention/grouped_query.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class GroupedQueryAttention(nn.Module, ModuleLateInit):
    """
    Implements Grouped Query Attention (GQA) with RoPE and optional QK Normalization.

    This module performs the full attention mechanism pipeline:
    1.  Linear projection to Q, K, V.
    2.  Optional RMS Normalization on Q and K.
    3.  Rotary Positional Embedding (RoPE) application.
    4.  Scaled Dot Product Attention (via FlashAttention).
    5.  Output projection.
    """

    def __init__(
            self,
            hidden_size: int,
            num_attention_heads: int,
            num_key_value_heads: int,
            head_dim: int,
            qk_norm_eps: float | None,
            is_causal: bool
    ):
        """
        Constructs the GroupedQueryAttention layer.

        Args:
            hidden_size: Hidden size.
            num_attention_heads: Number of Query heads.
            num_key_value_heads: Number of Key/Value heads. If less than `num_attention_heads`, GQA/MQA is enabled.
            head_dim: Dimensionality of a single attention head.
            qk_norm_eps: Epsilon for LayerNorm/RMSNorm applied to Q and K. If None, normalization is disabled.
            is_causal: Whether to apply a causal mask (auto-regressive constraint).
        """

        super().__init__()

        self._head_dim = head_dim
        self._num_key_value_groups = num_attention_heads // num_key_value_heads
        self._scaling = head_dim ** -0.5

        self.q_proj = nn.Linear(
            hidden_size, num_attention_heads * head_dim, bias=False
        )

        self.k_proj = nn.Linear(
            hidden_size, num_key_value_heads * head_dim, bias=False
        )

        self.v_proj = nn.Linear(
            hidden_size, num_key_value_heads * head_dim, bias=False
        )

        self.o_proj = nn.Linear(
            num_attention_heads * head_dim, hidden_size, bias=False
        )

        self.q_norm: nn.RMSNorm | None
        self.k_norm: nn.RMSNorm | None

        if qk_norm_eps is not None:
            self.q_norm = nn.RMSNorm(normalized_shape=head_dim,
                                     eps=qk_norm_eps)
            self.k_norm = nn.RMSNorm(normalized_shape=head_dim,
                                     eps=qk_norm_eps)
        else:
            self.q_norm = None
            self.k_norm = None

        self.rope = RotaryEmbeddingApplicator()
        self.kernel = FlashSdpa()
        self._is_causal = is_causal

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: torch.Tensor | None,
            position_embeddings: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        """
        Computes the attention operation.

        Args:
            hidden_states: Input tensor. Shape: `(batch, seq_len, hidden_size)`.
            attention_mask: Optional mask associated with the inputs.
            position_embeddings: Tuple of `(cos, sin)` tensors for RoPE application.
                Each tensor should be of shape `(batch, seq_len, head_dim)`

        Returns:
            The attention output tensor. Shape: `(batch, seq_len, hidden_size)`.
        """

        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self._head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape)
        if self.q_norm is not None:
            query_states = self.q_norm(query_states)
        query_states = query_states.transpose(1, 2)

        key_states = self.k_proj(hidden_states).view(hidden_shape)
        if self.k_norm is not None:
            key_states = self.k_norm(key_states)
        key_states = key_states.transpose(1, 2)

        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        query_states, key_states = self.rope(query_states, key_states, position_embeddings[0], position_embeddings[1])

        outputs = self.kernel(
            query_states,
            key_states,
            value_states,
            attention_mask=attention_mask,
            is_causal=self._is_causal,
            scale=self._scaling
        )

        outputs = outputs.reshape(*input_shape, -1).contiguous()
        outputs = self.o_proj(outputs)
        return outputs

    def reset_parameters(self):
        """Resets module parameters."""

        self.q_proj.reset_parameters()
        self.k_proj.reset_parameters()
        self.v_proj.reset_parameters()
        self.o_proj.reset_parameters()
        if self.q_norm is not None:
            self.q_norm.reset_parameters()
        if self.k_norm is not None:
            self.k_norm.reset_parameters()

__init__(hidden_size, num_attention_heads, num_key_value_heads, head_dim, qk_norm_eps, is_causal)

Constructs the GroupedQueryAttention layer.

Parameters:

Name Type Description Default
hidden_size int

Hidden size.

required
num_attention_heads int

Number of Query heads.

required
num_key_value_heads int

Number of Key/Value heads. If less than num_attention_heads, GQA/MQA is enabled.

required
head_dim int

Dimensionality of a single attention head.

required
qk_norm_eps float | None

Epsilon for LayerNorm/RMSNorm applied to Q and K. If None, normalization is disabled.

required
is_causal bool

Whether to apply a causal mask (auto-regressive constraint).

required
Source code in d9d/module/block/attention/grouped_query.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        num_key_value_heads: int,
        head_dim: int,
        qk_norm_eps: float | None,
        is_causal: bool
):
    """
    Constructs the GroupedQueryAttention layer.

    Args:
        hidden_size: Hidden size.
        num_attention_heads: Number of Query heads.
        num_key_value_heads: Number of Key/Value heads. If less than `num_attention_heads`, GQA/MQA is enabled.
        head_dim: Dimensionality of a single attention head.
        qk_norm_eps: Epsilon for LayerNorm/RMSNorm applied to Q and K. If None, normalization is disabled.
        is_causal: Whether to apply a causal mask (auto-regressive constraint).
    """

    super().__init__()

    self._head_dim = head_dim
    self._num_key_value_groups = num_attention_heads // num_key_value_heads
    self._scaling = head_dim ** -0.5

    self.q_proj = nn.Linear(
        hidden_size, num_attention_heads * head_dim, bias=False
    )

    self.k_proj = nn.Linear(
        hidden_size, num_key_value_heads * head_dim, bias=False
    )

    self.v_proj = nn.Linear(
        hidden_size, num_key_value_heads * head_dim, bias=False
    )

    self.o_proj = nn.Linear(
        num_attention_heads * head_dim, hidden_size, bias=False
    )

    self.q_norm: nn.RMSNorm | None
    self.k_norm: nn.RMSNorm | None

    if qk_norm_eps is not None:
        self.q_norm = nn.RMSNorm(normalized_shape=head_dim,
                                 eps=qk_norm_eps)
        self.k_norm = nn.RMSNorm(normalized_shape=head_dim,
                                 eps=qk_norm_eps)
    else:
        self.q_norm = None
        self.k_norm = None

    self.rope = RotaryEmbeddingApplicator()
    self.kernel = FlashSdpa()
    self._is_causal = is_causal

forward(hidden_states, attention_mask, position_embeddings)

Computes the attention operation.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor. Shape: (batch, seq_len, hidden_size).

required
attention_mask Tensor | None

Optional mask associated with the inputs.

required
position_embeddings tuple[Tensor, Tensor]

Tuple of (cos, sin) tensors for RoPE application. Each tensor should be of shape (batch, seq_len, head_dim)

required

Returns:

Type Description
Tensor

The attention output tensor. Shape: (batch, seq_len, hidden_size).

Source code in d9d/module/block/attention/grouped_query.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None,
        position_embeddings: tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
    """
    Computes the attention operation.

    Args:
        hidden_states: Input tensor. Shape: `(batch, seq_len, hidden_size)`.
        attention_mask: Optional mask associated with the inputs.
        position_embeddings: Tuple of `(cos, sin)` tensors for RoPE application.
            Each tensor should be of shape `(batch, seq_len, head_dim)`

    Returns:
        The attention output tensor. Shape: `(batch, seq_len, hidden_size)`.
    """

    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, self._head_dim)

    query_states = self.q_proj(hidden_states).view(hidden_shape)
    if self.q_norm is not None:
        query_states = self.q_norm(query_states)
    query_states = query_states.transpose(1, 2)

    key_states = self.k_proj(hidden_states).view(hidden_shape)
    if self.k_norm is not None:
        key_states = self.k_norm(key_states)
    key_states = key_states.transpose(1, 2)

    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

    query_states, key_states = self.rope(query_states, key_states, position_embeddings[0], position_embeddings[1])

    outputs = self.kernel(
        query_states,
        key_states,
        value_states,
        attention_mask=attention_mask,
        is_causal=self._is_causal,
        scale=self._scaling
    )

    outputs = outputs.reshape(*input_shape, -1).contiguous()
    outputs = self.o_proj(outputs)
    return outputs

reset_parameters()

Resets module parameters.

Source code in d9d/module/block/attention/grouped_query.py
129
130
131
132
133
134
135
136
137
138
139
def reset_parameters(self):
    """Resets module parameters."""

    self.q_proj.reset_parameters()
    self.k_proj.reset_parameters()
    self.v_proj.reset_parameters()
    self.o_proj.reset_parameters()
    if self.q_norm is not None:
        self.q_norm.reset_parameters()
    if self.k_norm is not None:
        self.k_norm.reset_parameters()

d9d.module.block.attention.sdpa

FlashSdpa

Bases: Module

Executes Scaled Dot Product Attention (SDPA) enforcing the FlashAttention backend.

Source code in d9d/module/block/attention/sdpa/flash.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class FlashSdpa(nn.Module):
    """Executes Scaled Dot Product Attention (SDPA) enforcing the FlashAttention backend."""

    def __init__(self):
        """
        Constructs the FlashSdpa object.
        """
        super().__init__()

    def forward(
            self,
            query_states: torch.Tensor,
            key_states: torch.Tensor,
            value_states: torch.Tensor,
            attention_mask: torch.Tensor | None,
            is_causal: bool,
            scale: float
    ) -> torch.Tensor:
        """
        Computes Scaled Dot-Product Attention using FlashAttention.

        Args:
            query_states: Query tensor. Shape: `(batch, n_q_heads, seq_len, head_dim)`.
            key_states: Key tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
            value_states: Value tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
            attention_mask: Optional attention mask (usually not needed for FlashAttn with causal=True).
            is_causal: If True, applies a causal mask (upper triangular masking).
            scale: Scaling factor applied to the dot products (usually `1 / sqrt(head_dim)`).

        Returns:
            The attention output tensor, permuted to channel-last format.
                Shape: `(batch, seq_len, n_q_heads, head_dim)`.
        """

        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
            results = F.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=attention_mask,
                dropout_p=0.0,
                is_causal=is_causal,
                scale=scale,
                enable_gqa=query_states.shape[1] != key_states.shape[1]
            )
            return results.transpose(1, 2).contiguous()

__init__()

Constructs the FlashSdpa object.

Source code in d9d/module/block/attention/sdpa/flash.py
10
11
12
13
14
def __init__(self):
    """
    Constructs the FlashSdpa object.
    """
    super().__init__()

forward(query_states, key_states, value_states, attention_mask, is_causal, scale)

Computes Scaled Dot-Product Attention using FlashAttention.

Parameters:

Name Type Description Default
query_states Tensor

Query tensor. Shape: (batch, n_q_heads, seq_len, head_dim).

required
key_states Tensor

Key tensor. Shape: (batch, n_kv_heads, seq_len, head_dim).

required
value_states Tensor

Value tensor. Shape: (batch, n_kv_heads, seq_len, head_dim).

required
attention_mask Tensor | None

Optional attention mask (usually not needed for FlashAttn with causal=True).

required
is_causal bool

If True, applies a causal mask (upper triangular masking).

required
scale float

Scaling factor applied to the dot products (usually 1 / sqrt(head_dim)).

required

Returns:

Type Description
Tensor

The attention output tensor, permuted to channel-last format. Shape: (batch, seq_len, n_q_heads, head_dim).

Source code in d9d/module/block/attention/sdpa/flash.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def forward(
        self,
        query_states: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        attention_mask: torch.Tensor | None,
        is_causal: bool,
        scale: float
) -> torch.Tensor:
    """
    Computes Scaled Dot-Product Attention using FlashAttention.

    Args:
        query_states: Query tensor. Shape: `(batch, n_q_heads, seq_len, head_dim)`.
        key_states: Key tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
        value_states: Value tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
        attention_mask: Optional attention mask (usually not needed for FlashAttn with causal=True).
        is_causal: If True, applies a causal mask (upper triangular masking).
        scale: Scaling factor applied to the dot products (usually `1 / sqrt(head_dim)`).

    Returns:
        The attention output tensor, permuted to channel-last format.
            Shape: `(batch, seq_len, n_q_heads, head_dim)`.
    """

    with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        results = F.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=0.0,
            is_causal=is_causal,
            scale=scale,
            enable_gqa=query_states.shape[1] != key_states.shape[1]
        )
        return results.transpose(1, 2).contiguous()