About

The d9d.module.block.embedding package provides enhanced embedding layers.

Features

Currently, this package provides only SplitTokenEmbeddings module. You can use this module:

  • Regular Token Embedding Layer: Specify a single split with global vocab size.
  • For Prompt Tuning: Add additional tokens to your Tokenizer and specify two splits - first one will be original token embeddings, second one will be newly added learnable prompt tokens. Unfreeze only nn.Embedding module that is related to the second split.

d9d.module.block.embedding

Package providing various embedding layer implementations

SplitTokenEmbeddings

Bases: Module, ModuleLateInit

A token embedding layer composed of multiple named, independent embedding tables.

This class maintains a dictionary of embedding layers, mapping contiguous ranges of global vocabulary indices to specific named splits (e.g., 'orig', 'special', 'prompt_prefix'). This is useful for model adaptation strategies where different sets of tokens require different initialization training behaviors.

Source code in d9d/module/block/embedding/shard_token_embedding.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class SplitTokenEmbeddings(nn.Module, ModuleLateInit):
    """
    A token embedding layer composed of multiple named, independent embedding tables.

    This class maintains a dictionary of embedding layers, mapping contiguous
    ranges of global vocabulary indices to specific named splits (e.g., 'orig',
    'special', 'prompt_prefix'). This is useful for model adaptation strategies where
    different sets of tokens require different initialization  training behaviors.
    """

    def __init__(
            self,
            split_vocab_size: dict[str, int],
            split_order: Sequence[str],
            hidden_size: int
    ):
        """
        Constructs the SplitTokenEmbeddings object.

        Args:
            split_vocab_size: A dictionary mapping split names to their vocabulary sizes.
            split_order: A sequence defining the order in which splits are concatenated
                to form the global vocabulary. Keys provided here must exist in
                split_vocab_size.
            hidden_size: The dimensionality of the embedding vectors.
        """

        super().__init__()

        token_embedding = nn.ModuleDict({
            split_name: nn.Embedding(vocab_size, hidden_size)
            for split_name, vocab_size in split_vocab_size.items()
        })
        self.token_embedding: Mapping[str, nn.Embedding] = cast(Mapping[str, nn.Embedding], token_embedding)

        self._id_start, self._id_end = _build_token_start_end_indices(split_vocab_size, split_order)
        self._hidden_size = hidden_size
        self._split_order = split_order

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Retrieves embeddings for the input indices by routing them to appropriate internal layers.

        Args:
            input_ids: Tensor of arbitrary shape containing global vocabulary indices.

        Returns:
            Tensor of same shape as input_ids plus a last dimension of hidden_size.
        """

        metadata_weight = next(iter(self.token_embedding.values())).weight
        # todo custom cuda kernel for indexing and filling?

        embed = torch.empty(
            size=(input_ids.shape[0], input_ids.shape[1], self._hidden_size),
            device=metadata_weight.device,
            dtype=metadata_weight.dtype,
        )

        for split_name in self._split_order:
            start_idx, end_idx = self._id_start[split_name], self._id_end[split_name]
            is_split_mask = (input_ids >= start_idx) & (input_ids < end_idx)
            split_embed = self.token_embedding[split_name](input_ids[is_split_mask] - start_idx)
            embed[is_split_mask] = split_embed

        return embed

    def reset_parameters(self):
        """
        Resets parameters for all registered embedding splits.
        """

        for layer in self.token_embedding.values():
            layer.reset_parameters()

__init__(split_vocab_size, split_order, hidden_size)

Constructs the SplitTokenEmbeddings object.

Parameters:

Name Type Description Default
split_vocab_size dict[str, int]

A dictionary mapping split names to their vocabulary sizes.

required
split_order Sequence[str]

A sequence defining the order in which splits are concatenated to form the global vocabulary. Keys provided here must exist in split_vocab_size.

required
hidden_size int

The dimensionality of the embedding vectors.

required
Source code in d9d/module/block/embedding/shard_token_embedding.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
        self,
        split_vocab_size: dict[str, int],
        split_order: Sequence[str],
        hidden_size: int
):
    """
    Constructs the SplitTokenEmbeddings object.

    Args:
        split_vocab_size: A dictionary mapping split names to their vocabulary sizes.
        split_order: A sequence defining the order in which splits are concatenated
            to form the global vocabulary. Keys provided here must exist in
            split_vocab_size.
        hidden_size: The dimensionality of the embedding vectors.
    """

    super().__init__()

    token_embedding = nn.ModuleDict({
        split_name: nn.Embedding(vocab_size, hidden_size)
        for split_name, vocab_size in split_vocab_size.items()
    })
    self.token_embedding: Mapping[str, nn.Embedding] = cast(Mapping[str, nn.Embedding], token_embedding)

    self._id_start, self._id_end = _build_token_start_end_indices(split_vocab_size, split_order)
    self._hidden_size = hidden_size
    self._split_order = split_order

forward(input_ids)

Retrieves embeddings for the input indices by routing them to appropriate internal layers.

Parameters:

Name Type Description Default
input_ids Tensor

Tensor of arbitrary shape containing global vocabulary indices.

required

Returns:

Type Description
Tensor

Tensor of same shape as input_ids plus a last dimension of hidden_size.

Source code in d9d/module/block/embedding/shard_token_embedding.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
    """
    Retrieves embeddings for the input indices by routing them to appropriate internal layers.

    Args:
        input_ids: Tensor of arbitrary shape containing global vocabulary indices.

    Returns:
        Tensor of same shape as input_ids plus a last dimension of hidden_size.
    """

    metadata_weight = next(iter(self.token_embedding.values())).weight
    # todo custom cuda kernel for indexing and filling?

    embed = torch.empty(
        size=(input_ids.shape[0], input_ids.shape[1], self._hidden_size),
        device=metadata_weight.device,
        dtype=metadata_weight.dtype,
    )

    for split_name in self._split_order:
        start_idx, end_idx = self._id_start[split_name], self._id_end[split_name]
        is_split_mask = (input_ids >= start_idx) & (input_ids < end_idx)
        split_embed = self.token_embedding[split_name](input_ids[is_split_mask] - start_idx)
        embed[is_split_mask] = split_embed

    return embed

reset_parameters()

Resets parameters for all registered embedding splits.

Source code in d9d/module/block/embedding/shard_token_embedding.py
93
94
95
96
97
98
99
def reset_parameters(self):
    """
    Resets parameters for all registered embedding splits.
    """

    for layer in self.token_embedding.values():
        layer.reset_parameters()