The d9d Approach

d9d implements a modern, highly modular pipelining engine designed for performance, stability and customization.

Dynamic Shapes & Algorithmic Shape Inference

To run P2P (Point-to-Point) communication, the receiver must know the shape of the incoming tensor to pre-allocate buffers. d9d asks your model to implement a lightweight protocol (ModuleSupportsPipelining) to calculate stage input and output shapes from batch input shapes mathematically, without performing a heavy forward pass or doing a distributed graph tracing.

This allows supporting Dynamic Shapes (e.g., varying sequence lengths) efficiently across runs.

Construction Consistency (No Patching)

A common anti-pattern in distributed training is "Instantiate-then-Delete": creating a huge model on CPU/Meta device and then hacking it apart del model.layers[N:].

We reject this pattern because of:

  1. Fragility: Changes to model architecture require changes to the external slicing script.
  2. Leaky Abstractions: Forward methods become full of if self.layer is not None.
  3. Invalid States: The model object exists in a "zombie" state until sliced.

In d9d, models are Pipeline-Aware. Each pipeline rank constructs only the sub-graph it owns. The object returned is compliant, complete, and valid immediately.

Making Models Compatible

The Protocol

Implementing the Protocol

To use Pipeline Parallelism in d9d, your model must implement the d9d.pipelining.api.ModuleSupportsPipelining protocol to allow the framework to manage memory and buffer allocations.

Forward Compatibility

  • Pipelined models currently only support outputting a dictionary (dict[str, torch.Tensor]). However, we plan to support arbitrary PyTrees in further releases. The keys in the dictionary returned by your forward method must strictly match the keys in the dictionary calculated by infer_stage_outputs_from_pipeline_inputs.
  • The named arguments accepted by your forward method must strictly match the infer_stage_inputs_from_pipeline_inputs.

This allows the communication handler to map tensor names to P2P buffers deterministically.

Example

Below is a skeleton of a Transformer-like model implemented for d9d pipelining.

import torch
from torch import nn
from d9d.pipelining.api import PipelineStageInfo, distribute_layers_for_pipeline_stage

class MyModelChunk(nn.Module):
    def __init__(self, stage: PipelineStageInfo, config):
        super().__init__()
        self.stage = stage
        self.config = config

        # 1. Determine what layers live here
        self.start_layer, self.end_layer = distribute_layers_for_pipeline_stage(
            config.n_layers, num_virtual_layers_pre=1, num_virtual_layers_post=1, stage=stage
        )

        # 2. Build sub-modules (using ModuleDict - for compatibility)
        self.layers = nn.ModuleDict({
            str(layer): TransformerBlock(...) 
            for layer in range(self.start_layer, self.end_layer)
        })

        # Only build embeddings on first stage
        if stage.is_current_stage_first:
            self.embed = nn.Embedding(...)

        # Only build head on last stage
        if stage.is_current_stage_last:
            self.head = nn.Linear(...)

    def forward(self, input_ids=None, hidden_states=None):        
        # Run embeddings only on first stage
        if self.stage.is_current_stage_first:
            x = self.embed(input_ids)
        else:
            x = hidden_states

        # Run local layers
        for layer_idx in range(self.start_layer, self.end_layer):
            x = self.layers[str(layer_idx)](x)

        outputs = {
            "hidden_states": x
        }

        # Last stage logic
        if self.stage.is_current_stage_last:
            logits = self.head(x)
            outputs['logits'] = logits

        return outputs

    # --- Protocol Implementation ---

    def infer_stage_inputs_from_pipeline_inputs(self, inputs: dict[str, torch.Tensor], n_microbatches: int):
        batch_size = inputs['input_ids'].shape[0]
        micro_batch_size = batch_size // n_microbatches
        seq_len = inputs['input_ids'].shape[1]

        if self.stage.is_current_stage_first:
            # First stage receives raw input IDs
            return {"input_ids": torch.empty((micro_batch_size, seq_len), dtype=torch.long)}
        else:
            # Intermediate stages receive hidden states from previous stage
            return {"hidden_states": torch.empty((micro_batch_size, seq_len, self.hidden_dim))}

    def infer_stage_outputs_from_pipeline_inputs(self, inputs: dict[str, torch.Tensor], n_microbatches: int):
        batch_size = inputs['input_ids'].shape[0]
        micro_batch_size = batch_size // n_microbatches
        seq_len = inputs['input_ids'].shape[1]

        outputs = {"hidden_states": torch.empty((micro_batch_size, seq_len, self.config.hidden_dim))}

        if self.stage.is_current_stage_last:
            # Last stage outputs logits too
            outputs["logits"] = torch.empty((micro_batch_size, seq_len, self.config.vocab_size))

        return outputs

