About
The d9d.module.block.hidden_states_aggregator package provides interfaces and implementations for collecting, reducing, and managing model hidden states during execution.
This is particularly useful in pipelines where intermediate activations need to be analyzed or stored (e.g., for reward modeling, custom distillation objectives, or analysis) without keeping the entire raw tensor history in memory.
As an end user, you typically will instantiate an aggregator with a factory method create_hidden_states_aggregator.
Aggregators support a pack_with_snapshot mechanism. This allows combining currently collected states with a pre-existing "snapshot" tensor (historical data or from previous pipeline stages), facilitating state management in stateful or iterative loops.
Modes
HiddenStatesAggregationMode.noop
Acts as a "null"-aggregator.
HiddenStatesAggregationMode.mean
The Mean mode (HiddenStatesAggregationMode.mean) performs "eager" reduction. Instead of storing the full [Batch, Seq_Len, Hidden_Dim] tensors for every step, it:
- Takes an aggregation mask.
- Computes the masked average immediately upon receiving the hidden states.
- Stores only the reduced
[Batch, Hidden_Dim]vectors.
This significantly reduces memory footprint when accumulating states over many iterations.
d9d.module.block.hidden_states_aggregator
Aggregation utilities for model hidden states.
BaseHiddenStatesAggregator
Bases: ABC
Abstract base class for hidden states aggregation strategies.
This interface defines how hidden states should be collected (added) and how they should be finalized (packed) combined with optional historical snapshots.
Source code in d9d/module/block/hidden_states_aggregator/base.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | |
add_hidden_states(hidden_states)
abstractmethod
Accumulates a batch of hidden states into the aggregator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
The tensor containing the hidden states to process. |
required |
Source code in d9d/module/block/hidden_states_aggregator/base.py
13 14 15 16 17 18 19 | |
pack_with_snapshot(snapshot)
abstractmethod
Finalizes the aggregation and combines it with an optional previous snapshot.
This method typically retrieves the accumulated states, processes them (if not done during addition), and concatenates them with the snapshot.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
snapshot
|
Tensor | None
|
An optional tensor representing previously aggregated states to be prepended to the current collection. |
required |
Returns:
| Type | Description |
|---|---|
Tensor | None
|
The combined result of the snapshot and the newly aggregated states, |
Tensor | None
|
or None if no states were collected. |
Source code in d9d/module/block/hidden_states_aggregator/base.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | |
HiddenStatesAggregationMode
Bases: StrEnum
Enumeration of available hidden state aggregation strategies.
Attributes:
| Name | Type | Description |
|---|---|---|
no |
Performs no aggregation (No-Op). |
|
mean |
Computes the mean of hidden states, taking a mask into account. |
Source code in d9d/module/block/hidden_states_aggregator/factory.py
10 11 12 13 14 15 16 17 18 19 | |
create_hidden_states_aggregator(mode, agg_mask)
Factory function to create a hidden states aggregator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
HiddenStatesAggregationMode
|
The specific aggregation mode to instantiate. |
required |
agg_mask
|
Tensor | None
|
A tensor mask required for specific modes. Can be None if the selected mode does not require masking. |
required |
Returns:
| Type | Description |
|---|---|
BaseHiddenStatesAggregator
|
An instance of a concrete BaseHiddenStatesAggregator subclass. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If 'mean' mode is selected but 'agg_mask' is None, or if an unknown mode is provided. |
Source code in d9d/module/block/hidden_states_aggregator/factory.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | |