About

The d9d.module.block.hidden_states_aggregator package provides interfaces and implementations for collecting, reducing, and managing model hidden states during execution.

This is particularly useful in pipelines where intermediate activations need to be analyzed or stored (e.g., for reward modeling, custom distillation objectives, or analysis) without keeping the entire raw tensor history in memory.

As an end user, you typically will instantiate an aggregator with a factory method create_hidden_states_aggregator.

Aggregators support a pack_with_snapshot mechanism. This allows combining currently collected states with a pre-existing "snapshot" tensor (historical data or from previous pipeline stages), facilitating state management in stateful or iterative loops.

Modes

HiddenStatesAggregationMode.noop

Acts as a "null"-aggregator.

HiddenStatesAggregationMode.mean

The Mean mode (HiddenStatesAggregationMode.mean) performs "eager" reduction. Instead of storing the full [Batch, Seq_Len, Hidden_Dim] tensors for every step, it:

  1. Takes an aggregation mask.
  2. Computes the masked average immediately upon receiving the hidden states.
  3. Stores only the reduced [Batch, Hidden_Dim] vectors.

This significantly reduces memory footprint when accumulating states over many iterations.

d9d.module.block.hidden_states_aggregator

Aggregation utilities for model hidden states.

BaseHiddenStatesAggregator

Bases: ABC

Abstract base class for hidden states aggregation strategies.

This interface defines how hidden states should be collected (added) and how they should be finalized (packed) combined with optional historical snapshots.

Source code in d9d/module/block/hidden_states_aggregator/base.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class BaseHiddenStatesAggregator(abc.ABC):
    """Abstract base class for hidden states aggregation strategies.

    This interface defines how hidden states should be collected (added) and
    how they should be finalized (packed) combined with optional historical snapshots.
    """

    @abc.abstractmethod
    def add_hidden_states(self, hidden_states: torch.Tensor) -> None:
        """Accumulates a batch of hidden states into the aggregator.

        Args:
            hidden_states: The tensor containing the hidden states to process.
        """

    @abc.abstractmethod
    def pack_with_snapshot(self, snapshot: torch.Tensor | None) -> torch.Tensor | None:
        """Finalizes the aggregation and combines it with an optional previous snapshot.

        This method typically retrieves the accumulated states, processes them
        (if not done during addition), and concatenates them with the snapshot.

        Args:
            snapshot: An optional tensor representing previously aggregated states
                to be prepended to the current collection.

        Returns:
            The combined result of the snapshot and the newly aggregated states,
            or None if no states were collected.
        """

add_hidden_states(hidden_states) abstractmethod

Accumulates a batch of hidden states into the aggregator.

Parameters:

Name Type Description Default
hidden_states Tensor

The tensor containing the hidden states to process.

required
Source code in d9d/module/block/hidden_states_aggregator/base.py
13
14
15
16
17
18
19
@abc.abstractmethod
def add_hidden_states(self, hidden_states: torch.Tensor) -> None:
    """Accumulates a batch of hidden states into the aggregator.

    Args:
        hidden_states: The tensor containing the hidden states to process.
    """

pack_with_snapshot(snapshot) abstractmethod

Finalizes the aggregation and combines it with an optional previous snapshot.

This method typically retrieves the accumulated states, processes them (if not done during addition), and concatenates them with the snapshot.

Parameters:

Name Type Description Default
snapshot Tensor | None

An optional tensor representing previously aggregated states to be prepended to the current collection.

required

Returns:

Type Description
Tensor | None

The combined result of the snapshot and the newly aggregated states,

Tensor | None

or None if no states were collected.

Source code in d9d/module/block/hidden_states_aggregator/base.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@abc.abstractmethod
def pack_with_snapshot(self, snapshot: torch.Tensor | None) -> torch.Tensor | None:
    """Finalizes the aggregation and combines it with an optional previous snapshot.

    This method typically retrieves the accumulated states, processes them
    (if not done during addition), and concatenates them with the snapshot.

    Args:
        snapshot: An optional tensor representing previously aggregated states
            to be prepended to the current collection.

    Returns:
        The combined result of the snapshot and the newly aggregated states,
        or None if no states were collected.
    """

HiddenStatesAggregationMode

Bases: StrEnum

Enumeration of available hidden state aggregation strategies.

Attributes:

Name Type Description
no

Performs no aggregation (No-Op).

mean

Computes the mean of hidden states, taking a mask into account.

Source code in d9d/module/block/hidden_states_aggregator/factory.py
10
11
12
13
14
15
16
17
18
19
class HiddenStatesAggregationMode(StrEnum):
    """Enumeration of available hidden state aggregation strategies.

    Attributes:
        no: Performs no aggregation (No-Op).
        mean: Computes the mean of hidden states, taking a mask into account.
    """

    no = "no"
    mean = "mean"

create_hidden_states_aggregator(mode, agg_mask)

Factory function to create a hidden states aggregator.

Parameters:

Name Type Description Default
mode HiddenStatesAggregationMode

The specific aggregation mode to instantiate.

required
agg_mask Tensor | None

A tensor mask required for specific modes. Can be None if the selected mode does not require masking.

required

Returns:

Type Description
BaseHiddenStatesAggregator

An instance of a concrete BaseHiddenStatesAggregator subclass.

Raises:

Type Description
ValueError

If 'mean' mode is selected but 'agg_mask' is None, or if an unknown mode is provided.

Source code in d9d/module/block/hidden_states_aggregator/factory.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def create_hidden_states_aggregator(
        mode: HiddenStatesAggregationMode, agg_mask: torch.Tensor | None
) -> BaseHiddenStatesAggregator:
    """Factory function to create a hidden states aggregator.

    Args:
        mode: The specific aggregation mode to instantiate.
        agg_mask: A tensor mask required for specific modes.
            Can be None if the selected mode does not require masking.

    Returns:
        An instance of a concrete BaseHiddenStatesAggregator subclass.

    Raises:
        ValueError: If 'mean' mode is selected but 'agg_mask' is None, or if
            an unknown mode is provided.
    """

    match mode:
        case HiddenStatesAggregationMode.no:
            return HiddenStatesAggregatorNoOp()
        case HiddenStatesAggregationMode.mean:
            if agg_mask is None:
                raise ValueError("You have to specify aggregation mask")
            return HiddenStatesAggregatorMean(agg_mask)
        case _:
            raise ValueError("Unknown hidden states aggregation mode")