Skip to content

Model Heads

About

The d9d.module.block.head package handles the model heads.

Features

Causal Language Modelling

SplitLanguageModellingHead provides a causal language modelling head that computes per-token logprobs.

It uses efficient fused Linear-Cross-Entropy kernel from the Cut-Cross-Entropy project and avoids full logit tensor materialization.

Supports vocab split to multiple independent splits following the SplitTokenEmbeddings embedding implementation.

d9d.module.block.head

LM_IGNORE_INDEX = -100 module-attribute

Index ignored by LM head while calculating logps

ClassificationHead

Bases: Module, ModuleLateInit

A classification head module that is typically used on top of model hidden states.

It applies dropout followed by a linear projection to produce logits for a specified number of classes. It supports optional pooling via a mask, allowing for selection of specific tokens (e.g., [CLS] tokens or specific sequence positions) before projection.

__init__(hidden_size, num_labels, dropout)

Constructs the ClassificationHead object.

Parameters:

Name Type Description Default
hidden_size int

The input dimensionality (hidden state size).

required
num_labels int

The number of output classes.

required
dropout float

The dropout probability.

required

forward(hidden_states, pooling_mask)

Computes class logits from hidden states.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor of hidden states.

required
pooling_mask Tensor | None

Optional mask to select specific hidden states. If provided, the input is indexed as hidden_states[pooling_mask == 1], flattening the batch and sequence dimensions into a single dimension of selected tokens.

required

Returns:

Type Description
Tensor

A tensor containing the unnormalized logits.

reset_parameters()

Resets module parameters.

SplitLanguageModellingHead

Bases: Module, ModuleLateInit

A segmented language modeling head that computes per-token cross-entropy loss values using a composed weight matrix.

This class maintains separate linear layers for different segments of the vocabulary (e.g., regular vs. special tokens). During the forward pass, it concatenates the weights to form a unified projection matrix and computes the cross-entropy loss efficiently, typically using a fused kernel to avoid materializing full logits.

The concatenation order of the weights is determined by split_order, which ensures consistency with the global vocabulary indices.

__init__(split_vocab_size, split_order, hidden_size)

Constructs the SplitLanguageModellingHead object.

Parameters:

Name Type Description Default
split_vocab_size dict[str, int]

A dictionary mapping split names to their output vocabulary sizes.

required
split_order Sequence[str]

A sequence defining the order in which vocabulary segments should be concatenated. This determines the mapping of global indices to specific heads.

required
hidden_size int

The input dimensionality (hidden state size).

required

forward(hidden_states, labels)

Computes the cross-entropy loss for the given hidden states and labels.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor of shape (B, S, H).

required
labels Tensor

Target label tensor of shape (B, S). Indices must correspond to the global vocabulary formed by concatenating splits in split_order.

required

Returns:

Type Description
Tensor

A tensor containing per-token loss values (reduction='none'), matching the

Tensor

shape of the labels tensor.

reset_parameters()

Resets module parameters.