Skip to content

Hidden States Aggregation

About

The d9d.module.block.hidden_states_aggregator package provides interfaces and implementations for collecting, reducing, and managing model hidden states during execution.

This is particularly useful in pipelines where intermediate activations need to be analyzed or stored (e.g., for reward modeling, custom distillation objectives, or analysis) without keeping the entire raw tensor history in memory.

As an end user, you typically will instantiate an aggregator with a factory method create_hidden_states_aggregator.

Aggregators support a pack_with_snapshot mechanism. This allows combining currently collected states with a pre-existing "snapshot" tensor (historical data or from previous pipeline stages), facilitating state management in stateful or iterative loops.

Modes

HiddenStatesAggregationMode.noop

Acts as a "null"-aggregator.

HiddenStatesAggregationMode.mean

The Mean mode (HiddenStatesAggregationMode.mean) performs "eager" reduction. Instead of storing the full [Batch, Seq_Len, Hidden_Dim] tensors for every step, it:

  1. Takes an aggregation mask.
  2. Computes the masked average immediately upon receiving the hidden states.
  3. Stores only the reduced [Batch, Hidden_Dim] vectors.

This significantly reduces memory footprint when accumulating states over many iterations.

d9d.module.block.hidden_states_aggregator

Aggregation utilities for model hidden states.

BaseHiddenStatesAggregator

Bases: ABC

Abstract base class for hidden states aggregation strategies.

This interface defines how hidden states should be collected (added) and how they should be finalized (packed) combined with optional historical snapshots.

add_hidden_states(hidden_states) abstractmethod

Accumulates a batch of hidden states into the aggregator.

Parameters:

Name Type Description Default
hidden_states Tensor

The tensor containing the hidden states to process.

required

pack_with_snapshot(snapshot) abstractmethod

Finalizes the aggregation and combines it with an optional previous snapshot.

This method typically retrieves the accumulated states, processes them (if not done during addition), and concatenates them with the snapshot.

Parameters:

Name Type Description Default
snapshot Tensor | None

An optional tensor representing previously aggregated states to be prepended to the current collection.

required

Returns:

Type Description
Tensor | None

The combined result of the snapshot and the newly aggregated states,

Tensor | None

or None if no states were collected.

HiddenStatesAggregationMode

Bases: StrEnum

Enumeration of available hidden state aggregation strategies.

Attributes:

Name Type Description
no

Performs no aggregation (No-Op).

mean

Computes the mean of hidden states, taking a mask into account.

create_hidden_states_aggregator(mode, agg_mask)

Factory function to create a hidden states aggregator.

Parameters:

Name Type Description Default
mode HiddenStatesAggregationMode

The specific aggregation mode to instantiate.

required
agg_mask Tensor | None

A tensor mask required for specific modes. Can be None if the selected mode does not require masking.

required

Returns:

Type Description
BaseHiddenStatesAggregator

An instance of a concrete BaseHiddenStatesAggregator subclass.

Raises:

Type Description
ValueError

If 'mean' mode is selected but 'agg_mask' is None, or if an unknown mode is provided.