Using the Pipeline

Supported Schedules

Example JSON Description
{"schedule": "inference"} Configuration for inference-only pipeline execution. Runs all forward passes sequentially without any backward passes.
{"schedule": "gpipe"} Standard GPipe execution. Assumes a single stage per rank and processes all microbatches for the forward pass before switching to the backward pass.
{"schedule": "looped_bfs", "num_stages_per_rank": 2} Looped Breadth-First Search execution. Supports multiple stages per rank (virtualization) and executes all work for a specific stage before moving to the next.
{"schedule": "1f1b", "num_stages_per_rank": 1, "zero_bubble": true} Interleaved 1F1B and Interleaved Zero Bubble execution. Supports multiple stages per rank. Handles sharding backward passes to dI and dW when zero_bubble is enabled.
{"schedule": "zero_bubble_v"} Zero Bubble V (ZBV) execution. A specialized V-shape topology schedule that splits backward passes into Input and Weight gradients. Requires exactly 2 stages per rank.
{"schedule": "dual_pipe_v"} DualPipeV execution. A bidirectional pipeline schedule for high-throughput training using V-shape topology and reciprocal forward/backward scheduling.

Batch Sharding

Pipelining works by splitting the input batch into N microbatches. By default, d9d assumes all input and output tensors should be split along dimension 0.

However, if your inputs require different sharding strategy, you can customize this via PipelineShardingSpec.

Please see the sharding utils docs.

from d9d.pipelining.api import PipelineShardingSpec
from d9d.core.sharding import ShardingSpec
from torch.distributed.tensor import Shard

# Example: Split 'images' on dim 1, but replicate 'camera_angles' across all microbatches
my_spec = PipelineShardingSpec(
    input_data={
        "images": Shard(1),
        "camera_angles": None
    }
    # input_kwargs can be defined similarly
)

Usage within the Trainer

Pipelining is available in the Trainer framework. When configuring the Trainer, simply provide an AnyPipelineScheduleConfig in your training arguments. The Trainer handles the construction of the schedule and the distribution of layers automatically.

Advanced - Manual Usage

If you want to use pipelining outside the Trainer (e.g., custom loops), you use the build_schedule factory.

The build_schedule function requires a Model Provider logic. Instead of passing an instantiated model, you pass a function that accepts PipelineStageInfo and returns the nn.Module for that stage. This ensures construction consistency.

from torch import Tensor
from torch.distributed.tensor import Shard
import torch.nn.functional as F

from d9d.core.dist_context import DistributedContext
from d9d.core.sharding import shard_tree
from d9d.pipelining.factory import build_schedule, PipelineSchedule1F1BConfig
from d9d.pipelining.api import PipelineShardingSpec


# 0. Define an object that manages loss calculation across steps
class PipelineLossHandler:
    def __init__(self, num_microbatches: int):
        self._shard_spec = {
            'target': Shard(0)
        }
        self._num_microbatches = num_microbatches
        self._targets = None

    def set_targets(self, targets: Tensor):
        self._targets = shard_tree(
            {'target': targets},
            sharding_spec=self._shard_spec,
            num_shards=self._num_microbatches,
            enforce_even_split=True
        )

    def compute_loss(self, outputs: dict[str, Tensor], microbatch_idx: int):
        # Implement any custom logic here
        current_target = self._targets[microbatch_idx]
        return F.cross_entropy(outputs['logits'].view(-1, outputs['logits'].shape[-1]), current_target.view(-1))


