Model Heads
About
The d9d.module.block.head package handles the model heads.
Features
Causal Language Modelling
SplitLanguageModellingHead provides a causal language modelling head that computes per-token logprobs.
It uses efficient fused Linear-Cross-Entropy kernel from the Cut-Cross-Entropy project and avoids full logit tensor materialization.
Supports vocab split to multiple independent splits following the SplitTokenEmbeddings embedding implementation.
d9d.module.block.head
LM_IGNORE_INDEX = -100
module-attribute
Index ignored by LM head while calculating logps
ClassificationHead
Bases: Module, ModuleLateInit
A classification head module that is typically used on top of model hidden states.
It applies dropout followed by a linear projection to produce logits for a specified number of classes. It supports optional pooling via a mask, allowing for selection of specific tokens (e.g., [CLS] tokens or specific sequence positions) before projection.
__init__(hidden_size, num_labels, dropout)
forward(hidden_states, pooling_mask)
Computes class logits from hidden states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
Input tensor of hidden states. |
required |
pooling_mask
|
Tensor | None
|
Optional mask to select specific hidden states.
If provided, the input is indexed as |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
A tensor containing the unnormalized logits. |
reset_parameters()
Resets module parameters.
SplitLanguageModellingHead
Bases: Module, ModuleLateInit
A segmented language modeling head that computes per-token cross-entropy loss values using a composed weight matrix.
This class maintains separate linear layers for different segments of the vocabulary (e.g., regular vs. special tokens). During the forward pass, it concatenates the weights to form a unified projection matrix and computes the cross-entropy loss efficiently, typically using a fused kernel to avoid materializing full logits.
The concatenation order of the weights is determined by split_order, which ensures
consistency with the global vocabulary indices.
__init__(split_vocab_size, split_order, hidden_size)
Constructs the SplitLanguageModellingHead object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
split_vocab_size
|
dict[str, int]
|
A dictionary mapping split names to their output vocabulary sizes. |
required |
split_order
|
Sequence[str]
|
A sequence defining the order in which vocabulary segments should be concatenated. This determines the mapping of global indices to specific heads. |
required |
hidden_size
|
int
|
The input dimensionality (hidden state size). |
required |
forward(hidden_states, labels)
Computes the cross-entropy loss for the given hidden states and labels.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
Input tensor of shape |
required |
labels
|
Tensor
|
Target label tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
A tensor containing per-token loss values (reduction='none'), matching the |
Tensor
|
shape of the labels tensor. |
reset_parameters()
Resets module parameters.