About

The d9d.module.block.positional package manages positional encoding logic.

Features

Rotary Positional Encoding

Rotary Positional Encoding from RoFormer.

See RotaryEmbeddingProvider and RotaryEmbeddingApplicator classes.

First one is typically bound to a model class and is used for providing (cos, sin) embedding tensors for specified position IDs.

Second one is typically bound to attention module implementation and is used for modifying query and key states in runtime.

d9d.module.block.positional

Provides modules for positional embeddings, such as Rotary Positional Embeddings.

RotaryEmbeddingApplicator

Bases: Module

Applies Rotary Positional Embeddings (RoPE) to Q and K projections.

Source code in d9d/module/block/positional/rope.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class RotaryEmbeddingApplicator(nn.Module):
    """Applies Rotary Positional Embeddings (RoPE) to Q and K projections."""

    def __init__(self):
        """
        Constructs RotaryEmbeddingApplicator object.
        """

        super().__init__()

    def forward(
            self,
            query_states: torch.Tensor,
            key_states: torch.Tensor,
            position_embedding_cos: torch.Tensor,
            position_embedding_sin: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Rotates query and key states using provided cosine and sine embeddings.

        Args:
            query_states: Query tensor. Shape: `(batch, n_heads, seq_len, head_dim)`.
            key_states: Key tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
            position_embedding_cos: Cosine values for positions.
                Shape: `(batch, seq_len, head_dim)`.
            position_embedding_sin: Sine values for positions.
                Shape: `(batch, seq_len, head_dim)`.

        Returns:
            A tuple containing the rotated query and key tensors.
        """

        query_states, key_states = _apply_rotary_pos_emb(query_states, key_states,
                                                         position_embedding_cos, position_embedding_sin)

        return query_states, key_states

__init__()

Constructs RotaryEmbeddingApplicator object.

Source code in d9d/module/block/positional/rope.py
118
119
120
121
122
123
def __init__(self):
    """
    Constructs RotaryEmbeddingApplicator object.
    """

    super().__init__()

forward(query_states, key_states, position_embedding_cos, position_embedding_sin)

Rotates query and key states using provided cosine and sine embeddings.

Parameters:

Name Type Description Default
query_states Tensor

Query tensor. Shape: (batch, n_heads, seq_len, head_dim).

required
key_states Tensor

Key tensor. Shape: (batch, n_kv_heads, seq_len, head_dim).

required
position_embedding_cos Tensor

Cosine values for positions. Shape: (batch, seq_len, head_dim).

required
position_embedding_sin Tensor

Sine values for positions. Shape: (batch, seq_len, head_dim).

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple containing the rotated query and key tensors.

Source code in d9d/module/block/positional/rope.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def forward(
        self,
        query_states: torch.Tensor,
        key_states: torch.Tensor,
        position_embedding_cos: torch.Tensor,
        position_embedding_sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Rotates query and key states using provided cosine and sine embeddings.

    Args:
        query_states: Query tensor. Shape: `(batch, n_heads, seq_len, head_dim)`.
        key_states: Key tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
        position_embedding_cos: Cosine values for positions.
            Shape: `(batch, seq_len, head_dim)`.
        position_embedding_sin: Sine values for positions.
            Shape: `(batch, seq_len, head_dim)`.

    Returns:
        A tuple containing the rotated query and key tensors.
    """

    query_states, key_states = _apply_rotary_pos_emb(query_states, key_states,
                                                     position_embedding_cos, position_embedding_sin)

    return query_states, key_states

RotaryEmbeddingProvider

Bases: Module, ModuleLateInit

Module that manages and provides Rotary Positional Embeddings.

Source code in d9d/module/block/positional/rope.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class RotaryEmbeddingProvider(nn.Module, ModuleLateInit):
    """Module that manages and provides Rotary Positional Embeddings."""

    def __init__(self, rope_base: int, head_dim: int, max_position_ids: int):
        """Constructs the RotaryEmbeddingProvider."""

        super().__init__()
        self._rope_base = rope_base
        self._head_dim = head_dim
        self._max_position_ids = max_position_ids
        self.cos_emb = nn.Buffer(torch.empty(max_position_ids, head_dim), persistent=False)
        self.sin_emb = nn.Buffer(torch.empty(max_position_ids, head_dim), persistent=False)

    def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieves cached cosine and sine embeddings for specific positions.

        Args:
            position_ids: Tensor of position indices.

        Returns:
            A tuple of (cos, sin) tensors aligned with the input positions.
        """

        return self.cos_emb[position_ids], self.sin_emb[position_ids]

    def reset_parameters(self):
        with torch.no_grad():
            cos, sin = prepare_rotary_cos_sin_emb(
                rope_base=self._rope_base,
                head_dim=self._head_dim,
                max_position_ids=self._max_position_ids,
                device=self.cos_emb.device,
                dtype=self.cos_emb.dtype
            )
            self.cos_emb.data = cos
            self.sin_emb.data = sin

__init__(rope_base, head_dim, max_position_ids)

Constructs the RotaryEmbeddingProvider.

Source code in d9d/module/block/positional/rope.py
63
64
65
66
67
68
69
70
71
def __init__(self, rope_base: int, head_dim: int, max_position_ids: int):
    """Constructs the RotaryEmbeddingProvider."""

    super().__init__()
    self._rope_base = rope_base
    self._head_dim = head_dim
    self._max_position_ids = max_position_ids
    self.cos_emb = nn.Buffer(torch.empty(max_position_ids, head_dim), persistent=False)
    self.sin_emb = nn.Buffer(torch.empty(max_position_ids, head_dim), persistent=False)

forward(position_ids)

Retrieves cached cosine and sine embeddings for specific positions.

Parameters:

Name Type Description Default
position_ids Tensor

Tensor of position indices.

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple of (cos, sin) tensors aligned with the input positions.

Source code in d9d/module/block/positional/rope.py
73
74
75
76
77
78
79
80
81
82
83
84
def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Retrieves cached cosine and sine embeddings for specific positions.

    Args:
        position_ids: Tensor of position indices.

    Returns:
        A tuple of (cos, sin) tensors aligned with the input positions.
    """

    return self.cos_emb[position_ids], self.sin_emb[position_ids]