Pipeline Parallelism
The d9d Approach
d9d implements a modern, highly modular pipelining engine designed for performance, stability and customization.
Dynamic Shapes & Algorithmic Shape Inference
To run P2P (Point-to-Point) communication, the receiver must know the shape of the incoming tensor to pre-allocate buffers. d9d asks your model to implement a lightweight protocol (ModuleSupportsPipelining) to calculate stage input and output shapes from batch input shapes mathematically, without performing a heavy forward pass or doing a distributed graph tracing.
This allows supporting Dynamic Shapes (e.g., varying sequence lengths) efficiently across runs.
Construction Consistency (No Patching)
A common anti-pattern in distributed training is "Instantiate-then-Delete": creating a huge model on CPU/Meta device and then hacking it apart del model.layers[N:].
We reject this pattern because of:
- Fragility: Changes to model architecture require changes to the external slicing script.
- Leaky Abstractions: Forward methods become full of
if self.layer is not None. - Invalid States: The model object exists in a "zombie" state until sliced.
In d9d, models are Pipeline-Aware. Each pipeline rank constructs only the sub-graph it owns. The object returned is compliant, complete, and valid immediately.
Making Models Compatible
The Protocol
Implementing the Protocol
To use Pipeline Parallelism in d9d, your model must implement the d9d.pipelining.api.ModuleSupportsPipelining protocol to allow the framework to manage memory and buffer allocations.
Forward Compatibility
- Pipelined models currently only support outputting a dictionary (
dict[str, torch.Tensor]). However, we plan to support arbitrary PyTrees in further releases. The keys in the dictionary returned by yourforwardmethod must strictly match the keys in the dictionary calculated byinfer_stage_outputs_from_pipeline_inputs. - The named arguments accepted by your
forwardmethod must strictly match theinfer_stage_inputs_from_pipeline_inputs.
This allows the communication handler to map tensor names to P2P buffers deterministically.
Example
Below is a skeleton of a Transformer-like model implemented for d9d pipelining.
Using the Pipeline
Supported Schedules
| Example JSON | Description |
|---|---|
{"schedule": "inference"} |
Configuration for inference-only pipeline execution. Runs all forward passes sequentially without any backward passes. |
{"schedule": "gpipe"} |
Standard GPipe execution. Assumes a single stage per rank and processes all microbatches for the forward pass before switching to the backward pass. |
{"schedule": "looped_bfs", "num_stages_per_rank": 2} |
Looped Breadth-First Search execution. Supports multiple stages per rank (virtualization) and executes all work for a specific stage before moving to the next. |
{"schedule": "1f1b", "num_stages_per_rank": 1, "zero_bubble": true} |
Interleaved 1F1B and Interleaved Zero Bubble execution. Supports multiple stages per rank. Handles sharding backward passes to dI and dW when zero_bubble is enabled. |
{"schedule": "zero_bubble_v"} |
Zero Bubble V (ZBV) execution. A specialized V-shape topology schedule that splits backward passes into Input and Weight gradients. Requires exactly 2 stages per rank. |
{"schedule": "dual_pipe_v"} |
DualPipeV execution. A bidirectional pipeline schedule for high-throughput training using V-shape topology and reciprocal forward/backward scheduling. |
Batch Sharding
Pipelining works by splitting the input batch into N microbatches.
By default, d9d assumes all input and output tensors should be split along dimension 0.
However, if your inputs require different sharding strategy, you can customize this via PipelineShardingSpec.
Please see the sharding utils docs.
Usage within the Trainer
Pipelining is available in the Trainer framework. When configuring the Trainer, simply provide an AnyPipelineScheduleConfig in your training arguments. The Trainer handles the construction of the schedule and the distribution of layers automatically.
Advanced - Manual Usage
If you want to use pipelining outside the Trainer (e.g., custom loops), you use the build_schedule factory.
The build_schedule function requires a Model Provider logic. Instead of passing an instantiated model, you pass a function that accepts PipelineStageInfo and returns the nn.Module for that stage. This ensures construction consistency.
d9d.pipelining.api
Pipelining API that is intended to be accessible by end user.
PipelineLossFn = Callable[[dict[str, torch.Tensor], int], torch.Tensor]
module-attribute
Callback function type for calculating loss in the final pipeline stage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outputs
|
A dictionary mapping output names to tensors produced by the model. |
required | |
microbatch_idx
|
The index of the current micro-batch being processed. |
required |
Returns:
| Type | Description |
|---|---|
|
The computed loss tensor (scalar). |
PipelineResultFn = Callable[[dict[str, torch.Tensor], int], Any]
module-attribute
Callback function type for handling results from a final pipeline stage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outputs
|
A dictionary mapping output names to tensors produced by the stage. |
required | |
microbatch_idx
|
The index of the current micro-batch being processed. |
required |
Returns:
| Type | Description |
|---|---|
|
Anything - not used. |
ModuleSupportsPipelining
Bases: Protocol
Protocol for modules that support pipeline parallelism metadata inference.
Classes implementing this protocol enable the framework to pre-calculate tensor shapes and types required for inter-stage communication (p2p) without executing the full forward pass.
infer_stage_inputs_from_pipeline_inputs(inputs, n_microbatches)
Infers the input tensors metadata for the current pipeline stage based on global batch inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
dict[str, Tensor]
|
Global inputs for the pipeline. |
required |
n_microbatches
|
int
|
Number of microbatches the global batch is split into. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary of input tensors expected by this specific stage locally. |
infer_stage_outputs_from_pipeline_inputs(inputs, n_microbatches)
Infers the output tensors metadata for the current pipeline stage based on global batch inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
dict[str, Tensor]
|
Global inputs for the pipeline (typically a batch). |
required |
n_microbatches
|
int
|
Number of microbatches the global batch is split into. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary of output tensors produced by this specific stage locally. |
PipelineSchedule
Bases: ABC
Abstract base class defining the interface for pipeline execution schedules.
configure_buffers(inputs, kwargs, sharding_spec)
abstractmethod
Configures internal state and buffers based on input shapes.
This method allows the schedule to pre-allocate memory or setup sharding specifications based on the structure of the input data before execution begins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
dict[str, Tensor]
|
A dictionary of input tensors. |
required |
kwargs
|
dict[str, Any]
|
A dictionary of keyword arguments. |
required |
sharding_spec
|
PipelineShardingSpec | None
|
A specification defining how inputs and kwargs should be split into micro-batches. If None, assumes standard split-by-zero-dim behavior. |
required |
step(inputs, kwargs)
abstractmethod
Executes a single pipeline step using the provided inputs.
This typically involves distributing inputs across microbatches, executing forward and backward passes according to the specific schedule logic, and handling communications between stages.
Args: inputs: A dictionary of global input tensors. kwargs: A dictionary of global keyword arguments.
PipelineStageInfo
dataclass
Holds information about the current position within the distributed pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
current_stage |
int
|
The 0-based index of the current pipeline stage. |
num_stages |
int
|
The total number of stages in the pipeline. |
is_current_stage_first
property
Determines if this is the first stage in the pipeline.
Returns:
| Type | Description |
|---|---|
bool
|
True if current_stage is 0. |
is_current_stage_last
property
Determines if this is the last stage in the pipeline.
Returns:
| Type | Description |
|---|---|
bool
|
True if current_stage is the last index. |
distribute_layers_for_pipeline_stage(num_layers, num_virtual_layers_pre, num_virtual_layers_post, stage)
Calculates the layer index range for a specific pipeline stage.
This function distributes a given number of layers across multiple pipeline stages as evenly as possible. It accounts for additional, non-layer computational load on the first and last stages (e.g., embeddings and the LM head) by using the concept of 'virtual layers' to reserve capacity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_layers
|
int
|
The total number of primary model layers to be distributed (e.g., the transformer blocks). |
required |
num_virtual_layers_pre
|
int
|
The number of 'virtual' layers representing the computational cost of modules on the first stage, before the main layers (e.g., token and positional embeddings). |
required |
num_virtual_layers_post
|
int
|
The number of 'virtual' layers representing the computational cost of modules on the last stage, after the main layers (e.g., the final layer normalization and LM head). |
required |
stage
|
PipelineStageInfo
|
An object containing total stages and current stage index. |
required |
Returns:
| Type | Description |
|---|---|
tuple[int, int]
|
A tuple (start_index, end_index), representing the slice of layers for the given stage. The start_index is inclusive and the end_index is exclusive. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the pipeline configuration results in a stage having zero or negative layers assigned (pipeline too long for the model size). |
d9d.pipelining.factory
AnyPipelineScheduleConfig = Annotated[PipelineScheduleInferenceConfig | PipelineScheduleGPipeConfig | PipelineScheduleLoopedBFSConfig | PipelineSchedule1F1BConfig | PipelineScheduleZeroBubbleVConfig | PipelineScheduleDualPipeVConfig, Field(discriminator='schedule')]
module-attribute
Union of all supported pipeline schedule configuration types.
This type alias uses a Pydantic discriminator on the schedule field to allow
polymorphic validation and serialization of specific schedule configs (e.g.
Inference, GPipe, 1F1B, ZeroBubble, etc.).
PipelineSchedule1F1BConfig
Bases: BaseModel
Configuration for Interleaved 1F1B and Interleaved Zero Bubble execution.
Supports assigning multiple stages per rank and sharding backward to dI and dW to reduce pipeline bubbles.
PipelineScheduleDualPipeVConfig
Bases: BaseModel
Configuration for DualPipeV execution.
A bidirectional pipeline schedule for high-throughput training, utilizing V-shape topology and reciprocal forward/backward scheduling.
PipelineScheduleGPipeConfig
Bases: BaseModel
Configuration for GPipe execution.
This assumes a single stage per rank and processes all microbatches for the forward pass before switching to the backward pass.
PipelineScheduleInferenceConfig
Bases: BaseModel
Configuration for inference-only pipeline execution.
This schedule runs all forward passes sequentially without any backward passes.
PipelineScheduleLoopedBFSConfig
Bases: BaseModel
Configuration for Looped Breadth-First Search execution.
Similar to GPipe, but supports multiple stages per rank (virtualization). It executes all available work for a specific stage before moving to the next.
PipelineScheduleZeroBubbleVConfig
Bases: BaseModel
Configuration for Zero Bubble V (ZBV) execution.
A specialized V-shape topology schedule that splits backward passes into Input and Weight gradients to maximize overlap. Requires exactly 2 stages per rank.
build_schedule(dist_context, n_microbatches, schedule_config, model_provider, callback)
Constructs the pipeline schedule and instantiates model stages.
This function coordinates the creation of the pipeline. If the context is
distributed, it builds a parallel schedule (PipelineScheduleExecutor) by
calculating topology and creating stages for the current rank. If the
context is local, it builds an offline schedule (OfflinePipelineExecutor)
for direct execution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context. |
required |
n_microbatches
|
int
|
Number of microbatches per global step. |
required |
schedule_config
|
AnyPipelineScheduleConfig
|
Configuration object determining the schedule strategy. |
required |
model_provider
|
Callable[[PipelineStageInfo], Module]
|
A factory function that accepts stage info and returns an |
required |
callback
|
PipelineLossFn | PipelineResultFn
|
Callback either computing loss function (if training) or just processing pipeline outputs (if not training). |
required |
Returns:
| Type | Description |
|---|---|
PipelineScheduleInfo
|
A tuple containing the schedule info (executor and metadata) and a list |
list[Module]
|
of local PyTorch modules created for this rank. |