The d9d Approach
d9d implements a modern, highly modular pipelining engine designed for performance, stability and customization.
Dynamic Shapes & Algorithmic Shape Inference
To run P2P (Point-to-Point) communication, the receiver must know the shape of the incoming tensor to pre-allocate buffers. d9d asks your model to implement a lightweight protocol (ModuleSupportsPipelining) to calculate stage input and output shapes from batch input shapes mathematically, without performing a heavy forward pass or doing a distributed graph tracing.
This allows supporting Dynamic Shapes (e.g., varying sequence lengths) efficiently across runs.
Construction Consistency (No Patching)
A common anti-pattern in distributed training is "Instantiate-then-Delete": creating a huge model on CPU/Meta device and then hacking it apart del model.layers[N:].
We reject this pattern because of:
- Fragility: Changes to model architecture require changes to the external slicing script.
- Leaky Abstractions: Forward methods become full of
if self.layer is not None. - Invalid States: The model object exists in a "zombie" state until sliced.
In d9d, models are Pipeline-Aware. Each pipeline rank constructs only the sub-graph it owns. The object returned is compliant, complete, and valid immediately.
Making Models Compatible
The Protocol
Implementing the Protocol
To use Pipeline Parallelism in d9d, your model must implement the d9d.pipelining.api.ModuleSupportsPipelining protocol to allow the framework to manage memory and buffer allocations.
Forward Compatibility
- Pipelined models currently only support outputting a dictionary (
dict[str, torch.Tensor]). However, we plan to support arbitrary PyTrees in further releases. The keys in the dictionary returned by yourforwardmethod must strictly match the keys in the dictionary calculated byinfer_stage_outputs_from_pipeline_inputs. - The named arguments accepted by your
forwardmethod must strictly match theinfer_stage_inputs_from_pipeline_inputs.
This allows the communication handler to map tensor names to P2P buffers deterministically.
Example
Below is a skeleton of a Transformer-like model implemented for d9d pipelining.
import torch
from torch import nn
from d9d.pipelining.api import PipelineStageInfo, distribute_layers_for_pipeline_stage
class MyModelChunk(nn.Module):
def __init__(self, stage: PipelineStageInfo, config):
super().__init__()
self.stage = stage
self.config = config
# 1. Determine what layers live here
self.start_layer, self.end_layer = distribute_layers_for_pipeline_stage(
config.n_layers, num_virtual_layers_pre=1, num_virtual_layers_post=1, stage=stage
)
# 2. Build sub-modules (using ModuleDict - for compatibility)
self.layers = nn.ModuleDict({
str(layer): TransformerBlock(...)
for layer in range(self.start_layer, self.end_layer)
})
# Only build embeddings on first stage
if stage.is_current_stage_first:
self.embed = nn.Embedding(...)
# Only build head on last stage
if stage.is_current_stage_last:
self.head = nn.Linear(...)
def forward(self, input_ids=None, hidden_states=None):
# Run embeddings only on first stage
if self.stage.is_current_stage_first:
x = self.embed(input_ids)
else:
x = hidden_states
# Run local layers
for layer_idx in range(self.start_layer, self.end_layer):
x = self.layers[str(layer_idx)](x)
outputs = {
"hidden_states": x
}
# Last stage logic
if self.stage.is_current_stage_last:
logits = self.head(x)
outputs['logits'] = logits
return outputs
# --- Protocol Implementation ---
def infer_stage_inputs_from_pipeline_inputs(self, inputs: dict[str, torch.Tensor], n_microbatches: int):
batch_size = inputs['input_ids'].shape[0]
micro_batch_size = batch_size // n_microbatches
seq_len = inputs['input_ids'].shape[1]
if self.stage.is_current_stage_first:
# First stage receives raw input IDs
return {"input_ids": torch.empty((micro_batch_size, seq_len), dtype=torch.long)}
else:
# Intermediate stages receive hidden states from previous stage
return {"hidden_states": torch.empty((micro_batch_size, seq_len, self.hidden_dim))}
def infer_stage_outputs_from_pipeline_inputs(self, inputs: dict[str, torch.Tensor], n_microbatches: int):
batch_size = inputs['input_ids'].shape[0]
micro_batch_size = batch_size // n_microbatches
seq_len = inputs['input_ids'].shape[1]
outputs = {"hidden_states": torch.empty((micro_batch_size, seq_len, self.config.hidden_dim))}
if self.stage.is_current_stage_last:
# Last stage outputs logits too
outputs["logits"] = torch.empty((micro_batch_size, seq_len, self.config.vocab_size))
return outputs
Using the Pipeline
Supported Schedules
| Example JSON | Description |
|---|---|
{"schedule": "inference"} |
Configuration for inference-only pipeline execution. Runs all forward passes sequentially without any backward passes. |
{"schedule": "gpipe"} |
Standard GPipe execution. Assumes a single stage per rank and processes all microbatches for the forward pass before switching to the backward pass. |
{"schedule": "looped_bfs", "num_stages_per_rank": 2} |
Looped Breadth-First Search execution. Supports multiple stages per rank (virtualization) and executes all work for a specific stage before moving to the next. |
{"schedule": "1f1b", "num_stages_per_rank": 1, "zero_bubble": true} |
Interleaved 1F1B and Interleaved Zero Bubble execution. Supports multiple stages per rank. Handles sharding backward passes to dI and dW when zero_bubble is enabled. |
{"schedule": "zero_bubble_v"} |
Zero Bubble V (ZBV) execution. A specialized V-shape topology schedule that splits backward passes into Input and Weight gradients. Requires exactly 2 stages per rank. |
{"schedule": "dual_pipe_v"} |
DualPipeV execution. A bidirectional pipeline schedule for high-throughput training using V-shape topology and reciprocal forward/backward scheduling. |
Batch Sharding
Pipelining works by splitting the input batch into N microbatches.
By default, d9d assumes all input and output tensors should be split along dimension 0.
However, if your inputs require different sharding strategy, you can customize this via PipelineShardingSpec.
Please see the sharding utils docs.
from d9d.pipelining.api import PipelineShardingSpec
from d9d.core.sharding import ShardingSpec
from torch.distributed.tensor import Shard
# Example: Split 'images' on dim 1, but replicate 'camera_angles' across all microbatches
my_spec = PipelineShardingSpec(
input_data={
"images": Shard(1),
"camera_angles": None
}
# input_kwargs can be defined similarly
)
Usage within the Trainer
Pipelining is available in the Trainer framework. When configuring the Trainer, simply provide an AnyPipelineScheduleConfig in your training arguments. The Trainer handles the construction of the schedule and the distribution of layers automatically.
Advanced - Manual Usage
If you want to use pipelining outside the Trainer (e.g., custom loops), you use the build_schedule factory.
The build_schedule function requires a Model Provider logic. Instead of passing an instantiated model, you pass a function that accepts PipelineStageInfo and returns the nn.Module for that stage. This ensures construction consistency.
from torch import Tensor
from torch.distributed.tensor import Shard
import torch.nn.functional as F
from d9d.core.dist_context import DistributedContext
from d9d.core.sharding import shard_tree
from d9d.pipelining.factory import build_schedule, PipelineSchedule1F1BConfig
from d9d.pipelining.api import PipelineShardingSpec
# 0. Define an object that manages loss calculation across steps
class PipelineLossHandler:
def __init__(self, num_microbatches: int):
self._shard_spec = {
'target': Shard(0)
}
self._num_microbatches = num_microbatches
self._targets = None
def set_targets(self, targets: Tensor):
self._targets = shard_tree(
{'target': targets},
sharding_spec=self._shard_spec,
num_shards=self._num_microbatches,
enforce_even_split=True
)
def compute_loss(self, outputs: dict[str, Tensor], microbatch_idx: int):
# Implement any custom logic here
current_target = self._targets[microbatch_idx]
return F.cross_entropy(outputs['logits'].view(-1, outputs['logits'].shape[-1]), current_target.view(-1))
# 1. Define configuration
dist_context: DistributedContext = ...
model_config = ...
n_microbatches = 32
schedule_config = PipelineSchedule1F1BConfig(
num_stages_per_rank=4, # 4 Virtual stages per rank
zero_bubble=True # Enable ZB1P optimization
)
# 2. Build the schedule, model shards and loss compute function
loss_handler = PipelineLossHandler(num_microbatches=n_microbatches)
schedule_info, modules = build_schedule(
dist_context=dist_context,
n_microbatches=32,
schedule_config=schedule_config,
model_provider=lambda stage: MyModelChunk(stage, model_config), # Factory function
loss_fn=loss_handler.compute_loss,
sharding_spec=PipelineShardingSpec() # Default sharding across dim 0
)
# 3. Execution
# The schedule object exposes a simple step API
inputs = {"input_ids": ...} # Full batch
loss_handler.set_targets(...) # Set targets for full batch
schedule_info.schedule.configure_buffers(inputs, kwargs={}) # Pre-allocate buffers
schedule_info.schedule.step(inputs, kwargs={})
d9d.pipelining.api
Pipelining API that is intended to be accessible by end user.
ModuleSupportsPipelining
Bases: Protocol
Protocol for modules that support pipeline parallelism metadata inference.
Classes implementing this protocol enable the framework to pre-calculate tensor shapes and types required for inter-stage communication (p2p) without executing the full forward pass.
Source code in d9d/pipelining/api/module.py
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | |
infer_stage_inputs_from_pipeline_inputs(inputs, n_microbatches)
Infers the input tensors metadata for the current pipeline stage based on global batch inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
dict[str, Tensor]
|
Global inputs for the pipeline. |
required |
n_microbatches
|
int
|
Number of microbatches the global batch is split into. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary of input tensors expected by this specific stage locally. |
Source code in d9d/pipelining/api/module.py
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | |
infer_stage_outputs_from_pipeline_inputs(inputs, n_microbatches)
Infers the output tensors metadata for the current pipeline stage based on global batch inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
dict[str, Tensor]
|
Global inputs for the pipeline (typically a batch). |
required |
n_microbatches
|
int
|
Number of microbatches the global batch is split into. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary of output tensors produced by this specific stage locally. |
Source code in d9d/pipelining/api/module.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | |
PipelineSchedule
Bases: ABC
Abstract base class defining the interface for pipeline execution schedules.
Source code in d9d/pipelining/api/schedule.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | |
configure_buffers(inputs, kwargs)
abstractmethod
Configures internal state and buffers based on input shapes.
This method allows the schedule to pre-allocate memory or setup sharding specifications based on the structure of the input data before execution begins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
dict[str, Tensor]
|
A dictionary of input tensors. |
required |
kwargs
|
dict[str, Any]
|
A dictionary of keyword arguments. |
required |
Source code in d9d/pipelining/api/schedule.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 | |
step(inputs, kwargs)
abstractmethod
Executes a single pipeline step using the provided inputs.
This typically involves distributing inputs across microbatches, executing forward and backward passes according to the specific schedule logic, and handling communications between stages.
Args: inputs: A dictionary of global input tensors. kwargs: A dictionary of global keyword arguments.
Source code in d9d/pipelining/api/schedule.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | |
PipelineStageInfo
dataclass
Holds information about the current position within the distributed pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
current_stage |
int
|
The 0-based index of the current pipeline stage. |
num_stages |
int
|
The total number of stages in the pipeline. |
Source code in d9d/pipelining/api/module.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | |
is_current_stage_first
property
Determines if this is the first stage in the pipeline.
Returns:
| Type | Description |
|---|---|
bool
|
True if current_stage is 0. |
is_current_stage_last
property
Determines if this is the last stage in the pipeline.
Returns:
| Type | Description |
|---|---|
bool
|
True if current_stage is the last index. |
distribute_layers_for_pipeline_stage(num_layers, num_virtual_layers_pre, num_virtual_layers_post, stage)
Calculates the layer index range for a specific pipeline stage.
This function distributes a given number of layers across multiple pipeline stages as evenly as possible. It accounts for additional, non-layer computational load on the first and last stages (e.g., embeddings and the LM head) by using the concept of 'virtual layers' to reserve capacity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_layers
|
int
|
The total number of primary model layers to be distributed (e.g., the transformer blocks). |
required |
num_virtual_layers_pre
|
int
|
The number of 'virtual' layers representing the computational cost of modules on the first stage, before the main layers (e.g., token and positional embeddings). |
required |
num_virtual_layers_post
|
int
|
The number of 'virtual' layers representing the computational cost of modules on the last stage, after the main layers (e.g., the final layer normalization and LM head). |
required |
stage
|
PipelineStageInfo
|
An object containing total stages and current stage index. |
required |
Returns:
| Type | Description |
|---|---|
tuple[int, int]
|
A tuple (start_index, end_index), representing the slice of layers for the given stage. The start_index is inclusive and the end_index is exclusive. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the pipeline configuration results in a stage having zero or negative layers assigned (pipeline too long for the model size). |
Source code in d9d/pipelining/api/module.py
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | |
d9d.pipelining.factory
PipelineSchedule1F1BConfig
Bases: BaseModel
Configuration for Interleaved 1F1B and Interleaved Zero Bubble execution.
Supports assigning multiple stages per rank and sharding backward to dI and dW to reduce pipeline bubbles.
Source code in d9d/pipelining/factory/config.py
40 41 42 43 44 45 46 47 48 49 50 51 | |
PipelineScheduleDualPipeVConfig
Bases: BaseModel
Configuration for DualPipeV execution.
A bidirectional pipeline schedule for high-throughput training, utilizing V-shape topology and reciprocal forward/backward scheduling.
Source code in d9d/pipelining/factory/config.py
64 65 66 67 68 69 70 71 72 | |
PipelineScheduleGPipeConfig
Bases: BaseModel
Configuration for GPipe execution.
This assumes a single stage per rank and processes all microbatches for the forward pass before switching to the backward pass.
Source code in d9d/pipelining/factory/config.py
16 17 18 19 20 21 22 23 24 | |
PipelineScheduleInferenceConfig
Bases: BaseModel
Configuration for inference-only pipeline execution.
This schedule runs all forward passes sequentially without any backward passes.
Source code in d9d/pipelining/factory/config.py
6 7 8 9 10 11 12 13 | |
PipelineScheduleLoopedBFSConfig
Bases: BaseModel
Configuration for Looped Breadth-First Search execution.
Similar to GPipe, but supports multiple stages per rank (virtualization). It executes all available work for a specific stage before moving to the next.
Source code in d9d/pipelining/factory/config.py
27 28 29 30 31 32 33 34 35 36 37 | |
PipelineScheduleZeroBubbleVConfig
Bases: BaseModel
Configuration for Zero Bubble V (ZBV) execution.
A specialized V-shape topology schedule that splits backward passes into Input and Weight gradients to maximize overlap. Requires exactly 2 stages per rank.
Source code in d9d/pipelining/factory/config.py
54 55 56 57 58 59 60 61 | |
build_schedule(dist_context, n_microbatches, schedule_config, model_provider, loss_fn, sharding_spec)
Constructs the pipeline schedule and instantiates model stages.
This function coordinates the creation of the distributed pipeline. It:
1. Selects the appropriate PipelineProgramBuilder based on the config.
2. Calculates the global stage topology mapping stages to ranks.
3. Instantiates the local model stages for the current rank using model_provider.
4. Wraps models in PipelineStage containers.
5. Generates the execution program (action list).
6. Builds the runtime executor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context. |
required |
n_microbatches
|
int
|
Number of microbatches per global step. |
required |
schedule_config
|
AnyPipelineScheduleConfig
|
Configuration object determining the schedule strategy. |
required |
model_provider
|
Callable[[PipelineStageInfo], Module]
|
A factory function that accepts stage info and returns an |
required |
loss_fn
|
Callable[[dict[str, Tensor], int], Tensor] | None
|
Optional loss function. Required if training (backward pass needed). |
required |
sharding_spec
|
PipelineShardingSpec
|
Specification for how data and states are sharded. |
required |
Returns:
| Type | Description |
|---|---|
PipelineScheduleInfo
|
A tuple containing: |
list[Module]
|
|
tuple[PipelineScheduleInfo, list[Module]]
|
|
Source code in d9d/pipelining/factory/factory.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | |