About
The d9d.module.block.head package handles the model heads.
Features
Causal Language Modelling
SplitLanguageModellingHead provides a causal language modelling head that computes per-token logprobs.
It uses efficient fused Linear-Cross-Entropy kernel from the Cut-Cross-Entropy project and avoids full logit tensor materialization.
Supports vocab split to multiple independent splits following the SplitTokenEmbeddings embedding implementation.
d9d.module.block.head
SplitLanguageModellingHead
Bases: Module, ModuleLateInit
A segmented language modeling head that computes per-token cross-entropy loss values using a composed weight matrix.
This class maintains separate linear layers for different segments of the vocabulary (e.g., regular vs. special tokens). During the forward pass, it concatenates the weights to form a unified projection matrix and computes the cross-entropy loss efficiently, typically using a fused kernel to avoid materializing full logits.
The concatenation order of the weights is determined by split_order, which ensures
consistency with the global vocabulary indices.
Source code in d9d/module/block/head/language_modelling.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | |
__init__(split_vocab_size, split_order, hidden_size)
Constructs the SplitLanguageModellingHead object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
split_vocab_size
|
dict[str, int]
|
A dictionary mapping split names to their output vocabulary sizes. |
required |
split_order
|
Sequence[str]
|
A sequence defining the order in which vocabulary segments should be concatenated. This determines the mapping of global indices to specific heads. |
required |
hidden_size
|
int
|
The input dimensionality (hidden state size). |
required |
Source code in d9d/module/block/head/language_modelling.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | |
forward(hidden_states, labels)
Computes the cross-entropy loss for the given hidden states and labels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
Input tensor of shape |
required |
labels
|
Tensor
|
Target label tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
A tensor containing per-token loss values (reduction='none'), matching the |
Tensor
|
shape of the labels tensor. |
Source code in d9d/module/block/head/language_modelling.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | |
reset_parameters()
Resets module parameters.
Source code in d9d/module/block/head/language_modelling.py
82 83 84 85 86 | |