Skip to content

Pipeline State Management

Internal API Warning

If you are utilizing the standard d9d training infrastructure, you do not need to manage pipeline states manually. The framework automatically handles this. This package is primarily intended for users extending d9d.

About

The d9d.internals.pipeline_state package provides a unified mechanism to manage data lifecycle within a training step. It specifically addresses the complexity of transitioning between the Global Context (an entire training step/batch) and the Sharded Context (partial execution, i.e. within pipeline parallel loss computation).

For instance, a typical data flow in a pipelined step is:

  1. Prepare the data using a global view.
  2. Compute loss value for a microbatch, it now requires to create a sharded view of the data.
  3. Log metrics, using a global view again.

PipelineState abstracts the slicing (Global -> Sharded) and aggregation (Sharded -> Global) operations behind a simple dictionary-like interface, allowing the training loop to act as a seamless bridge between these two contexts.

d9d.internals.pipeline_state

Pipeline State management package.

This package provides mechanisms to store, retrieve, and synchronize state across different stages of a distributed pipeline, providing global and sharded view for these states.

PipelineState

Bases: ABC

Object representing the state of a pipeline.

This class defines the interface for accessing state variables like a dictionary, abstracting away whether the underlying storage is local, sharded, or global.

__contains__(item) abstractmethod

Checks if a key exists in the state.

Parameters:

Name Type Description Default
item str

The identifier to check.

required

Returns:

Type Description
bool

True if the key exists, False otherwise.

__getitem__(item) abstractmethod

Retrieves a state value for a given key.

Parameters:

Name Type Description Default
item str

The identifier for the state variable.

required

Returns:

Type Description
Any

The value associated with the key.

__setitem__(key, value) abstractmethod

Sets a state value for a given key.

Parameters:

Name Type Description Default
key str

The identifier for the state variable.

required
value Any

The value to store.

required

PipelineStateHandler

Manages the lifecycle and access patterns of pipeline states.

This handler initializes the underlying storage and provides specific views (global or sharded) into that storage.

__init__(sharding_spec, num_shards)

Constructs a PipelineStateHandler object.

Parameters:

Name Type Description Default
sharding_spec dict[str, ShardingSpecLeaf]

A definition of how specific keys should be sharded.

required
num_shards int

The total number of shards in the pipeline.

required

global_state()

Returns a view interface for accessing global state.

Returns:

Type Description
PipelineState

A PipelineState interface that accesses the full, aggregated data.

reset()

Resets the underlying storage, clearing all state.

sharded_state(shard_id)

Returns a view interface for accessing state specific to a shard ID.

Parameters:

Name Type Description Default
shard_id int

The index of the shard to access.

required

Returns:

Type Description
PipelineState

A PipelineState interface that accesses partial data for the given shard.