# 1. Define configuration
dist_context: DistributedContext = ...
model_config = ...
n_microbatches = 32
schedule_config = PipelineSchedule1F1BConfig(
    num_stages_per_rank=4,  # 4 Virtual stages per rank
    zero_bubble=True  # Enable ZB1P optimization
)

# 2. Build the schedule, model shards and loss compute function
loss_handler = PipelineLossHandler(num_microbatches=n_microbatches)
schedule_info, modules = build_schedule(
    dist_context=dist_context,
    n_microbatches=32,
    schedule_config=schedule_config,
    model_provider=lambda stage: MyModelChunk(stage, model_config),  # Factory function
    loss_fn=loss_handler.compute_loss,
    sharding_spec=PipelineShardingSpec()  # Default sharding across dim 0
)

# 3. Execution
# The schedule object exposes a simple step API
inputs = {"input_ids": ...}  # Full batch
loss_handler.set_targets(...)  # Set targets for full batch
schedule_info.schedule.configure_buffers(inputs, kwargs={})  # Pre-allocate buffers
schedule_info.schedule.step(inputs, kwargs={})

d9d.pipelining.api

Pipelining API that is intended to be accessible by end user.

ModuleSupportsPipelining

Bases: Protocol

Protocol for modules that support pipeline parallelism metadata inference.

Classes implementing this protocol enable the framework to pre-calculate tensor shapes and types required for inter-stage communication (p2p) without executing the full forward pass.

Source code in d9d/pipelining/api/module.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@typing.runtime_checkable
class ModuleSupportsPipelining(typing.Protocol):
    """
    Protocol for modules that support pipeline parallelism metadata inference.

    Classes implementing this protocol enable the framework to pre-calculate
    tensor shapes and types required for inter-stage communication (p2p)
    without executing the full forward pass.
    """

    def infer_stage_inputs_from_pipeline_inputs(
            self, inputs: dict[str, torch.Tensor], n_microbatches: int
    ) -> dict[str, torch.Tensor]:
        """
        Infers the input tensors metadata for the current pipeline stage based on global batch inputs.

        Args:
            inputs: Global inputs for the pipeline.
            n_microbatches: Number of microbatches the global batch is split into.

        Returns:
            Dictionary of input tensors expected by this specific stage locally.
        """

        ...

    def infer_stage_outputs_from_pipeline_inputs(
            self, inputs: dict[str, torch.Tensor], n_microbatches: int
    ) -> dict[str, torch.Tensor]:
        """
        Infers the output tensors metadata for the current pipeline stage based on global batch inputs.

        Args:
            inputs: Global inputs for the pipeline (typically a batch).
            n_microbatches: Number of microbatches the global batch is split into.

        Returns:
            Dictionary of output tensors produced by this specific stage locally.
        """

        ...

infer_stage_inputs_from_pipeline_inputs(inputs, n_microbatches)

Infers the input tensors metadata for the current pipeline stage based on global batch inputs.

Parameters:

Name Type Description Default
inputs dict[str, Tensor]

Global inputs for the pipeline.

required
n_microbatches int

Number of microbatches the global batch is split into.

required

Returns:

Type Description
dict[str, Tensor]

Dictionary of input tensors expected by this specific stage locally.

Source code in d9d/pipelining/api/module.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def infer_stage_inputs_from_pipeline_inputs(
        self, inputs: dict[str, torch.Tensor], n_microbatches: int
) -> dict[str, torch.Tensor]:
    """
    Infers the input tensors metadata for the current pipeline stage based on global batch inputs.

    Args:
        inputs: Global inputs for the pipeline.
        n_microbatches: Number of microbatches the global batch is split into.

    Returns:
        Dictionary of input tensors expected by this specific stage locally.
    """

    ...

infer_stage_outputs_from_pipeline_inputs(inputs, n_microbatches)

Infers the output tensors metadata for the current pipeline stage based on global batch inputs.

Parameters:

Name Type Description Default
inputs dict[str, Tensor]

Global inputs for the pipeline (typically a batch).

required
n_microbatches int

Number of microbatches the global batch is split into.

required

Returns:

Type Description
dict[str, Tensor]

