Warning:
If you are utilizing the standard d9d training infrastructure, you do not need to manage pipeline states manually. The framework automatically handles this. This package is primarily intended for users extending d9d.
About
The d9d.internals.pipeline_state package provides a unified mechanism to manage data lifecycle within a training step. It specifically addresses the complexity of transitioning between the Global Context (an entire training step/batch) and the Sharded Context (partial execution, i.e. within pipeline parallel loss computation).
For instance, a typical data flow in a pipelined step is:
- Prepare the data using a global view.
- Compute loss value for a microbatch, it now requires to create a sharded view of the data.
- Log metrics, using a global view again.
PipelineState abstracts the slicing (Global -> Sharded) and aggregation (Sharded -> Global) operations behind a simple dictionary-like interface, allowing the training loop to act as a seamless bridge between these two contexts.
d9d.internals.pipeline_state
Pipeline State management package.
This package provides mechanisms to store, retrieve, and synchronize state across different stages of a distributed pipeline, providing global and sharded view for these states.
PipelineState
Bases: ABC
Object representing the state of a pipeline.
This class defines the interface for accessing state variables like a dictionary, abstracting away whether the underlying storage is local, sharded, or global.
Source code in d9d/internals/pipeline_state/api.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | |
__contains__(item)
abstractmethod
Checks if a key exists in the state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
str
|
The identifier to check. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if the key exists, False otherwise. |
Source code in d9d/internals/pipeline_state/api.py
35 36 37 38 39 40 41 42 43 44 45 | |
__getitem__(item)
abstractmethod
Retrieves a state value for a given key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
item
|
str
|
The identifier for the state variable. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The value associated with the key. |
Source code in d9d/internals/pipeline_state/api.py
23 24 25 26 27 28 29 30 31 32 33 | |
__setitem__(key, value)
abstractmethod
Sets a state value for a given key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
The identifier for the state variable. |
required |
value
|
Any
|
The value to store. |
required |
Source code in d9d/internals/pipeline_state/api.py
13 14 15 16 17 18 19 20 21 | |
PipelineStateHandler
Manages the lifecycle and access patterns of pipeline states.
This handler initializes the underlying storage and provides specific views (global or sharded) into that storage.
Source code in d9d/internals/pipeline_state/handler.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | |
__init__(sharding_spec, num_shards)
Constructs a PipelineStateHandler object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sharding_spec
|
dict[str, ShardingSpecLeaf]
|
A definition of how specific keys should be sharded. |
required |
num_shards
|
int
|
The total number of shards in the pipeline. |
required |
Source code in d9d/internals/pipeline_state/handler.py
69 70 71 72 73 74 75 76 77 78 79 80 81 | |
global_state()
Returns a view interface for accessing global state.
Returns:
| Type | Description |
|---|---|
PipelineState
|
A PipelineState interface that accesses the full, aggregated data. |
Source code in d9d/internals/pipeline_state/handler.py
83 84 85 86 87 88 89 90 91 | |
reset()
Resets the underlying storage, clearing all state.
Source code in d9d/internals/pipeline_state/handler.py
106 107 108 109 110 111 | |
sharded_state(shard_id)
Returns a view interface for accessing state specific to a shard ID.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shard_id
|
int
|
The index of the shard to access. |
required |
Returns:
| Type | Description |
|---|---|
PipelineState
|
A PipelineState interface that accesses partial data for the given shard. |
Source code in d9d/internals/pipeline_state/handler.py
93 94 95 96 97 98 99 100 101 102 103 104 | |