Pipeline State Management
Internal API Warning
If you are utilizing the standard d9d training infrastructure, you do not need to manage pipeline states manually. The framework automatically handles this. This package is primarily intended for users extending d9d.
About
The d9d.internals.pipeline_state package provides a unified mechanism to manage data lifecycle within a training step. It specifically addresses the complexity of transitioning between the Global Context (an entire training step/batch) and the Sharded Context (partial execution, i.e. within pipeline parallel loss computation).
For instance, a typical data flow in a pipelined step is:
- Prepare the data using a global view.
- Compute loss value for a microbatch, it now requires to create a sharded view of the data.
- Log metrics, using a global view again.
PipelineState abstracts the slicing (Global -> Sharded) and aggregation (Sharded -> Global) operations behind a simple dictionary-like interface, allowing the training loop to act as a seamless bridge between these two contexts.
d9d.internals.pipeline_state
Pipeline State management package.
This package provides mechanisms to store, retrieve, and synchronize state across different stages of a distributed pipeline, providing global and sharded view for these states.
PipelineState
Bases: ABC
Object representing the state of a pipeline.
This class defines the interface for accessing state variables like a dictionary, abstracting away whether the underlying storage is local, sharded, or global.
__contains__(item)
abstractmethod
__getitem__(item)
abstractmethod
PipelineStateHandler
Manages the lifecycle and access patterns of pipeline states.
This handler initializes the underlying storage and provides specific views (global or sharded) into that storage.
__init__(sharding_spec, num_shards)
global_state()
Returns a view interface for accessing global state.
Returns:
| Type | Description |
|---|---|
PipelineState
|
A PipelineState interface that accesses the full, aggregated data. |
reset()
Resets the underlying storage, clearing all state.
sharded_state(shard_id)
Returns a view interface for accessing state specific to a shard ID.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shard_id
|
int
|
The index of the shard to access. |
required |
Returns:
| Type | Description |
|---|---|
PipelineState
|
A PipelineState interface that accesses partial data for the given shard. |