Dictionary of output tensors produced by this specific stage locally.

Source code in d9d/pipelining/api/module.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def infer_stage_outputs_from_pipeline_inputs(
        self, inputs: dict[str, torch.Tensor], n_microbatches: int
) -> dict[str, torch.Tensor]:
    """
    Infers the output tensors metadata for the current pipeline stage based on global batch inputs.

    Args:
        inputs: Global inputs for the pipeline (typically a batch).
        n_microbatches: Number of microbatches the global batch is split into.

    Returns:
        Dictionary of output tensors produced by this specific stage locally.
    """

    ...

PipelineSchedule

Bases: ABC

Abstract base class defining the interface for pipeline execution schedules.

Source code in d9d/pipelining/api/schedule.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class PipelineSchedule(abc.ABC):
    """Abstract base class defining the interface for pipeline execution schedules."""

    @abc.abstractmethod
    def configure_buffers(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
        """
        Configures internal state and buffers based on input shapes.

        This method allows the schedule to pre-allocate memory or setup sharding
        specifications based on the structure of the input data before execution begins.

        Args:
            inputs: A dictionary of input tensors.
            kwargs: A dictionary of keyword arguments.
        """

        ...

    @abc.abstractmethod
    def step(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
        """
        Executes a single pipeline step using the provided inputs.

         This typically involves distributing inputs across microbatches,
         executing forward and backward passes according to the specific schedule logic,
         and handling communications between stages.

         Args:
             inputs: A dictionary of global input tensors.
             kwargs: A dictionary of global keyword arguments.
         """

        ...

configure_buffers(inputs, kwargs) abstractmethod

Configures internal state and buffers based on input shapes.

This method allows the schedule to pre-allocate memory or setup sharding specifications based on the structure of the input data before execution begins.

Parameters:

Name Type Description Default
inputs dict[str, Tensor]

A dictionary of input tensors.

required
kwargs dict[str, Any]

A dictionary of keyword arguments.

required
Source code in d9d/pipelining/api/schedule.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@abc.abstractmethod
def configure_buffers(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
    """
    Configures internal state and buffers based on input shapes.

    This method allows the schedule to pre-allocate memory or setup sharding
    specifications based on the structure of the input data before execution begins.

    Args:
        inputs: A dictionary of input tensors.
        kwargs: A dictionary of keyword arguments.
    """

    ...

step(inputs, kwargs) abstractmethod

Executes a single pipeline step using the provided inputs.

This typically involves distributing inputs across microbatches, executing forward and backward passes according to the specific schedule logic, and handling communications between stages.

Args: inputs: A dictionary of global input tensors. kwargs: A dictionary of global keyword arguments.

Source code in d9d/pipelining/api/schedule.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@abc.abstractmethod
def step(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
    """
    Executes a single pipeline step using the provided inputs.

     This typically involves distributing inputs across microbatches,
     executing forward and backward passes according to the specific schedule logic,
     and handling communications between stages.

     Args:
         inputs: A dictionary of global input tensors.
         kwargs: A dictionary of global keyword arguments.
     """

    ...

PipelineStageInfo dataclass

Holds information about the current position within the distributed pipeline.

Attributes:

Name Type Description
current_stage int

The 0-based index of the current pipeline stage.

num_stages int

The total number of stages in the pipeline.

Source code in d9d/pipelining/api/module.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@dataclasses.dataclass
class PipelineStageInfo:
    """
    Holds information about the current position within the distributed pipeline.

    Attributes:
        current_stage: The 0-based index of the current pipeline stage.
        num_stages: The total number of stages in the pipeline.
    """

    current_stage: int
    num_stages: int

    @property
    def is_current_stage_first(self) -> bool:
        """
        Determines if this is the first stage in the pipeline.

        Returns:
            True if current_stage is 0.
        """

        return self.current_stage == 0

    @property
    def is_current_stage_last(self) -> bool:
        """
        Determines if this is the last stage in the pipeline.

        Returns:
            True if current_stage is the last index.
        """

        return self.current_stage == self.num_stages - 1

is_current_stage_first property

Determines if this is the first stage in the pipeline.

Returns:

Type Description
bool

True if current_stage is 0.

is_current_stage_last property

Determines if this is the last stage in the pipeline.

Returns:

Type Description
bool

True if current_stage is the last index.

distribute_layers_for_pipeline_stage(num_layers, num_virtual_layers_pre, num_virtual_layers_post, stage)

Calculates the layer index range for a specific pipeline stage.

This function distributes a given number of layers across multiple pipeline stages as evenly as possible. It accounts for additional, non-layer computational load on the first and last stages (e.g., embeddings and the LM head) by using the concept of 'virtual layers' to reserve capacity.

Parameters:

Name Type Description Default
num_layers int

The total number of primary model layers to be distributed (e.g., the transformer blocks).

required
num_virtual_layers_pre int

The number of 'virtual' layers representing the computational cost of modules on the first stage, before the main layers (e.g., token and positional embeddings).

required
num_virtual_layers_post int

The number of 'virtual' layers representing the computational cost of modules on the last stage, after the main layers (e.g., the final layer normalization and LM head).

required
stage PipelineStageInfo

An object containing total stages and current stage index.

required

Returns:

Type Description
tuple[int, int]

A tuple (start_index, end_index), representing the slice of layers for the given stage. The start_index is inclusive and the end_index is exclusive.

Raises:

Type Description
ValueError

If the pipeline configuration results in a stage having zero or negative layers assigned (pipeline too long for the model size).

Source code in d9d/pipelining/api/module.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def distribute_layers_for_pipeline_stage(
        num_layers: int,
        num_virtual_layers_pre: int,
        num_virtual_layers_post: int,
        stage: PipelineStageInfo
) -> tuple[int, int]:
    """
    Calculates the layer index range for a specific pipeline stage.

    This function distributes a given number of layers across multiple pipeline
    stages as evenly as possible. It accounts for additional, non-layer
    computational load on the first and last stages (e.g., embeddings and the
    LM head) by using the concept of 'virtual layers' to reserve capacity.

    Args:
        num_layers: The total number of primary model layers to be distributed
            (e.g., the transformer blocks).
        num_virtual_layers_pre: The number of 'virtual' layers representing the
            computational cost of modules on the *first* stage, before the main
            layers (e.g., token and positional embeddings).
        num_virtual_layers_post: The number of 'virtual' layers representing the
            computational cost of modules on the *last* stage, after the main
            layers (e.g., the final layer normalization and LM head).
        stage: An object containing total stages and current stage index.

    Returns:
        A tuple (start_index, end_index), representing the slice of layers for
            the given stage. The start_index is inclusive and the end_index is
            exclusive.

    Raises:
        ValueError: If the pipeline configuration results in a stage having zero
            or negative layers assigned (pipeline too long for the model size).
    """

    num_layers_virtual = num_layers + num_virtual_layers_pre + num_virtual_layers_post

    base_layers_per_stage = num_layers_virtual // stage.num_stages
    extra_layers = num_layers_virtual % stage.num_stages

    layer_count_per_stage = []

    for proposed_stage_i in range(stage.num_stages):
        proposed_stage = PipelineStageInfo(num_stages=stage.num_stages, current_stage=proposed_stage_i)
        layers = base_layers_per_stage + 1 if proposed_stage_i < extra_layers else base_layers_per_stage

        adjustment = 0
        if proposed_stage.is_current_stage_first:
            adjustment += num_virtual_layers_pre
        if proposed_stage.is_current_stage_last:
            adjustment += num_virtual_layers_post

        actual_layers = layers - adjustment

        if actual_layers <= 0:
            raise ValueError(f"Tried to distribute layers, but got {actual_layers} on "
                             f"stage {proposed_stage.current_stage}. Perhaps the pipeline is too long for this model?")

        layer_count_per_stage.append(actual_layers)

    start_layer_id = sum(layer_count_per_stage[:stage.current_stage])
    num_layers_in_stage = layer_count_per_stage[stage.current_stage]

    return start_layer_id, start_layer_id + num_layers_in_stage

d9d.pipelining.factory

PipelineSchedule1F1BConfig

Bases: BaseModel

Configuration for Interleaved 1F1B and Interleaved Zero Bubble execution.

Supports assigning multiple stages per rank and sharding backward to dI and dW to reduce pipeline bubbles.

Source code in d9d/pipelining/factory/config.py
40
41
42
43
44
45
46
47
48
49
50
51
class PipelineSchedule1F1BConfig(BaseModel):
    """
    Configuration for Interleaved 1F1B and Interleaved Zero Bubble execution.

    Supports assigning multiple stages per rank and sharding backward to dI and dW
    to reduce pipeline bubbles.
    """

    schedule: Literal["1f1b"] = "1f1b"

    num_stages_per_rank: int
    zero_bubble: bool

PipelineScheduleDualPipeVConfig

Bases: BaseModel

Configuration for DualPipeV execution.

A bidirectional pipeline schedule for high-throughput training, utilizing V-shape topology and reciprocal forward/backward scheduling.

Source code in d9d/pipelining/factory/config.py
64
65
66
67
68
69
70
71
72
class PipelineScheduleDualPipeVConfig(BaseModel):
    """
    Configuration for DualPipeV execution.

    A bidirectional pipeline schedule for high-throughput training, utilizing
    V-shape topology and reciprocal forward/backward scheduling.
    """

    schedule: Literal["dual_pipe_v"] = "dual_pipe_v"

PipelineScheduleGPipeConfig

Bases: BaseModel

Configuration for GPipe execution.

This assumes a single stage per rank and processes all microbatches for the forward pass before switching to the backward pass.

Source code in d9d/pipelining/factory/config.py
16
17
18
19
20
21
22
23
24
class PipelineScheduleGPipeConfig(BaseModel):
    """
    Configuration for GPipe execution.

    This assumes a single stage per rank and processes all microbatches for the
    forward pass before switching to the backward pass.
    """

    schedule: Literal["gpipe"] = "gpipe"

PipelineScheduleInferenceConfig

Bases: BaseModel

Configuration for inference-only pipeline execution.

This schedule runs all forward passes sequentially without any backward passes.

Source code in d9d/pipelining/factory/config.py
 6
 7
 8
 9
10
11
12
13
class PipelineScheduleInferenceConfig(BaseModel):
    """
    Configuration for inference-only pipeline execution.

    This schedule runs all forward passes sequentially without any backward passes.
    """

    schedule: Literal["inference"] = "inference"

PipelineScheduleLoopedBFSConfig

Bases: BaseModel

Configuration for Looped Breadth-First Search execution.

Similar to GPipe, but supports multiple stages per rank (virtualization). It executes all available work for a specific stage before moving to the next.

Source code in d9d/pipelining/factory/config.py
27
28
29
30
31
32
33
34
35
36
37
class PipelineScheduleLoopedBFSConfig(BaseModel):
    """
    Configuration for Looped Breadth-First Search execution.

    Similar to GPipe, but supports multiple stages per rank (virtualization).
    It executes all available work for a specific stage before moving to the next.
    """

    schedule: Literal["looped_bfs"] = "looped_bfs"

    num_stages_per_rank: int

PipelineScheduleZeroBubbleVConfig

Bases: BaseModel

Configuration for Zero Bubble V (ZBV) execution.

A specialized V-shape topology schedule that splits backward passes into Input and Weight gradients to maximize overlap. Requires exactly 2 stages per rank.

Source code in d9d/pipelining/factory/config.py
54
55
56
57
58
59
60
61
class PipelineScheduleZeroBubbleVConfig(BaseModel):
    """
    Configuration for Zero Bubble V (ZBV) execution.

    A specialized V-shape topology schedule that splits backward passes into
    Input and Weight gradients to maximize overlap. Requires exactly 2 stages per rank.
    """
    schedule: Literal["zero_bubble_v"] = "zero_bubble_v"

build_schedule(dist_context, n_microbatches, schedule_config, model_provider, loss_fn, sharding_spec)

Constructs the pipeline schedule and instantiates model stages.

This function coordinates the creation of the distributed pipeline. It: 1. Selects the appropriate PipelineProgramBuilder based on the config. 2. Calculates the global stage topology mapping stages to ranks. 3. Instantiates the local model stages for the current rank using model_provider. 4. Wraps models in PipelineStage containers. 5. Generates the execution program (action list). 6. Builds the runtime executor.

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context.

required
n_microbatches int

Number of microbatches per global step.

required
schedule_config AnyPipelineScheduleConfig

Configuration object determining the schedule strategy.

required
model_provider Callable[[PipelineStageInfo], Module]

A factory function that accepts stage info and returns an nn.Module for that specific stage.

required
loss_fn Callable[[dict[str, Tensor], int], Tensor] | None

Optional loss function. Required if training (backward pass needed).

required
sharding_spec PipelineShardingSpec

Specification for how data and states are sharded.

required

Returns:

Type Description
PipelineScheduleInfo

A tuple containing:

list[Module]
  1. PipelineScheduleInfo: The executable schedule and metadata.
tuple[PipelineScheduleInfo, list[Module]]
  1. list[nn.Module]: The local PyTorch modules created for this rank.
Source code in d9d/pipelining/factory/factory.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def build_schedule(
        dist_context: DistributedContext,
        n_microbatches: int,
        schedule_config: AnyPipelineScheduleConfig,
        model_provider: Callable[[PipelineStageInfo], nn.Module],
        loss_fn: Callable[[dict[str, torch.Tensor], int], torch.Tensor] | None,
        sharding_spec: PipelineShardingSpec
) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
    """
    Constructs the pipeline schedule and instantiates model stages.

    This function coordinates the creation of the distributed pipeline. It:
    1.  Selects the appropriate `PipelineProgramBuilder` based on the config.
    2.  Calculates the global stage topology mapping stages to ranks.
    3.  Instantiates the local model stages for the current rank using `model_provider`.
    4.  Wraps models in `PipelineStage` containers.
    5.  Generates the execution program (action list).
    6.  Builds the runtime executor.

    Args:
        dist_context: The distributed context.
        n_microbatches: Number of microbatches per global step.
        schedule_config: Configuration object determining the schedule strategy.
        model_provider: A factory function that accepts stage info and returns an `nn.Module`
            for that specific stage.
        loss_fn: Optional loss function. Required if training (backward pass needed).
        sharding_spec: Specification for how data and states are sharded.

    Returns:
        A tuple containing:
        1.  `PipelineScheduleInfo`: The executable schedule and metadata.
        2.  `list[nn.Module]`: The local PyTorch modules created for this rank.
    """

    program_builder = PIPELINE_PROGRAM_REGISTRY.program_for(schedule_config)
    mesh = dist_context.mesh_for(REGULAR_DOMAIN)["pp"]

    num_stages = program_builder.num_stages_per_rank * mesh.size()

    stage_to_host = build_stage_to_host_rank_topology(
        num_stages=num_stages,
        pp_size=mesh.size(),
        style=program_builder.topology_style
    )
    host_to_stage = invert_stage_to_host_rank_topology(stage_to_host)
    this_rank_stages = host_to_stage[mesh.get_local_rank()]

    stages = []
    modules = []
    has_first_stage = False
    has_last_stage = False

    for stage_idx in this_rank_stages:
        stage_info = PipelineStageInfo(
            num_stages=num_stages,
            current_stage=stage_idx
        )

        if stage_info.is_current_stage_first:
            has_first_stage = True
        if stage_info.is_current_stage_last:
            has_last_stage = True

        model = model_provider(stage_info)
        modules.append(model)
        stage = PipelineStage(
            info=stage_info,
            module=model,
            group=mesh.get_group(),
            stage_to_host_topology=stage_to_host
        )
        stages.append(stage)

    program = program_builder.compose(num_microbatches=n_microbatches, pp_size=mesh.size())
    schedule = PipelineScheduleExecutor(
        dist_context=dist_context,
        stages=stages,
        num_microbatches=n_microbatches,
        loss_fn=loss_fn,
        program=program,
        sharding_spec=sharding_spec
    )

    return PipelineScheduleInfo(
        schedule=schedule,
        has_first_stage=has_first_stage,
        has_last_stage=has_last_stage
    ), modules