Warning:

If you are utilizing the standard d9d training infrastructure, you do not need to manage pipeline states manually. The framework automatically handles this. This package is primarily intended for users extending d9d.

About

The d9d.internals.pipeline_state package provides a unified mechanism to manage data lifecycle within a training step. It specifically addresses the complexity of transitioning between the Global Context (an entire training step/batch) and the Sharded Context (partial execution, i.e. within pipeline parallel loss computation).

For instance, a typical data flow in a pipelined step is:

  1. Prepare the data using a global view.
  2. Compute loss value for a microbatch, it now requires to create a sharded view of the data.
  3. Log metrics, using a global view again.

PipelineState abstracts the slicing (Global -> Sharded) and aggregation (Sharded -> Global) operations behind a simple dictionary-like interface, allowing the training loop to act as a seamless bridge between these two contexts.

d9d.internals.pipeline_state

Pipeline State management package.

This package provides mechanisms to store, retrieve, and synchronize state across different stages of a distributed pipeline, providing global and sharded view for these states.

PipelineState

Bases: ABC

Object representing the state of a pipeline.

This class defines the interface for accessing state variables like a dictionary, abstracting away whether the underlying storage is local, sharded, or global.

Source code in d9d/internals/pipeline_state/api.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class PipelineState(abc.ABC):
    """
    Object representing the state of a pipeline.

    This class defines the interface for accessing state variables like a dictionary,
    abstracting away whether the underlying storage is local, sharded, or global.
    """

    @abc.abstractmethod
    def __setitem__(self, key: str, value: Any):
        """
        Sets a state value for a given key.

        Args:
            key: The identifier for the state variable.
            value: The value to store.
        """

    @abc.abstractmethod
    def __getitem__(self, item: str) -> Any:
        """
        Retrieves a state value for a given key.

        Args:
            item: The identifier for the state variable.

        Returns:
            The value associated with the key.
        """

    @abc.abstractmethod
    def __contains__(self, item: str) -> bool:
        """
        Checks if a key exists in the state.

        Args:
            item: The identifier to check.

        Returns:
            True if the key exists, False otherwise.
        """

__contains__(item) abstractmethod

Checks if a key exists in the state.

Parameters:

Name Type Description Default
item str

The identifier to check.

required

Returns:

Type Description
bool

True if the key exists, False otherwise.

Source code in d9d/internals/pipeline_state/api.py
35
36
37
38
39
40
41
42
43
44
45
@abc.abstractmethod
def __contains__(self, item: str) -> bool:
    """
    Checks if a key exists in the state.

    Args:
        item: The identifier to check.

    Returns:
        True if the key exists, False otherwise.
    """

__getitem__(item) abstractmethod

Retrieves a state value for a given key.

Parameters:

Name Type Description Default
item str

The identifier for the state variable.

required

Returns:

Type Description
Any

The value associated with the key.

Source code in d9d/internals/pipeline_state/api.py
23
24
25
26
27
28
29
30
31
32
33
@abc.abstractmethod
def __getitem__(self, item: str) -> Any:
    """
    Retrieves a state value for a given key.

    Args:
        item: The identifier for the state variable.

    Returns:
        The value associated with the key.
    """

__setitem__(key, value) abstractmethod

Sets a state value for a given key.

Parameters:

Name Type Description Default
key str

The identifier for the state variable.

required
value Any

The value to store.

required
Source code in d9d/internals/pipeline_state/api.py
13
14
15
16
17
18
19
20
21
@abc.abstractmethod
def __setitem__(self, key: str, value: Any):
    """
    Sets a state value for a given key.

    Args:
        key: The identifier for the state variable.
        value: The value to store.
    """

PipelineStateHandler

Manages the lifecycle and access patterns of pipeline states.

This handler initializes the underlying storage and provides specific views (global or sharded) into that storage.

Source code in d9d/internals/pipeline_state/handler.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class PipelineStateHandler:
    """
    Manages the lifecycle and access patterns of pipeline states.

    This handler initializes the underlying storage and provides specific views
    (global or sharded) into that storage.
    """

    def __init__(self, sharding_spec: dict[str, ShardingSpecLeaf], num_shards: int):
        """
        Constructs a PipelineStateHandler object.

        Args:
            sharding_spec: A definition of how specific keys should be sharded.
            num_shards: The total number of shards in the pipeline.
        """

        self._storage = PipelineStateStorage(
            sharding_spec={(k,): v for k, v in sharding_spec.items()},
            num_shards=num_shards
        )

    def global_state(self) -> PipelineState:
        """
        Returns a view interface for accessing global state.

        Returns:
            A PipelineState interface that accesses the full, aggregated data.
        """

        return PipelineStateGlobal(self._storage)

    def sharded_state(self, shard_id: int) -> PipelineState:
        """
        Returns a view interface for accessing state specific to a shard ID.

        Args:
            shard_id: The index of the shard to access.

        Returns:
            A PipelineState interface that accesses partial data for the given shard.
        """

        return PipelineStateShard(self._storage, shard_id)

    def reset(self):
        """
        Resets the underlying storage, clearing all state.
        """

        self._storage.reset()

__init__(sharding_spec, num_shards)

Constructs a PipelineStateHandler object.

Parameters:

Name Type Description Default
sharding_spec dict[str, ShardingSpecLeaf]

A definition of how specific keys should be sharded.

required
num_shards int

The total number of shards in the pipeline.

required
Source code in d9d/internals/pipeline_state/handler.py
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(self, sharding_spec: dict[str, ShardingSpecLeaf], num_shards: int):
    """
    Constructs a PipelineStateHandler object.

    Args:
        sharding_spec: A definition of how specific keys should be sharded.
        num_shards: The total number of shards in the pipeline.
    """

    self._storage = PipelineStateStorage(
        sharding_spec={(k,): v for k, v in sharding_spec.items()},
        num_shards=num_shards
    )

global_state()

Returns a view interface for accessing global state.

Returns:

Type Description
PipelineState

A PipelineState interface that accesses the full, aggregated data.

Source code in d9d/internals/pipeline_state/handler.py
83
84
85
86
87
88
89
90
91
def global_state(self) -> PipelineState:
    """
    Returns a view interface for accessing global state.

    Returns:
        A PipelineState interface that accesses the full, aggregated data.
    """

    return PipelineStateGlobal(self._storage)

reset()

Resets the underlying storage, clearing all state.

Source code in d9d/internals/pipeline_state/handler.py
106
107
108
109
110
111
def reset(self):
    """
    Resets the underlying storage, clearing all state.
    """

    self._storage.reset()

sharded_state(shard_id)

Returns a view interface for accessing state specific to a shard ID.

Parameters:

Name Type Description Default
shard_id int

The index of the shard to access.

required

Returns:

Type Description
PipelineState

A PipelineState interface that accesses partial data for the given shard.

Source code in d9d/internals/pipeline_state/handler.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def sharded_state(self, shard_id: int) -> PipelineState:
    """
    Returns a view interface for accessing state specific to a shard ID.

    Args:
        shard_id: The index of the shard to access.

    Returns:
        A PipelineState interface that accesses partial data for the given shard.
    """

    return PipelineStateShard(self._storage, shard_id)