About

The d9d.module.block.head package handles the model heads.

Features

Causal Language Modelling

SplitLanguageModellingHead provides a causal language modelling head that computes per-token logprobs.

It uses efficient fused Linear-Cross-Entropy kernel from the Cut-Cross-Entropy project and avoids full logit tensor materialization.

Supports vocab split to multiple independent splits following the SplitTokenEmbeddings embedding implementation.

d9d.module.block.head

SplitLanguageModellingHead

Bases: Module, ModuleLateInit

A segmented language modeling head that computes per-token cross-entropy loss values using a composed weight matrix.

This class maintains separate linear layers for different segments of the vocabulary (e.g., regular vs. special tokens). During the forward pass, it concatenates the weights to form a unified projection matrix and computes the cross-entropy loss efficiently, typically using a fused kernel to avoid materializing full logits.

The concatenation order of the weights is determined by split_order, which ensures consistency with the global vocabulary indices.

Source code in d9d/module/block/head/language_modelling.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class SplitLanguageModellingHead(nn.Module, ModuleLateInit):
    """
    A segmented language modeling head that computes per-token cross-entropy loss values using a composed weight matrix.

    This class maintains separate linear layers for different segments of the vocabulary
    (e.g., regular vs. special tokens). During the forward pass, it concatenates the
    weights to form a unified projection matrix and computes the cross-entropy loss
    efficiently, typically using a fused kernel to avoid materializing full logits.

    The concatenation order of the weights is determined by `split_order`, which ensures
    consistency with the global vocabulary indices.
    """

    def __init__(
            self,
            split_vocab_size: dict[str, int],
            split_order: Sequence[str],
            hidden_size: int
    ):
        """
        Constructs the SplitLanguageModellingHead object.

        Args:
            split_vocab_size: A dictionary mapping split names to their output vocabulary sizes.
            split_order: A sequence defining the order in which vocabulary segments should be
                concatenated. This determines the mapping of global indices to specific heads.
            hidden_size: The input dimensionality (hidden state size).
        """

        super().__init__()

        lm_head = nn.ModuleDict({
            split_name: nn.Linear(hidden_size, vocab_size, bias=False)
            for split_name, vocab_size in split_vocab_size.items()
        })

        self.lm_head: Mapping[str, nn.Linear] = cast(Mapping[str, nn.Linear], lm_head)
        self._split_order = split_order
        self._hidden_size = hidden_size

    def forward(
            self,
            hidden_states: torch.Tensor,
            labels: torch.Tensor
    ) -> torch.Tensor:
        """
        Computes the cross-entropy loss for the given hidden states and labels.

        Args:
            hidden_states: Input tensor of shape `(B, S, H)`.
            labels: Target label tensor of shape `(B, S)`. Indices must correspond
                to the global vocabulary formed by concatenating splits in `split_order`.

        Returns:
            A tensor containing per-token loss values (reduction='none'), matching the
            shape of the labels tensor.
        """

        lm_head_weight = torch.cat([self.lm_head[split_name].weight for split_name in self._split_order], dim=0)

        losses = linear_cross_entropy(
            hidden_states,
            lm_head_weight,
            labels,
            ignore_index=_IGNORE_INDEX,
            reduction="none"
        )
        return losses

    def reset_parameters(self):
        """Resets module parameters."""

        for head in self.lm_head.values():
            head.reset_parameters()

__init__(split_vocab_size, split_order, hidden_size)

Constructs the SplitLanguageModellingHead object.

Parameters:

Name Type Description Default
split_vocab_size dict[str, int]

A dictionary mapping split names to their output vocabulary sizes.

required
split_order Sequence[str]

A sequence defining the order in which vocabulary segments should be concatenated. This determines the mapping of global indices to specific heads.

required
hidden_size int

The input dimensionality (hidden state size).

required
Source code in d9d/module/block/head/language_modelling.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
        self,
        split_vocab_size: dict[str, int],
        split_order: Sequence[str],
        hidden_size: int
):
    """
    Constructs the SplitLanguageModellingHead object.

    Args:
        split_vocab_size: A dictionary mapping split names to their output vocabulary sizes.
        split_order: A sequence defining the order in which vocabulary segments should be
            concatenated. This determines the mapping of global indices to specific heads.
        hidden_size: The input dimensionality (hidden state size).
    """

    super().__init__()

    lm_head = nn.ModuleDict({
        split_name: nn.Linear(hidden_size, vocab_size, bias=False)
        for split_name, vocab_size in split_vocab_size.items()
    })

    self.lm_head: Mapping[str, nn.Linear] = cast(Mapping[str, nn.Linear], lm_head)
    self._split_order = split_order
    self._hidden_size = hidden_size

forward(hidden_states, labels)

Computes the cross-entropy loss for the given hidden states and labels.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor of shape (B, S, H).

required
labels Tensor

Target label tensor of shape (B, S). Indices must correspond to the global vocabulary formed by concatenating splits in split_order.

required

Returns:

Type Description
Tensor

A tensor containing per-token loss values (reduction='none'), matching the

Tensor

shape of the labels tensor.

Source code in d9d/module/block/head/language_modelling.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def forward(
        self,
        hidden_states: torch.Tensor,
        labels: torch.Tensor
) -> torch.Tensor:
    """
    Computes the cross-entropy loss for the given hidden states and labels.

    Args:
        hidden_states: Input tensor of shape `(B, S, H)`.
        labels: Target label tensor of shape `(B, S)`. Indices must correspond
            to the global vocabulary formed by concatenating splits in `split_order`.

    Returns:
        A tensor containing per-token loss values (reduction='none'), matching the
        shape of the labels tensor.
    """

    lm_head_weight = torch.cat([self.lm_head[split_name].weight for split_name in self._split_order], dim=0)

    losses = linear_cross_entropy(
        hidden_states,
        lm_head_weight,
        labels,
        ignore_index=_IGNORE_INDEX,
        reduction="none"
    )
    return losses

reset_parameters()

Resets module parameters.

Source code in d9d/module/block/head/language_modelling.py
82
83
84
85
86
def reset_parameters(self):
    """Resets module parameters."""

    for head in self.lm_head.values():
        head.reset_parameters()