This section details the internals of the d9d.pipelining module. It is intended for those who wish to implement new layouts, schedules, or modify the execution engine.

Architecture

The Idea

d9d decouples the Schedule Structure from the Runtime Execution.

  1. You write a builder (e.g., 1F1B, DualPipe) that generates a linear list of logical Actions (e.g., Forward(Stage=0, MB=1), Backward(Stage=0, MB=0)). If you want, d9d may automatically inject Send/Recv actions into your compute-only schedule based on data dependencies, preventing deadlocks.
  2. You run a dumb virtual machine simply iterates the action list and executes them.

This makes implementing complex research schedules (like Zero Bubble or DualPipeV) significantly easier than managing state machines or recursive calls.

Core Components

PipelineStage (infra/stage/stage.py)

Encapsulates a user nn.Module. It is not responsible for deciding when to run. Instead, it provides atomic pipeline stage capabilities (such as forward and backward passes) to the actions and the executor.

Consists of:

  • Computation Handlers:
    • ForwardComputeHandler: Performs forward pass, caches inputs/outputs for backward passes.
    • BackwardComputeHandler: Performs backward pass, capable of splitting backward passes into backward_input (dI) and backward_weight (dW) for advanced schedules.
  • Communication Handlers: Contain and manage the P2P buffers for both forward and backward passes.

Actions (infra/schedule/component/runtime/action.py)

The atomic instructions for the pipeline virtual machine.

  • ForwardComputeAction: Run forward on specific microbatch.
  • BackwardFullInputComputeAction: Run backward. Can be configured to compute gradients for inputs-only or inputs+weights.
  • BackwardWeightComputeAction: Compute gradients for weights (used in Zero Bubble schedules).
  • ForwardSendAction / ForwardReceiveAction / BackwardSendAction / BackwardReceiveAction: Network IO.
  • ComposeAction: Composes multiple actions into a single one. Used for Forward/Backward overlap in schedules such as DualPipeV.

Actions are designed to be declarative and immutable.

Programs

A Program is simply dict[int, list[ActionBase]] — a mapping of Rank ID to a sequential list of Actions.

Executor (infra/schedule/component/runtime/executor.py)

The PipelineScheduleExecutor is the runtime engine.

It:

  1. Shards global inputs into microbatches.
  2. Iterates through the Program action list.
  3. Dispatches calls to Actions that perform computation or communication workload.

Comparison with PyTorch

The d9d pipelining implementation is heavily inspired by and borrows concepts from the torch.distributed.pipelining API (e.g., ZeroBubble implementation), but refactors the codebase significantly to improve clarity, type safety, and modularity.

The main architectural differences lie in the strict separation of concerns and composition over inheritance:

  1. Decomposed Stage Logic:

    • PyTorch: Uses a monolithic _PipelineStageBase class that simultaneously manages P2P buffer allocation, gradient accumulation state, and forward/backward execution logic.
    • d9d: Adopts a compositional approach. The PipelineStage class is a thin orchestrator that delegates responsibilities to dedicated handlers.
  2. Polymorphic Actions vs Enumeration:

    • PyTorch: Represents schedule instructions using a single generic _Action NamedTuple combined with an Enum (_ComputationType.FORWARD, _ComputationType.SEND_F, etc.).
    • d9d: Uses a class hierarchy for actions (ForwardComputeAction, ForwardSendAction, ComposeAction). This allows the runtime executor to use structural pattern matching (match/case) rather than large if/elif blocks checking enums, allows different actions to carry different metadata (e.g. full_backward flag), and improves static type checking.
  3. Builder Pattern vs Schedule Classes:

    • PyTorch: Often couples the schedule definition with the runtime object (e.g., Schedule1F1B class contains both the logic to generate the ordering and the logic to execute it).
    • d9d: Strictly separates the Program Builder (which generates the list of actions) from the Executor (which runs the actions). This makes it easier to inspect a schedule plan before execution or swap scheduling algorithms without changing the runtime driver.

Building Custom Schedules

To build a new schedule, you create a PipelineProgramBuilder.

Implement the Builder

You must implement the pipeline program builder.

from collections import defaultdict

from d9d.pipelining.infra.schedule.component.program import PipelineProgramBuilder, build_stage_to_host_rank_topology, ScheduleStyle, add_communication_ops
from d9d.pipelining.infra.schedule.component.runtime import ActionBase, ForwardComputeAction


class MyFancyScheduleBuilder(PipelineProgramBuilder):
    def __init__(self, stages_per_rank: int):
        self._stages_per_rank = stages_per_rank

    @property
    def num_stages_per_rank(self) -> int:
        return self._stages_per_rank

    @property
    def topology_style(self) -> ScheduleStyle:
        return ScheduleStyle.loop

    def compose(self, num_microbatches: int, pp_size: int) -> dict[int, list[ActionBase]]:
        # Map logical stages to ranks
        stage_to_rank = build_stage_to_host_rank_topology(num_stages=self._stages_per_rank * pp_size,
                                                          style=ScheduleStyle.loop,
                                                          pp_size=pp_size)

        actions = defaultdict(list)

        # 1. Generate Compute Schedule
        for rank in range(pp_size):
            # ... custom logic to decide order of Fwd/Bwd ...
            actions[rank].append(ForwardComputeAction(stage_idx=..., microbatch_idx=...))

        # 2. Inject Communications (Magic Pass)
        # This analyzes data dependencies between stages and inserts Send/Recvs
        return add_communication_ops(actions, stage_to_rank, num_stages=self._stages_per_rank * pp_size)

Registering

Add your configuration to factory/config.py and register the builder in factory/factory.py.

d9d.pipelining.infra.stage

PipelineStage

Represents a single structural stage in a Pipelined Model.

This class acts as an orchestrator that combines StageCommunicationHandler (for I/O) and Forward/BackwardComputeHandler (for execution). It abstracts away the complexity of buffer management, distributed communication, and gradient calculation from the scheduler.

Source code in d9d/pipelining/infra/stage/stage.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class PipelineStage:
    """
    Represents a single structural stage in a Pipelined Model.

    This class acts as an orchestrator that combines `StageCommunicationHandler` (for I/O)
    and `Forward/BackwardComputeHandler` (for execution). It abstracts away the complexity
    of buffer management, distributed communication, and gradient calculation from the scheduler.
    """

    def __init__(
            self,
            info: PipelineStageInfo,
            module: nn.Module,
            group: dist.ProcessGroup,
            stage_to_host_topology: dict[int, int]
    ):
        """
        Constructs a PipelineStage object.

        Args:
            info: Metadata about the stage (index, total stages).
            module: The PyTorch module executed by this stage.
            group: The distributed process group for pipeline communications.
            stage_to_host_topology: Dict mapping stage ID to PP rank hosting it.
        """

        self._info = info
        self._module = module
        self._group = group
        self._stage_to_host_topology = stage_to_host_topology

        self._has_backward = False

        self._forward_comm: StageCommunicationHandler | None = None
        self._backward_comm: StageCommunicationHandler | None = None

        self._forward_comp = ForwardComputeHandler(
            stage_index=info.current_stage,
            module=module
        )
        self._backward_comp = BackwardComputeHandler(
            stage_index=info.current_stage,
            module=module
        )

    @property
    def info(self) -> PipelineStageInfo:
        return self._info

    def configure_buffers(
            self,
            num_microbatches: int,
            has_backward: bool,
            pipeline_inputs: dict[str, torch.Tensor]
    ):
        """
        Initializes the communication handlers and buffers for the stage.

        This must be called before execution to establish P2P buffer sizes and directions.

        Args:
            num_microbatches: Total number of microbatches to process.
            has_backward: Does this pipeline stage should store info for a backward pass
            pipeline_inputs: Pipeline input data.
        """

        self._has_backward = has_backward

        prev_stage_idx = None if self._info.is_current_stage_first else self._info.current_stage - 1
        next_stage_idx = None if self._info.is_current_stage_last else self._info.current_stage + 1

        with torch.device("meta"):
            if not isinstance(self._module, ModuleSupportsPipelining):
                raise TypeError("Module does not implement ModuleSupportsPipelining protocol")
            inputs_meta = self._module.infer_stage_inputs_from_pipeline_inputs(
                inputs=pipeline_inputs,
                n_microbatches=num_microbatches
            )
            outputs_meta = self._module.infer_stage_outputs_from_pipeline_inputs(
                inputs=pipeline_inputs,
                n_microbatches=num_microbatches
            )

        self._forward_comm = StageCommunicationHandler(
            name="fwd",
            stage_index=self._info.current_stage,
            num_microbatches=num_microbatches,
            input_stage_index=prev_stage_idx,
            input_args=inputs_meta,
            output_stage_index=next_stage_idx,
            output_args=outputs_meta,
            group=self._group,
            stage_idx_to_host_rank=self._stage_to_host_topology
        )
        self._forward_comm.set_input_requires_grad_(requires_grad=has_backward)

        if has_backward:
            # for grad - current stage receives OUTPUTS as inputs and sends INPUTS as outputs
            # because it is reversed forward
            self._backward_comm = StageCommunicationHandler(
                name="bwd",
                stage_index=self._info.current_stage,
                num_microbatches=num_microbatches,
                input_stage_index=next_stage_idx,
                input_args=outputs_meta,
                output_stage_index=prev_stage_idx,
                output_args=inputs_meta,
                group=self._group,
                stage_idx_to_host_rank=self._stage_to_host_topology
            )
        else:
            self._backward_comm = None

    def set_local_fwd_input(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
        """
        Sets local forward inputs manually.

        Used for the V-shape schedulers.
        """

        if self._forward_comm is None:
            raise ValueError("You must configure stage buffers first")

        self._forward_comm.set_inputs_local(inputs, microbatch_index)

    def get_local_fwd_output(self, microbatch_index: int) -> dict[str, torch.Tensor]:
        return self._forward_comp.get_outputs(microbatch_index)

    def pop_local_bwd_output(self, microbatch_index: int) -> dict[str, torch.Tensor]:
        """
        Retrieves local backward outputs (gradients).
        """

        if not self._has_backward:
            raise ValueError()

        return self._backward_comp.pop_for_sending(microbatch_index)

    def set_local_bwd_input(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
        """
        Sets local backward inputs (output gradients) manually.
        """

        if not self._has_backward:
            raise ValueError()

        if self._backward_comm is None:
            raise ValueError("You must configure stage buffers first")

        self._backward_comm.set_inputs_local(inputs, microbatch_index)

    def get_fwd_recv_ops(self, microbatch_index: int) -> list[dist.P2POp]:
        """Returns P2P ops to receive forward inputs for the given microbatch."""

        if self._forward_comm is None:
            raise ValueError("You must configure stage buffers first")

        return self._forward_comm.create_receive_ops(microbatch_index)

    def get_fwd_send_ops(self, microbatch_index: int) -> list[dist.P2POp]:
        """Returns P2P ops to send forward outputs for the given microbatch."""

        if self._forward_comm is None:
            raise ValueError("You must configure stage buffers first")

        fwd_result = self._forward_comp.get_outputs(microbatch_index)
        return self._forward_comm.create_send_ops(fwd_result)

    def get_bwd_recv_ops(self, microbatch_index: int) -> list[dist.P2POp]:
        """Returns P2P ops to receive backward gradients for the given microbatch."""

        if not self._has_backward:
            return []

        if self._backward_comm is None:
            raise ValueError("You must configure stage buffers first")

        return self._backward_comm.create_receive_ops(microbatch_index)

    def get_bwd_send_ops(self, microbatch_index: int) -> list[dist.P2POp]:
        """Returns P2P ops to send backward gradients for the given microbatch."""

        if not self._has_backward:
            return []

        if self._backward_comm is None:
            raise ValueError("You must configure stage buffers first")

        bwd_result = self._backward_comp.pop_for_sending(microbatch_index)
        return self._backward_comm.create_send_ops(bwd_result)

    def forward_one_chunk(
            self,
            microbatch_index: int,
            pipeline_inputs: dict[str, torch.Tensor],
            pipeline_kwargs: dict[str, Any] | None = None,
    ):
        """
        Executes a forward pass for a single microbatch chunk.

        Fetches inputs from the communication buffer (or `pipeline_inputs` if first stage),
        runs the computation, and caches the result.

        Args:
            microbatch_index: The microbatch index.
            pipeline_inputs: Inputs provided locally (only used if this is the first stage).
            pipeline_kwargs: Additional arguments for the module.

        Returns:
            The output tensors of the forward pass.
        """

        if self._forward_comm is None:
            raise ValueError("You must configure stage buffers first")

        if self._info.is_current_stage_first:
            inputs = pipeline_inputs
        else:
            inputs = self._forward_comm.get_inputs(microbatch_index)

        kwargs = pipeline_kwargs or {}

        self._forward_comp.run(
            microbatch_index=microbatch_index,
            inputs=inputs,
            kwargs=kwargs
        )

    def backward_one_chunk(
            self,
            microbatch_index: int,
            loss: torch.Tensor | None = None,
            full_backward: bool = True
    ):
        """
        Executes a backward pass for a single microbatch chunk.

        Can perform either a full backward or just the input gradients (if `full_backward=False`).
        It fetches required data from forward cache and communication buffers.

        Args:
            microbatch_index: The microbatch index.
            loss: The loss tensor (only used if this is the last stage).
            full_backward: If True, computes grads for inputs and weights. If False, only for inputs.
        """

        if not self._has_backward:
            raise ValueError()

        if self._backward_comm is None:
            raise ValueError("You must configure stage buffers first")

        inputs, fwd_outputs = self._forward_comp.pop_inputs_outputs(microbatch_index)

        outputs: dict[str, torch.Tensor]
        outputs_grad: dict[str, torch.Tensor] | None

        if self._info.is_current_stage_last:
            if loss is None:
                raise ValueError("Cannot perform backward on last stage without loss specified")
            outputs = {"loss": loss}
            outputs_grad = None
        else:
            outputs = fwd_outputs
            outputs_grad = self._backward_comm.get_inputs(microbatch_index)

        if full_backward:
            self._backward_comp.backward_full(
                microbatch_index=microbatch_index,
                inputs=inputs,
                outputs=outputs,
                outputs_grad=outputs_grad
            )
        else:
            self._backward_comp.backward_input(
                microbatch_index=microbatch_index,
                inputs=inputs,
                outputs=outputs,
                outputs_grad=outputs_grad
            )

        if self._info.is_current_stage_last and not self._info.is_current_stage_first:
            for t in fwd_outputs.values():
                if not t._is_view():  # noqa: SLF001
                    t.detach_()

    def backward_weight_one_chunk(self, microbatch_index: int):
        """
        Executes the weight gradient accumulation part of the backward pass.

        This assumes `backward_one_chunk(..., full_backward=False)` was already called
        for this microbatch.

        Args:
            microbatch_index: The microbatch index.
        """

        if not self._has_backward:
            raise ValueError()

        self._backward_comp.backward_weight(microbatch_index=microbatch_index)

    def reset(self):
        """Resets the internal state of communication handlers, clearing gradients on buffers."""

        if self._forward_comm is not None:
            self._forward_comm.reset()
        if self._backward_comm is not None:
            self._backward_comm.reset()

__init__(info, module, group, stage_to_host_topology)

Constructs a PipelineStage object.

Parameters:

Name Type Description Default
info PipelineStageInfo

Metadata about the stage (index, total stages).

required
module Module

The PyTorch module executed by this stage.

required
group ProcessGroup

The distributed process group for pipeline communications.

required
stage_to_host_topology dict[int, int]

Dict mapping stage ID to PP rank hosting it.

required
Source code in d9d/pipelining/infra/stage/stage.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
        self,
        info: PipelineStageInfo,
        module: nn.Module,
        group: dist.ProcessGroup,
        stage_to_host_topology: dict[int, int]
):
    """
    Constructs a PipelineStage object.

    Args:
        info: Metadata about the stage (index, total stages).
        module: The PyTorch module executed by this stage.
        group: The distributed process group for pipeline communications.
        stage_to_host_topology: Dict mapping stage ID to PP rank hosting it.
    """

    self._info = info
    self._module = module
    self._group = group
    self._stage_to_host_topology = stage_to_host_topology

    self._has_backward = False

    self._forward_comm: StageCommunicationHandler | None = None
    self._backward_comm: StageCommunicationHandler | None = None

    self._forward_comp = ForwardComputeHandler(
        stage_index=info.current_stage,
        module=module
    )
    self._backward_comp = BackwardComputeHandler(
        stage_index=info.current_stage,
        module=module
    )

backward_one_chunk(microbatch_index, loss=None, full_backward=True)

Executes a backward pass for a single microbatch chunk.

Can perform either a full backward or just the input gradients (if full_backward=False). It fetches required data from forward cache and communication buffers.

Parameters:

Name Type Description Default
microbatch_index int

The microbatch index.

required
loss Tensor | None

The loss tensor (only used if this is the last stage).

None
full_backward bool

If True, computes grads for inputs and weights. If False, only for inputs.

True
Source code in d9d/pipelining/infra/stage/stage.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def backward_one_chunk(
        self,
        microbatch_index: int,
        loss: torch.Tensor | None = None,
        full_backward: bool = True
):
    """
    Executes a backward pass for a single microbatch chunk.

    Can perform either a full backward or just the input gradients (if `full_backward=False`).
    It fetches required data from forward cache and communication buffers.

    Args:
        microbatch_index: The microbatch index.
        loss: The loss tensor (only used if this is the last stage).
        full_backward: If True, computes grads for inputs and weights. If False, only for inputs.
    """

    if not self._has_backward:
        raise ValueError()

    if self._backward_comm is None:
        raise ValueError("You must configure stage buffers first")

    inputs, fwd_outputs = self._forward_comp.pop_inputs_outputs(microbatch_index)

    outputs: dict[str, torch.Tensor]
    outputs_grad: dict[str, torch.Tensor] | None

    if self._info.is_current_stage_last:
        if loss is None:
            raise ValueError("Cannot perform backward on last stage without loss specified")
        outputs = {"loss": loss}
        outputs_grad = None
    else:
        outputs = fwd_outputs
        outputs_grad = self._backward_comm.get_inputs(microbatch_index)

    if full_backward:
        self._backward_comp.backward_full(
            microbatch_index=microbatch_index,
            inputs=inputs,
            outputs=outputs,
            outputs_grad=outputs_grad
        )
    else:
        self._backward_comp.backward_input(
            microbatch_index=microbatch_index,
            inputs=inputs,
            outputs=outputs,
            outputs_grad=outputs_grad
        )

    if self._info.is_current_stage_last and not self._info.is_current_stage_first:
        for t in fwd_outputs.values():
            if not t._is_view():  # noqa: SLF001
                t.detach_()

backward_weight_one_chunk(microbatch_index)

Executes the weight gradient accumulation part of the backward pass.

This assumes backward_one_chunk(..., full_backward=False) was already called for this microbatch.

Parameters:

Name Type Description Default
microbatch_index int

The microbatch index.

required
Source code in d9d/pipelining/infra/stage/stage.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def backward_weight_one_chunk(self, microbatch_index: int):
    """
    Executes the weight gradient accumulation part of the backward pass.

    This assumes `backward_one_chunk(..., full_backward=False)` was already called
    for this microbatch.

    Args:
        microbatch_index: The microbatch index.
    """

    if not self._has_backward:
        raise ValueError()

    self._backward_comp.backward_weight(microbatch_index=microbatch_index)

configure_buffers(num_microbatches, has_backward, pipeline_inputs)

Initializes the communication handlers and buffers for the stage.

This must be called before execution to establish P2P buffer sizes and directions.

Parameters:

Name Type Description Default
num_microbatches int

Total number of microbatches to process.

required
has_backward bool

Does this pipeline stage should store info for a backward pass

required
pipeline_inputs dict[str, Tensor]

Pipeline input data.

required
Source code in d9d/pipelining/infra/stage/stage.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def configure_buffers(
        self,
        num_microbatches: int,
        has_backward: bool,
        pipeline_inputs: dict[str, torch.Tensor]
):
    """
    Initializes the communication handlers and buffers for the stage.

    This must be called before execution to establish P2P buffer sizes and directions.

    Args:
        num_microbatches: Total number of microbatches to process.
        has_backward: Does this pipeline stage should store info for a backward pass
        pipeline_inputs: Pipeline input data.
    """

    self._has_backward = has_backward

    prev_stage_idx = None if self._info.is_current_stage_first else self._info.current_stage - 1
    next_stage_idx = None if self._info.is_current_stage_last else self._info.current_stage + 1

    with torch.device("meta"):
        if not isinstance(self._module, ModuleSupportsPipelining):
            raise TypeError("Module does not implement ModuleSupportsPipelining protocol")
        inputs_meta = self._module.infer_stage_inputs_from_pipeline_inputs(
            inputs=pipeline_inputs,
            n_microbatches=num_microbatches
        )
        outputs_meta = self._module.infer_stage_outputs_from_pipeline_inputs(
            inputs=pipeline_inputs,
            n_microbatches=num_microbatches
        )

    self._forward_comm = StageCommunicationHandler(
        name="fwd",
        stage_index=self._info.current_stage,
        num_microbatches=num_microbatches,
        input_stage_index=prev_stage_idx,
        input_args=inputs_meta,
        output_stage_index=next_stage_idx,
        output_args=outputs_meta,
        group=self._group,
        stage_idx_to_host_rank=self._stage_to_host_topology
    )
    self._forward_comm.set_input_requires_grad_(requires_grad=has_backward)

    if has_backward:
        # for grad - current stage receives OUTPUTS as inputs and sends INPUTS as outputs
        # because it is reversed forward
        self._backward_comm = StageCommunicationHandler(
            name="bwd",
            stage_index=self._info.current_stage,
            num_microbatches=num_microbatches,
            input_stage_index=next_stage_idx,
            input_args=outputs_meta,
            output_stage_index=prev_stage_idx,
            output_args=inputs_meta,
            group=self._group,
            stage_idx_to_host_rank=self._stage_to_host_topology
        )
    else:
        self._backward_comm = None

forward_one_chunk(microbatch_index, pipeline_inputs, pipeline_kwargs=None)

Executes a forward pass for a single microbatch chunk.

Fetches inputs from the communication buffer (or pipeline_inputs if first stage), runs the computation, and caches the result.

Parameters:

Name Type Description Default
microbatch_index int

The microbatch index.

required
pipeline_inputs dict[str, Tensor]

Inputs provided locally (only used if this is the first stage).

required
pipeline_kwargs dict[str, Any] | None

Additional arguments for the module.

None

Returns:

Type Description

The output tensors of the forward pass.

Source code in d9d/pipelining/infra/stage/stage.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def forward_one_chunk(
        self,
        microbatch_index: int,
        pipeline_inputs: dict[str, torch.Tensor],
        pipeline_kwargs: dict[str, Any] | None = None,
):
    """
    Executes a forward pass for a single microbatch chunk.

    Fetches inputs from the communication buffer (or `pipeline_inputs` if first stage),
    runs the computation, and caches the result.

    Args:
        microbatch_index: The microbatch index.
        pipeline_inputs: Inputs provided locally (only used if this is the first stage).
        pipeline_kwargs: Additional arguments for the module.

    Returns:
        The output tensors of the forward pass.
    """

    if self._forward_comm is None:
        raise ValueError("You must configure stage buffers first")

    if self._info.is_current_stage_first:
        inputs = pipeline_inputs
    else:
        inputs = self._forward_comm.get_inputs(microbatch_index)

    kwargs = pipeline_kwargs or {}

    self._forward_comp.run(
        microbatch_index=microbatch_index,
        inputs=inputs,
        kwargs=kwargs
    )

get_bwd_recv_ops(microbatch_index)

Returns P2P ops to receive backward gradients for the given microbatch.

Source code in d9d/pipelining/infra/stage/stage.py
181
182
183
184
185
186
187
188
189
190
def get_bwd_recv_ops(self, microbatch_index: int) -> list[dist.P2POp]:
    """Returns P2P ops to receive backward gradients for the given microbatch."""

    if not self._has_backward:
        return []

    if self._backward_comm is None:
        raise ValueError("You must configure stage buffers first")

    return self._backward_comm.create_receive_ops(microbatch_index)

get_bwd_send_ops(microbatch_index)

Returns P2P ops to send backward gradients for the given microbatch.

Source code in d9d/pipelining/infra/stage/stage.py
192
193
194
195
196
197
198
199
200
201
202
def get_bwd_send_ops(self, microbatch_index: int) -> list[dist.P2POp]:
    """Returns P2P ops to send backward gradients for the given microbatch."""

    if not self._has_backward:
        return []

    if self._backward_comm is None:
        raise ValueError("You must configure stage buffers first")

    bwd_result = self._backward_comp.pop_for_sending(microbatch_index)
    return self._backward_comm.create_send_ops(bwd_result)

get_fwd_recv_ops(microbatch_index)

Returns P2P ops to receive forward inputs for the given microbatch.

Source code in d9d/pipelining/infra/stage/stage.py
164
165
166
167
168
169
170
def get_fwd_recv_ops(self, microbatch_index: int) -> list[dist.P2POp]:
    """Returns P2P ops to receive forward inputs for the given microbatch."""

    if self._forward_comm is None:
        raise ValueError("You must configure stage buffers first")

    return self._forward_comm.create_receive_ops(microbatch_index)

get_fwd_send_ops(microbatch_index)

Returns P2P ops to send forward outputs for the given microbatch.

Source code in d9d/pipelining/infra/stage/stage.py
172
173
174
175
176
177
178
179
def get_fwd_send_ops(self, microbatch_index: int) -> list[dist.P2POp]:
    """Returns P2P ops to send forward outputs for the given microbatch."""

    if self._forward_comm is None:
        raise ValueError("You must configure stage buffers first")

    fwd_result = self._forward_comp.get_outputs(microbatch_index)
    return self._forward_comm.create_send_ops(fwd_result)

pop_local_bwd_output(microbatch_index)

Retrieves local backward outputs (gradients).

Source code in d9d/pipelining/infra/stage/stage.py
141
142
143
144
145
146
147
148
149
def pop_local_bwd_output(self, microbatch_index: int) -> dict[str, torch.Tensor]:
    """
    Retrieves local backward outputs (gradients).
    """

    if not self._has_backward:
        raise ValueError()

    return self._backward_comp.pop_for_sending(microbatch_index)

reset()

Resets the internal state of communication handlers, clearing gradients on buffers.

Source code in d9d/pipelining/infra/stage/stage.py
315
316
317
318
319
320
321
def reset(self):
    """Resets the internal state of communication handlers, clearing gradients on buffers."""

    if self._forward_comm is not None:
        self._forward_comm.reset()
    if self._backward_comm is not None:
        self._backward_comm.reset()

set_local_bwd_input(inputs, microbatch_index)

Sets local backward inputs (output gradients) manually.

Source code in d9d/pipelining/infra/stage/stage.py
151
152
153
154
155
156
157
158
159
160
161
162
def set_local_bwd_input(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
    """
    Sets local backward inputs (output gradients) manually.
    """

    if not self._has_backward:
        raise ValueError()

    if self._backward_comm is None:
        raise ValueError("You must configure stage buffers first")

    self._backward_comm.set_inputs_local(inputs, microbatch_index)

set_local_fwd_input(inputs, microbatch_index)

Sets local forward inputs manually.

Used for the V-shape schedulers.

Source code in d9d/pipelining/infra/stage/stage.py
126
127
128
129
130
131
132
133
134
135
136
def set_local_fwd_input(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
    """
    Sets local forward inputs manually.

    Used for the V-shape schedulers.
    """

    if self._forward_comm is None:
        raise ValueError("You must configure stage buffers first")

    self._forward_comm.set_inputs_local(inputs, microbatch_index)

d9d.pipelining.infra.schedule.component.runtime

Pipelining Runtime Package.

ActionBase

Bases: ABC

Abstract base class for all pipeline schedule actions.

An action represents an atomic unit of work in a pipeline schedule, such as computing a microbatch or sending/receiving a tensor.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class ActionBase(abc.ABC):
    """
    Abstract base class for all pipeline schedule actions.

    An action represents an atomic unit of work in a pipeline schedule,
    such as computing a microbatch or sending/receiving a tensor.
    """

    @abc.abstractmethod
    def apply(self, ctx: ActionContext):
        """
        Executes the action logic using the provided context.

        Args:
            ctx: The runtime context containing stages, data, and communication handlers.
        """

        ...

    @property
    @abc.abstractmethod
    def work_type(self) -> ActionWorkType:
        """Returns the classification of work this action performs."""
        ...

    @property
    @abc.abstractmethod
    def has_backward_work(self) -> bool:
        """Returns True if this action involves backward pass computations."""
        ...

    @abc.abstractmethod
    def __str__(self) -> str:
        """Returns a short string representation of the action for logging/visualization."""
        ...

has_backward_work abstractmethod property

Returns True if this action involves backward pass computations.

work_type abstractmethod property

Returns the classification of work this action performs.

__str__() abstractmethod

Returns a short string representation of the action for logging/visualization.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
79
80
81
82
@abc.abstractmethod
def __str__(self) -> str:
    """Returns a short string representation of the action for logging/visualization."""
    ...

apply(ctx) abstractmethod

Executes the action logic using the provided context.

Parameters:

Name Type Description Default
ctx ActionContext

The runtime context containing stages, data, and communication handlers.

required
Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
56
57
58
59
60
61
62
63
64
65
@abc.abstractmethod
def apply(self, ctx: ActionContext):
    """
    Executes the action logic using the provided context.

    Args:
        ctx: The runtime context containing stages, data, and communication handlers.
    """

    ...

BackwardFullInputComputeAction dataclass

Bases: ActionBase

Action to perform backward computation with respect to inputs.

Attributes:

Name Type Description
stage_idx int

The integer index of the pipeline stage.

microbatch_idx int

The integer index of the microbatch to compute.

full_backward bool

If True, performs a full backward pass including inputs and weights. If False, may only compute gradients w.r.t inputs (depending on schedule implementation).

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
@dataclasses.dataclass(frozen=True, slots=True)
class BackwardFullInputComputeAction(ActionBase):
    """
    Action to perform backward computation with respect to inputs.

    Attributes:
        stage_idx: The integer index of the pipeline stage.
        microbatch_idx: The integer index of the microbatch to compute.
        full_backward: If True, performs a full backward pass including inputs
            and weights. If False, may only compute gradients w.r.t inputs
            (depending on schedule implementation).
    """

    stage_idx: int
    microbatch_idx: int
    full_backward: bool

    def apply(self, ctx: ActionContext):
        # todo unshard
        stage = ctx.stages[self.stage_idx]

        if not stage.info.is_current_stage_last and self.stage_idx + 1 not in ctx.stages:
            ctx.communications.wait_bwd_recv(self.stage_idx, self.microbatch_idx)

        if stage.info.is_current_stage_last and ctx.loss is not None:
            loss = ctx.loss.acquire_loss(self.microbatch_idx)
        else:
            loss = None

        stage.backward_one_chunk(
            microbatch_index=self.microbatch_idx,
            full_backward=self.full_backward,
            loss=loss
        )

        if not stage.info.is_current_stage_first and self.stage_idx - 1 in ctx.stages:
            ctx.stages[self.stage_idx - 1].set_local_bwd_input(
                microbatch_index=self.microbatch_idx,
                inputs=stage.pop_local_bwd_output(self.microbatch_idx)
            )

    @property
    def work_type(self) -> ActionWorkType:
        return ActionWorkType.compute

    @property
    def has_backward_work(self) -> bool:
        return True

    def __str__(self) -> str:
        letter = "B" if self.full_backward else "I"
        return f"{self.stage_idx}{letter}{self.microbatch_idx}"

BackwardReceiveAction dataclass

Bases: ActionBase

Action to schedule a backward pass gradient receive operation.

Attributes:

Name Type Description
stage_idx int

The integer index of the pipeline stage expecting the receive operation.

microbatch_idx int

The integer index of the microbatch being received.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
@dataclasses.dataclass(frozen=True, slots=True)
class BackwardReceiveAction(ActionBase):
    """
    Action to schedule a backward pass gradient receive operation.

    Attributes:
        stage_idx: The integer index of the pipeline stage expecting the receive operation.
        microbatch_idx: The integer index of the microbatch being received.
    """

    stage_idx: int
    microbatch_idx: int

    def apply(self, ctx: ActionContext):
        ctx.communications.schedule_bwd_recv(self.stage_idx, self.microbatch_idx)

    @property
    def work_type(self) -> ActionWorkType:
        return ActionWorkType.communicate

    @property
    def has_backward_work(self) -> bool:
        return True

    def __str__(self) -> str:
        return f"{self.stage_idx}RECV_B{self.microbatch_idx}"

BackwardSendAction dataclass

Bases: ActionBase

Action to schedule a backward pass gradient send operation.

Attributes:

Name Type Description
stage_idx int

The integer index of the pipeline stage initiating the send operation.

microbatch_idx int

The integer index of the microbatch being sent.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@dataclasses.dataclass(frozen=True, slots=True)
class BackwardSendAction(ActionBase):
    """
    Action to schedule a backward pass gradient send operation.

    Attributes:
        stage_idx: The integer index of the pipeline stage initiating the send operation.
        microbatch_idx: The integer index of the microbatch being sent.
    """

    stage_idx: int
    microbatch_idx: int

    def apply(self, ctx: ActionContext):
        ctx.communications.schedule_bwd_send(self.stage_idx, self.microbatch_idx)

    @property
    def work_type(self) -> ActionWorkType:
        return ActionWorkType.communicate

    @property
    def has_backward_work(self) -> bool:
        return True

    def __str__(self) -> str:
        return f"{self.stage_idx}SEND_B{self.microbatch_idx}"

BackwardWeightComputeAction dataclass

Bases: ActionBase

Action to perform gradient accumulation on weights.

Attributes:

Name Type Description
stage_idx int

The integer index of the pipeline stage.

microbatch_idx int

The integer index of the microbatch to compute.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
@dataclasses.dataclass(frozen=True, slots=True)
class BackwardWeightComputeAction(ActionBase):
    """
    Action to perform gradient accumulation on weights.

    Attributes:
        stage_idx: The integer index of the pipeline stage.
        microbatch_idx: The integer index of the microbatch to compute.
    """

    stage_idx: int
    microbatch_idx: int

    def apply(self, ctx: ActionContext):
        # todo unshard
        stage = ctx.stages[self.stage_idx]

        stage.backward_weight_one_chunk(
            microbatch_index=self.microbatch_idx
        )

    @property
    def work_type(self) -> ActionWorkType:
        return ActionWorkType.compute

    @property
    def has_backward_work(self) -> bool:
        return True

    def __str__(self) -> str:
        return f"{self.stage_idx}W{self.microbatch_idx}"

ComposeAction dataclass

Bases: ActionBase

Composite action scheduling multiple sub-actions sequentially.

Used for forward/backward overlapping.

Attributes:

Name Type Description
actions tuple[ActionBase, ...]

A tuple of sub-actions to be executed sequentially.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
@dataclasses.dataclass(frozen=True, slots=True)
class ComposeAction(ActionBase):
    """
    Composite action scheduling multiple sub-actions sequentially.

    Used for forward/backward overlapping.

    Attributes:
        actions: A tuple of sub-actions to be executed sequentially.
    """

    actions: tuple[ActionBase, ...]

    def apply(self, ctx: ActionContext):
        for act in self.actions:
            act.apply(ctx)

    @property
    def work_type(self) -> ActionWorkType:
        sub_work_types = {x.work_type for x in self.actions}
        if len(sub_work_types) != 1:
            raise ValueError("")
        return next(iter(sub_work_types))

    @property
    def has_backward_work(self) -> bool:
        return any(x.has_backward_work for x in self.actions)

    def __str__(self) -> str:
        return "|".join(map(str, self.actions))

ForwardComputeAction dataclass

Bases: ActionBase

Action to perform forward computation for a specific microbatch.

Attributes:

Name Type Description
stage_idx int

The integer index of the pipeline stage.

microbatch_idx int

The integer index of the microbatch to compute.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
@dataclasses.dataclass(frozen=True, slots=True)
class ForwardComputeAction(ActionBase):
    """
    Action to perform forward computation for a specific microbatch.

    Attributes:
        stage_idx: The integer index of the pipeline stage.
        microbatch_idx: The integer index of the microbatch to compute.
    """

    stage_idx: int
    microbatch_idx: int

    def apply(self, ctx: ActionContext):
        # todo check unsharded
        stage = ctx.stages[self.stage_idx]

        if not stage.info.is_current_stage_first and self.stage_idx - 1 not in ctx.stages:
            ctx.communications.wait_fwd_recv(self.stage_idx, self.microbatch_idx)

        stage.forward_one_chunk(
            microbatch_index=self.microbatch_idx,
            pipeline_inputs=ctx.pipeline_inputs_microbatches[self.microbatch_idx],
            pipeline_kwargs=ctx.pipeline_kwargs_microbatches[self.microbatch_idx]
        )
        result = stage.get_local_fwd_output(self.microbatch_idx)

        if stage.info.is_current_stage_last and ctx.loss is not None:
            ctx.loss.compute_loss(result, self.microbatch_idx)

        if not stage.info.is_current_stage_last and self.stage_idx + 1 in ctx.stages:
            ctx.stages[self.stage_idx + 1].set_local_fwd_input(
                inputs=result,
                microbatch_index=self.microbatch_idx
            )

    @property
    def work_type(self) -> ActionWorkType:
        return ActionWorkType.compute

    @property
    def has_backward_work(self) -> bool:
        return False

    def __str__(self) -> str:
        return f"{self.stage_idx}F{self.microbatch_idx}"

ForwardReceiveAction dataclass

Bases: ActionBase

Action to schedule a forward pass tensor receive operation.

Attributes:

Name Type Description
stage_idx int

The integer index of the pipeline stage expecting the receive operation.

microbatch_idx int

The integer index of the microbatch being received.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@dataclasses.dataclass(frozen=True, slots=True)
class ForwardReceiveAction(ActionBase):
    """
    Action to schedule a forward pass tensor receive operation.

    Attributes:
        stage_idx: The integer index of the pipeline stage expecting the receive operation.
        microbatch_idx: The integer index of the microbatch being received.
    """

    stage_idx: int
    microbatch_idx: int

    def apply(self, ctx: ActionContext):
        ctx.communications.schedule_fwd_recv(self.stage_idx, self.microbatch_idx)

    @property
    def work_type(self) -> ActionWorkType:
        return ActionWorkType.communicate

    @property
    def has_backward_work(self) -> bool:
        return True

    def __str__(self) -> str:
        return f"{self.stage_idx}RECV_F{self.microbatch_idx}"

ForwardSendAction dataclass

Bases: ActionBase

Action to schedule a forward pass tensor send operation.

Attributes:

Name Type Description
stage_idx int

The integer index of the pipeline stage initiating the send operation.

microbatch_idx int

The integer index of the microbatch being sent.

Source code in d9d/pipelining/infra/schedule/component/runtime/action.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@dataclasses.dataclass(frozen=True, slots=True)
class ForwardSendAction(ActionBase):
    """
    Action to schedule a forward pass tensor send operation.

    Attributes:
        stage_idx: The integer index of the pipeline stage initiating the send operation.
        microbatch_idx: The integer index of the microbatch being sent.
    """

    stage_idx: int
    microbatch_idx: int

    def apply(self, ctx: ActionContext):
        ctx.communications.schedule_fwd_send(self.stage_idx, self.microbatch_idx)

    @property
    def work_type(self) -> ActionWorkType:
        return ActionWorkType.communicate

    @property
    def has_backward_work(self) -> bool:
        return False

    def __str__(self) -> str:
        return f"{self.stage_idx}SEND_F{self.microbatch_idx}"

PipelineScheduleExecutor

Bases: PipelineSchedule

Executes a defined pipeline schedule by interpreting a sequence of actions.

Source code in d9d/pipelining/infra/schedule/component/runtime/executor.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class PipelineScheduleExecutor(PipelineSchedule):
    """Executes a defined pipeline schedule by interpreting a sequence of actions."""

    def __init__(
            self,
            dist_context: DistributedContext,
            stages: list[PipelineStage],
            num_microbatches: int,
            loss_fn: LossFn | None,
            program: dict[int, list[ActionBase]],
            sharding_spec: PipelineShardingSpec
    ):
        """
        Constructs the schedule executor.

        Args:
            dist_context: The distributed context.
            stages: List of stages managed by this executor.
            num_microbatches: Number of microbatches the global batch is split.
            loss_fn: Function to compute loss.
            program: The execution plan mapping rank ID to a list of actions.
            sharding_spec: Sharding specification for input and output tensors.
        """

        self._dist_ctx = dist_context
        self._stages = {stage.info.current_stage: stage for stage in stages}
        self._num_microbatches = num_microbatches
        self._program = program

        self._has_backward = any(any(
            action.has_backward_work for action in sub_program
        ) for sub_program in program.values())

        self._comm_handler = PipelineCommunicationHandler(self._stages)
        if loss_fn is None:
            self._loss_handler = None
        else:
            self._loss_handler = PipelineLossHandler(loss_fn)

        # these could be late-initialized on configure_buffers \/
        self._input_data_sharding_spec = sharding_spec.input_data
        self._input_kwargs_sharding_spec = sharding_spec.input_kwargs

    def configure_buffers(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
        if self._input_data_sharding_spec is None:
            self._input_data_sharding_spec = shard_spec_on_dim(inputs, dim=0)
        if self._input_kwargs_sharding_spec is None:
            self._input_kwargs_sharding_spec = shard_spec_on_dim(kwargs, dim=0)

        for stage in self._stages.values():
            stage.configure_buffers(
                num_microbatches=self._num_microbatches,
                pipeline_inputs=inputs,
                has_backward=self._has_backward
            )

    def step(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
        if self._input_data_sharding_spec is None or self._input_kwargs_sharding_spec is None:
            raise ValueError("Please configure sharding specs first")

        self._dist_ctx.logger.debug("Begin pipeline step")
        pp_group = self._dist_ctx.mesh_for(REGULAR_DOMAIN).get_group("pp")

        for stage in self._stages.values():
            stage.reset()

        # Shard inputs and kwargs to microbatches
        inputs_shard = shard_tree(
            inputs,
            num_shards=self._num_microbatches,
            sharding_spec=self._input_data_sharding_spec,
            enforce_even_split=True
        )
        kwargs_shard = shard_tree(
            kwargs,
            num_shards=self._num_microbatches,
            sharding_spec=self._input_kwargs_sharding_spec,
            enforce_even_split=True
        )

        my_program = self._program[pp_group.rank()]

        for action in my_program:
            with record_function(str(action)):
                self._dist_ctx.logger.debug(f"Running pipeline action {action}")
                action.apply(ActionContext(
                    loss=self._loss_handler,
                    stages=self._stages,
                    communications=self._comm_handler,
                    pipeline_inputs_microbatches=inputs_shard,
                    pipeline_kwargs_microbatches=kwargs_shard
                ))

        self._dist_ctx.logger.debug("Waiting for potentially hanging PP send comms")
        self._comm_handler.wait_send_all()  # finalize just in case
        self._dist_ctx.logger.debug("End pipeline step")

__init__(dist_context, stages, num_microbatches, loss_fn, program, sharding_spec)

Constructs the schedule executor.

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context.

required
stages list[PipelineStage]

List of stages managed by this executor.

required
num_microbatches int

Number of microbatches the global batch is split.

required
loss_fn LossFn | None

Function to compute loss.

required
program dict[int, list[ActionBase]]

The execution plan mapping rank ID to a list of actions.

required
sharding_spec PipelineShardingSpec

Sharding specification for input and output tensors.

required
Source code in d9d/pipelining/infra/schedule/component/runtime/executor.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(
        self,
        dist_context: DistributedContext,
        stages: list[PipelineStage],
        num_microbatches: int,
        loss_fn: LossFn | None,
        program: dict[int, list[ActionBase]],
        sharding_spec: PipelineShardingSpec
):
    """
    Constructs the schedule executor.

    Args:
        dist_context: The distributed context.
        stages: List of stages managed by this executor.
        num_microbatches: Number of microbatches the global batch is split.
        loss_fn: Function to compute loss.
        program: The execution plan mapping rank ID to a list of actions.
        sharding_spec: Sharding specification for input and output tensors.
    """

    self._dist_ctx = dist_context
    self._stages = {stage.info.current_stage: stage for stage in stages}
    self._num_microbatches = num_microbatches
    self._program = program

    self._has_backward = any(any(
        action.has_backward_work for action in sub_program
    ) for sub_program in program.values())

    self._comm_handler = PipelineCommunicationHandler(self._stages)
    if loss_fn is None:
        self._loss_handler = None
    else:
        self._loss_handler = PipelineLossHandler(loss_fn)

    # these could be late-initialized on configure_buffers \/
    self._input_data_sharding_spec = sharding_spec.input_data
    self._input_kwargs_sharding_spec = sharding_spec.input_kwargs

d9d.pipelining.infra.schedule.component.program

Pipeline Schedule Building Components.

This package provides the core building blocks and compiler passes used to generate execution schedules for distributed pipelines.

PipelineProgramBuilder

Bases: ABC

Abstract interface for building pipeline execution schedules.

Source code in d9d/pipelining/infra/schedule/component/program/base.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class PipelineProgramBuilder(abc.ABC):
    """Abstract interface for building pipeline execution schedules."""

    @abc.abstractmethod
    def compose(self, num_microbatches: int, pp_size: int) -> dict[int, list[ActionBase]]:
        """
        Generates the execution program for all ranks in the pipeline.

        Args:
            num_microbatches: Number of microbatches per step.
            pp_size: Number of pipeline parallel ranks.

        Returns:
            A dictionary mapping rank indices to their list of sequential actions.
        """
        ...

    @property
    @abc.abstractmethod
    def num_stages_per_rank(self) -> int:
        """Returns the number of model stages designated for each rank."""

        ...

    @property
    @abc.abstractmethod
    def topology_style(self) -> ScheduleStyle:
        """Returns the topology style strategy used to assign stages to ranks."""
        ...

num_stages_per_rank abstractmethod property

Returns the number of model stages designated for each rank.

topology_style abstractmethod property

Returns the topology style strategy used to assign stages to ranks.

compose(num_microbatches, pp_size) abstractmethod

Generates the execution program for all ranks in the pipeline.

Parameters:

Name Type Description Default
num_microbatches int

Number of microbatches per step.

required
pp_size int

Number of pipeline parallel ranks.

required

Returns:

Type Description
dict[int, list[ActionBase]]

A dictionary mapping rank indices to their list of sequential actions.

Source code in d9d/pipelining/infra/schedule/component/program/base.py
10
11
12
13
14
15
16
17
18
19
20
21
22
@abc.abstractmethod
def compose(self, num_microbatches: int, pp_size: int) -> dict[int, list[ActionBase]]:
    """
    Generates the execution program for all ranks in the pipeline.

    Args:
        num_microbatches: Number of microbatches per step.
        pp_size: Number of pipeline parallel ranks.

    Returns:
        A dictionary mapping rank indices to their list of sequential actions.
    """
    ...

ScheduleStyle

Bases: StrEnum

Defines the strategy for mapping logical stages to physical ranks.

Attributes:

Name Type Description
loop

Assigns stages in a round-robin circular fashion (mod pp_size).

v

Assigns stages in a zig-zag V-shape pattern. Useful for interleaved 1F1B schedules.

Source code in d9d/pipelining/infra/schedule/component/program/topology.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
class ScheduleStyle(StrEnum):
    """
    Defines the strategy for mapping logical stages to physical ranks.

    Attributes:
        loop: Assigns stages in a round-robin circular fashion (mod pp_size).
        v: Assigns stages in a zig-zag V-shape pattern. Useful for interleaved 1F1B schedules.
    """

    loop = "loop"
    v = "v"

add_communication_ops(compute_actions, stage_to_rank, num_stages)

Injects communication actions into a computation-only schedule.

This function iterates through the provided compute schedule and simulates execution. When a compute action produces a result needed by a different rank, it injects Send/Receive pairs. It also reorders actions to ensure that Receive operations occur before the Computes that depend on them, preventing deadlocks.

Parameters:

Name Type Description Default
compute_actions dict[int, list[ActionBase]]

Initial schedule containing only compute operations.

required
stage_to_rank dict[int, int]

Mapping from stage index to rank index.

required
num_stages int

Total number of pipeline stages.

required

Returns:

Type Description
dict[int, list[ActionBase]]

A new schedule dictionary including both compute and communication actions.

Raises:

Type Description
RuntimeError

If the schedule simulation enters a deadlock state.

Source code in d9d/pipelining/infra/schedule/component/program/communications.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def add_communication_ops(
        compute_actions: dict[int, list[ActionBase]],
        stage_to_rank: dict[int, int],
        num_stages: int,
) -> dict[int, list[ActionBase]]:
    """
    Injects communication actions into a computation-only schedule.

    This function iterates through the provided compute schedule and simulates execution.
    When a compute action produces a result needed by a different rank, it injects
    Send/Receive pairs. It also reorders actions to ensure that Receive
    operations occur before the Computes that depend on them, preventing deadlocks.

    Args:
        compute_actions: Initial schedule containing only compute operations.
        stage_to_rank: Mapping from stage index to rank index.
        num_stages: Total number of pipeline stages.

    Returns:
        A new schedule dictionary including both compute and communication actions.

    Raises:
        RuntimeError: If the schedule simulation enters a deadlock state.
    """

    compute_actions = copy.deepcopy(compute_actions)

    full_actions: dict[int, list[ActionBase]] = {rank: [] for rank in compute_actions}
    completed_events: dict[int, set[ActionBase]] = {rank: set() for rank in compute_actions}

    while compute_actions:
        progress = False

        for rank in sorted(compute_actions.keys()):
            if not compute_actions[rank]:
                del compute_actions[rank]
                continue

            current_action = compute_actions[rank][0]
            sub_actions = _get_sub_actions(current_action)

            # Check readiness
            if not check_action_communication_dependencies_fulfilled(
                    current_action, completed_events[rank], num_stages
            ):
                continue

            # Execute
            full_actions[rank].append(current_action)
            compute_actions[rank].pop(0)
            progress = True

            for sub_action in sub_actions:
                completed_events[rank].add(sub_action)

                comm_pkg = _create_communications_for_action(
                    sub_action,
                    num_stages=num_stages,
                    stage_to_rank=stage_to_rank
                )
                if comm_pkg:
                    # Add Send locally
                    full_actions[rank].append(comm_pkg.send)
                    completed_events[rank].add(comm_pkg.send)

                    # Add Recv remotely and unblock target
                    full_actions[comm_pkg.sends_to_rank].append(comm_pkg.recv)
                    completed_events[comm_pkg.sends_to_rank].add(comm_pkg.recv)

        if not progress and compute_actions:
            raise RuntimeError("Deadlock in schedule simulation")

    return full_actions

build_stage_to_host_rank_topology(pp_size, num_stages, style)

Constructs the mapping from stage index to rank index.

Parameters:

Name Type Description Default
pp_size int

Number of pipeline parallel ranks.

required
num_stages int

Total number of model stages.

required
style ScheduleStyle

The topology style to use for assignment.

required

Returns:

Type Description
dict[int, int]

A dictionary mapping stage IDs to Rank IDs.

Raises:

Type Description
ValueError

If the style is unknown or if V-style parameters are invalid (num_stages must be divisible by pp_size).

Source code in d9d/pipelining/infra/schedule/component/program/topology.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def build_stage_to_host_rank_topology(
    pp_size: int, num_stages: int, style: ScheduleStyle
) -> dict[int, int]:
    """
    Constructs the mapping from stage index to rank index.

    Args:
        pp_size: Number of pipeline parallel ranks.
        num_stages: Total number of model stages.
        style: The topology style to use for assignment.

    Returns:
        A dictionary mapping stage IDs to Rank IDs.

    Raises:
        ValueError: If the style is unknown or if V-style parameters are invalid
            (num_stages must be divisible by pp_size).
    """

    match style:
        case ScheduleStyle.loop:
            return {stage_index: stage_index % pp_size for stage_index in range(num_stages)}
        case ScheduleStyle.v:
            if num_stages % pp_size != 0:
                raise ValueError(
                    f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules"
                )

            result = {}
            rank_index = 0
            for stage_index in range(num_stages):
                result[stage_index] = rank_index
                if (stage_index + 1) % pp_size == 0:
                    continue
                if (stage_index // pp_size) % 2 == 0:
                    rank_index += 1
                else:
                    rank_index -= 1
            return result
        case _:
            raise ValueError()

invert_stage_to_host_rank_topology(stage_to_host)

Inverts the topology mapping to list execution stages per rank.

Parameters:

Name Type Description Default
stage_to_host dict[int, int]

Mapping from stage index to rank index.

required

Returns:

Type Description
dict[int, list[int]]

A dictionary where keys are Rank IDs and values are lists of Stage IDs

dict[int, list[int]]

managed by that rank.

Source code in d9d/pipelining/infra/schedule/component/program/topology.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def invert_stage_to_host_rank_topology(
        stage_to_host: dict[int, int]
) -> dict[int, list[int]]:
    """
    Inverts the topology mapping to list execution stages per rank.

    Args:
        stage_to_host: Mapping from stage index to rank index.

    Returns:
        A dictionary where keys are Rank IDs and values are lists of Stage IDs
        managed by that rank.
    """

    host_to_stage = defaultdict(list)
    for stage_idx, host in stage_to_host.items():
        host_to_stage[host].append(stage_idx)
    return dict(host_to_stage)

d9d.pipelining.infra.schedule.program

Pipeline Schedule Implementations

DualPipeVPipelineProgramBuilder

Bases: PipelineProgramBuilder

Builder for the DualPipeV Pipeline Parallelism schedule.

DualPipeV is a specialized bi-directional pipeline schedule designed for high throughput training. It requires exactly 2 stages per pipeline rank (V-shape) and utilizes split backward passes (Input gradients vs Weight gradients) to fill pipeline bubbles.

References

https://github.com/deepseek-ai/DualPipe https://hackmd.io/@ufotalent/r1lVXsa9Jg

Source code in d9d/pipelining/infra/schedule/program/dualpipev.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class DualPipeVPipelineProgramBuilder(PipelineProgramBuilder):
    """
    Builder for the DualPipeV Pipeline Parallelism schedule.

    DualPipeV is a specialized bi-directional pipeline schedule designed for high
    throughput training. It requires exactly 2 stages per pipeline rank (V-shape)
    and utilizes split backward passes (Input gradients vs Weight gradients)
    to fill pipeline bubbles.

    References:
        https://github.com/deepseek-ai/DualPipe
        https://hackmd.io/@ufotalent/r1lVXsa9Jg
    """

    def __init__(self):
        """
        Constructs the DualPipeV builder.
        """

    @staticmethod
    def _build_for_rank(  # noqa: C901
            rank: int, stage_to_rank: dict[int, int], num_microbatches: int, pp_size: int
    ) -> list[ActionBase]:
        compute_actions: list[ActionBase] = []

        # Identify local stages: s0 is Phase 0, s1 is Phase 1
        my_stages = sorted([s for s, r in stage_to_rank.items() if r == rank])
        s0, s1 = my_stages[0], my_stages[1]

        # Track microbatch indices for each stage and operation type
        # f_idx: Next Forward microbatch
        # b_idx: Next Backward microbatch (Input or Full)
        f_idx = {s0: 0, s1: 0}
        b_idx = {s0: 0, s1: 0}

        # Queue for Zero Bubble optimization: stores (stage, mb_idx) for deferred weight grads
        weight_queue: deque[tuple[int, int]] = deque()

        # --- Helper Functions for Action Emission ---

        def _add_f(stage: int):
            compute_actions.append(
                ForwardComputeAction(stage_idx=stage, microbatch_idx=f_idx[stage])
            )
            f_idx[stage] += 1

        def _add_b_full(stage: int):
            compute_actions.append(
                BackwardFullInputComputeAction(
                    stage_idx=stage,
                    microbatch_idx=b_idx[stage],
                    full_backward=True,
                )
            )
            b_idx[stage] += 1

        def _add_b_input(stage: int):
            mb = b_idx[stage]
            compute_actions.append(
                BackwardFullInputComputeAction(
                    stage_idx=stage,
                    microbatch_idx=mb,
                    full_backward=False,
                )
            )
            weight_queue.append((stage, mb))
            b_idx[stage] += 1

        def _pop_w():
            if not weight_queue:
                return
            s, mb = weight_queue.popleft()
            compute_actions.append(
                BackwardWeightComputeAction(stage_idx=s, microbatch_idx=mb)
            )

        def _add_overlap_f_b(stage_f: int, stage_b: int, b_is_full: bool):
            """Emit overlapped Forward and Backward actions."""
            mb_f = f_idx[stage_f]
            mb_b = b_idx[stage_b]

            act_f = ForwardComputeAction(stage_idx=stage_f, microbatch_idx=mb_f)

            act_b = BackwardFullInputComputeAction(
                stage_idx=stage_b, microbatch_idx=mb_b, full_backward=b_is_full
            )
            if not b_is_full:
                weight_queue.append((stage_b, mb_b))

            f_idx[stage_f] += 1
            b_idx[stage_b] += 1

            # Note: d9d infra treats ComposeAction sequentially in simulation,
            # but runtime may overlap them.
            compute_actions.append(ComposeAction(actions=(act_f, act_b)))

        # Step 1: nF0 (Startup Phase 0)
        step_1 = (pp_size - rank - 1) * 2
        for _ in range(step_1):
            _add_f(s0)

        # Step 2: nF0F1 (Forward fill)
        step_2 = rank + 1
        for _ in range(step_2):
            _add_f(s0)
            _add_f(s1)

        # Step 3: nI1W1F1 (Mixed Phase with Zero Bubble)
        step_3 = pp_size - rank - 1
        for _ in range(step_3):
            _add_b_input(s1)  # Backward Input Phase 1
            _pop_w()  # Weight Phase (accumulated from prev)
            _add_f(s1)  # Forward Phase 1

        # Step 4: The Main Loop (Interleaved Forward/Backward)
        step_4 = num_microbatches - 2 * pp_size + rank + 1
        for i in range(step_4):
            # Sub-step A: F0 & B1
            if i == 0 and rank == pp_size - 1:
                # Specific case for last rank on first iter: do not overlap
                _add_f(s0)
                _add_b_full(s1)
            else:
                # Overlap F0 and B1 (usually full backward unless we were in ZB mode,
                # but DualPipeV main loop defaults to full for simplicity unless tuned)
                # DeepSeek impl uses standard backward here (zb=False).
                _add_overlap_f_b(stage_f=s0, stage_b=s1, b_is_full=True)

            # Sub-step B: F1 & B0
            # Overlap F1 and B0 (Full)
            _add_overlap_f_b(stage_f=s1, stage_b=s0, b_is_full=True)

        # Step 5: Cooldown F1/B0
        step_5 = pp_size - rank - 1
        for _ in range(step_5):
            _add_b_full(s1)
            _add_overlap_f_b(stage_f=s1, stage_b=s0, b_is_full=True)

        # Step 6: Cooldown B1/B0 with Zero Bubble ramp-up
        step_6 = rank + 1
        enable_zb = False
        for i in range(step_6):
            # Phase 1 Backward
            if i == step_6 // 2 and rank % 2 == 1:
                enable_zb = True

            if enable_zb:
                _add_b_input(s1)
            else:
                _add_b_full(s1)

            # Phase 0 Backward
            if i == step_6 // 2 and rank % 2 == 0:
                enable_zb = True

            if enable_zb:
                _add_b_input(s0)
            else:
                _add_b_full(s0)

        # Step 7: Zero Bubble Weights + B0
        step_7 = pp_size - rank - 1
        for _ in range(step_7):
            _pop_w()
            # DeepSeek source explicitly uses enable_zb=True here for chunk 0
            _add_b_input(s0)

        # Step 8: Flush Weights
        step_8 = rank + 1
        for _ in range(step_8):
            _pop_w()

        return compute_actions

    def compose(
            self, num_microbatches: int, pp_size: int
    ) -> dict[int, list[ActionBase]]:
        num_stages = self.num_stages_per_rank * pp_size

        if num_microbatches < num_stages:
            raise ValueError(
                f"DualPipeV requires num_microbatches ({num_microbatches}) >= "
                f"num_stages ({num_stages})."
            )

        # Ranks hold stages in a V pattern (e.g., Rank 0 holds Stage 0 and Stage N-1).
        # We rely on the sorted order of local steps to determine Phase 0 (Forward-going)
        # and Phase 1 (Backward-coming).
        stage_to_rank = build_stage_to_host_rank_topology(
            pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.v
        )

        compute_actions: dict[int, list[ActionBase]] = {r: [] for r in range(pp_size)}

        for rank in range(pp_size):
            compute_actions[rank] = self._build_for_rank(
                rank=rank,
                pp_size=pp_size,
                num_microbatches=num_microbatches,
                stage_to_rank=stage_to_rank
            )

        # 4. Inject Communication Operations
        # This wrapper handles dependency analysis and inserts Send/Recv/Wait ops.
        return add_communication_ops(
            compute_actions=compute_actions,
            stage_to_rank=stage_to_rank,
            num_stages=num_stages
        )

    @property
    def num_stages_per_rank(self) -> int:
        return 2

    @property
    def topology_style(self) -> ScheduleStyle:
        return ScheduleStyle.v

__init__()

Constructs the DualPipeV builder.

Source code in d9d/pipelining/infra/schedule/program/dualpipev.py
32
33
34
35
def __init__(self):
    """
    Constructs the DualPipeV builder.
    """

Interleaved1F1BPipelineProgramBuilder

Bases: PipelineProgramBuilder

Builder for Interleaved Pipeline Parallelism schedules.

This builder supports:

  1. Standard Interleaved 1F1B: Assigns multiple stages per rank and prioritizes depth-first execution. (See https://arxiv.org/pdf/2104.04473)
  2. Interleaved Zero Bubble (ZB1P): Extends 1F1B by splitting backward passes into Input Gradients and Weight Gradients. Weight gradients are delayed to fill pipeline bubbles. (See https://arxiv.org/pdf/2401.10241)
Source code in d9d/pipelining/infra/schedule/program/interleaved.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class Interleaved1F1BPipelineProgramBuilder(PipelineProgramBuilder):
    """
    Builder for Interleaved Pipeline Parallelism schedules.

    This builder supports:

    1.  **Standard Interleaved 1F1B**: Assigns multiple stages per rank and prioritizes
        depth-first execution. (See https://arxiv.org/pdf/2104.04473)
    2.  **Interleaved Zero Bubble (ZB1P)**: Extends 1F1B by splitting backward passes
        into Input Gradients and Weight Gradients. Weight gradients are delayed
        to fill pipeline bubbles. (See https://arxiv.org/pdf/2401.10241)
    """

    def __init__(self, num_stages_per_rank: int, enable_zero_bubble: bool = False):
        """
        Constructs the Interleaved 1F1B builder.

        Args:
            num_stages_per_rank: Number of stages per rank.
            enable_zero_bubble: If True, uses the ZB1P schedule variant which
                splits backward passes to reduce bubble size.
        """
        self._num_stages_per_rank = num_stages_per_rank
        self._enable_zero_bubble = enable_zero_bubble

    def _get_warmup_ops(
            self,
            rank: int,
            microbatches_per_round: int,
            pp_size: int,
            n_microbatches: int,
            multiply_factor: int,
    ) -> int:
        """
        Calculates the number of warmup steps required before entering steady state.
        """
        warmups_ops_last_stage = (self._num_stages_per_rank - 1) * microbatches_per_round
        warmup_ops = warmups_ops_last_stage + multiply_factor * ((pp_size - 1) - rank)
        return min(warmup_ops, n_microbatches * self._num_stages_per_rank)

    def compose(
            self, num_microbatches: int, pp_size: int
    ) -> dict[int, list[ActionBase]]:
        """
        Generates the execution program for all ranks.

        Args:
            num_microbatches: Total microbatches. Must be divisible by the derived
                number of rounds.
            pp_size: Number of pipeline ranks.

        Returns:
            A dictionary mapping rank indices to their list of sequential actions.
        """
        num_stages = self.num_stages_per_rank * pp_size

        if num_stages % pp_size != 0:
            raise ValueError(
                f"num_stages ({num_stages}) must be divisible by pp_size ({pp_size}) "
                "for interleaved schedules."
            )

        # 1. Topology Setup
        # Use Loop/Round-Robin assignment: Rank 0 gets Stage 0, PP, 2*PP...
        stage_to_rank = build_stage_to_host_rank_topology(
            pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.loop
        )

        num_rounds = max(1, num_microbatches // pp_size)

        if num_microbatches % num_rounds != 0:
            raise ValueError(
                f"microbatches ({num_microbatches}) must be divisible by rounds ({num_rounds})."
            )

        microbatches_per_round = num_microbatches // num_rounds

        # 2. Schedule Generation
        actions: dict[int, list[ActionBase]] = {}

        # Zero Bubble 1f1b uses a shorter warmup heuristic (factor 1) than Standard (factor 2)
        warmup_multiplier = 1 if self._enable_zero_bubble else 2

        for rank in range(pp_size):
            actions[rank] = self._generate_rank_schedule(
                rank=rank,
                pp_size=pp_size,
                n_microbatches=num_microbatches,
                microbatches_per_round=microbatches_per_round,
                multiply_factor=warmup_multiplier,
            )

        # 3. Communication Injection
        return add_communication_ops(
            compute_actions=actions,
            stage_to_rank=stage_to_rank,
            num_stages=num_stages,
        )

    def _generate_rank_schedule(  # noqa: C901
            self,
            rank: int,
            pp_size: int,
            n_microbatches: int,
            microbatches_per_round: int,
            multiply_factor: int,
    ) -> list[ActionBase]:
        """
        Generates the sequential list of compute actions for a specific rank.
        """
        rank_actions: list[ActionBase] = []

        # -- State Tracking --
        # Map: stage_idx -> next_microbatch_idx
        fwd_counters: dict[int, int] = defaultdict(int)
        bwd_counters: dict[int, int] = defaultdict(int)

        # FIFO Queue for deferred weight gradients in Zero Bubble
        # Stores: (stage_idx, microbatch_idx)
        pending_weights: deque[tuple[int, int]] = deque()

        # -- Helpers --

        def get_global_stage(local_idx: int) -> int:
            """Converts a local virtual stage index (0..N) to global stage ID."""
            return (local_idx * pp_size) + rank

        def get_fwd_local_idx(op_idx: int) -> int:
            return (op_idx // microbatches_per_round) % self._num_stages_per_rank

        def get_bwd_local_idx(op_idx: int, warmup_offset: int) -> int:
            return (self._num_stages_per_rank
                    - 1
                    - ((op_idx - warmup_offset) // microbatches_per_round) % self._num_stages_per_rank)

        def emit_forward(op_idx: int):
            local_idx = get_fwd_local_idx(op_idx)
            stage = get_global_stage(local_idx)
            mb = fwd_counters[stage]

            rank_actions.append(ForwardComputeAction(stage_idx=stage, microbatch_idx=mb))
            fwd_counters[stage] += 1

        def emit_backward(op_idx: int, warmup_offset: int):
            local_idx = get_bwd_local_idx(op_idx, warmup_offset)
            stage = get_global_stage(local_idx)
            mb = bwd_counters[stage]

            # In Zero Bubble, we split: Backward Input (Now) + Backward Weight (Later)
            # In Standard 1F1B, we do full backward now.
            is_full = not self._enable_zero_bubble

            rank_actions.append(
                BackwardFullInputComputeAction(
                    stage_idx=stage,
                    microbatch_idx=mb,
                    full_backward=is_full
                )
            )

            if self._enable_zero_bubble:
                pending_weights.append((stage, mb))

            bwd_counters[stage] += 1

        def try_emit_weight_zb(op_idx: int, warmup_offset: int):
            if not self._enable_zero_bubble or not pending_weights:
                return

            steps_into_1f1b = op_idx - warmup_offset
            # The earliest reasonable time to start weaving in weights is proportional to rank depth
            if steps_into_1f1b >= rank:
                w_stage, w_mb = pending_weights.popleft()
                rank_actions.append(
                    BackwardWeightComputeAction(stage_idx=w_stage, microbatch_idx=w_mb)
                )

        # -- Execution Phase Math --

        warmup_ops = self._get_warmup_ops(
            rank, microbatches_per_round, pp_size, n_microbatches, multiply_factor
        )
        total_microbatch_ops = self._num_stages_per_rank * n_microbatches
        fwd_bwd_ops = total_microbatch_ops - warmup_ops
        cooldown_ops = total_microbatch_ops - fwd_bwd_ops

        # Combine into one sequence for iteration, but handle logic per phase
        total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops

        # -- Main Schedule Loop --

        for op in range(total_ops):

            # Phase 1: Warmup (Forward Only)
            if op < warmup_ops:
                emit_forward(op)

            # Phase 2: Steady State (1F1B)
            elif op < warmup_ops + fwd_bwd_ops:
                emit_forward(op)
                emit_backward(op, warmup_offset=warmup_ops)
                try_emit_weight_zb(op, warmup_offset=warmup_ops)

            # Phase 3: Cooldown (Backward Only)
            else:
                emit_backward(op, warmup_offset=warmup_ops)
                try_emit_weight_zb(op, warmup_offset=warmup_ops)

        # -- Post-Loop: Flush Remaining Weights (ZB only) --
        while pending_weights:
            w_stage, w_mb = pending_weights.popleft()
            rank_actions.append(
                BackwardWeightComputeAction(stage_idx=w_stage, microbatch_idx=w_mb)
            )

        return rank_actions

    @property
    def num_stages_per_rank(self) -> int:
        return self._num_stages_per_rank

    @property
    def topology_style(self) -> ScheduleStyle:
        return ScheduleStyle.loop

__init__(num_stages_per_rank, enable_zero_bubble=False)

Constructs the Interleaved 1F1B builder.

Parameters:

Name Type Description Default
num_stages_per_rank int

Number of stages per rank.

required
enable_zero_bubble bool

If True, uses the ZB1P schedule variant which splits backward passes to reduce bubble size.

False
Source code in d9d/pipelining/infra/schedule/program/interleaved.py
30
31
32
33
34
35
36
37
38
39
40
def __init__(self, num_stages_per_rank: int, enable_zero_bubble: bool = False):
    """
    Constructs the Interleaved 1F1B builder.

    Args:
        num_stages_per_rank: Number of stages per rank.
        enable_zero_bubble: If True, uses the ZB1P schedule variant which
            splits backward passes to reduce bubble size.
    """
    self._num_stages_per_rank = num_stages_per_rank
    self._enable_zero_bubble = enable_zero_bubble

compose(num_microbatches, pp_size)

Generates the execution program for all ranks.

Parameters:

Name Type Description Default
num_microbatches int

Total microbatches. Must be divisible by the derived number of rounds.

required
pp_size int

Number of pipeline ranks.

required

Returns:

Type Description
dict[int, list[ActionBase]]

A dictionary mapping rank indices to their list of sequential actions.

Source code in d9d/pipelining/infra/schedule/program/interleaved.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def compose(
        self, num_microbatches: int, pp_size: int
) -> dict[int, list[ActionBase]]:
    """
    Generates the execution program for all ranks.

    Args:
        num_microbatches: Total microbatches. Must be divisible by the derived
            number of rounds.
        pp_size: Number of pipeline ranks.

    Returns:
        A dictionary mapping rank indices to their list of sequential actions.
    """
    num_stages = self.num_stages_per_rank * pp_size

    if num_stages % pp_size != 0:
        raise ValueError(
            f"num_stages ({num_stages}) must be divisible by pp_size ({pp_size}) "
            "for interleaved schedules."
        )

    # 1. Topology Setup
    # Use Loop/Round-Robin assignment: Rank 0 gets Stage 0, PP, 2*PP...
    stage_to_rank = build_stage_to_host_rank_topology(
        pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.loop
    )

    num_rounds = max(1, num_microbatches // pp_size)

    if num_microbatches % num_rounds != 0:
        raise ValueError(
            f"microbatches ({num_microbatches}) must be divisible by rounds ({num_rounds})."
        )

    microbatches_per_round = num_microbatches // num_rounds

    # 2. Schedule Generation
    actions: dict[int, list[ActionBase]] = {}

    # Zero Bubble 1f1b uses a shorter warmup heuristic (factor 1) than Standard (factor 2)
    warmup_multiplier = 1 if self._enable_zero_bubble else 2

    for rank in range(pp_size):
        actions[rank] = self._generate_rank_schedule(
            rank=rank,
            pp_size=pp_size,
            n_microbatches=num_microbatches,
            microbatches_per_round=microbatches_per_round,
            multiply_factor=warmup_multiplier,
        )

    # 3. Communication Injection
    return add_communication_ops(
        compute_actions=actions,
        stage_to_rank=stage_to_rank,
        num_stages=num_stages,
    )

LoopedBFSPipelineProgramBuilder

Bases: PipelineProgramBuilder

Builder for the Breadth-First Pipeline Parallelism schedule.

This schedule runs all available forward microbatches for local stages first. If configured for training, it then runs backwards in reverse topological order.

References

https://arxiv.org/pdf/2211.05953

Source code in d9d/pipelining/infra/schedule/program/bfs.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class LoopedBFSPipelineProgramBuilder(PipelineProgramBuilder):
    """
    Builder for the Breadth-First Pipeline Parallelism schedule.

    This schedule runs all available forward microbatches for local stages first.
    If configured for training, it then runs backwards in reverse topological order.

    References:
        https://arxiv.org/pdf/2211.05953
    """

    def __init__(self, num_stages_per_rank: int, inference_mode: bool = False):
        """
        Constructs the LoopedBFS builder.

        Args:
            num_stages_per_rank: Number of stages per rank.
            inference_mode: If True, only forward passes are scheduled. If False,
                both forward and backward passes are scheduled.
        """
        self._num_stages_per_rank = num_stages_per_rank
        self._inference_mode = inference_mode

    def compose(self, num_microbatches: int, pp_size: int) -> dict[int, list[ActionBase]]:
        num_stages = self._num_stages_per_rank * pp_size
        stage_to_rank = build_stage_to_host_rank_topology(
            pp_size=pp_size,
            num_stages=num_stages,
            style=ScheduleStyle.loop
        )

        compute_actions: dict[int, list[ActionBase]] = {r: [] for r in range(pp_size)}

        for rank in range(pp_size):
            my_stages = [s for s in range(num_stages) if stage_to_rank[s] == rank]

            # Schedule all Forwards
            # In Breadth-First loops, we finish all microbatches for the current stage
            # before moving to the next stage assigned to this rank.
            for stage_idx in my_stages:
                for mb_idx in range(num_microbatches):
                    compute_actions[rank].append(
                        ForwardComputeAction(
                            stage_idx=stage_idx,
                            microbatch_idx=mb_idx
                        )
                    )

            # Schedule all Backwards (Reverse order) - Only if training
            if not self._inference_mode:
                for stage_idx in reversed(my_stages):
                    for mb_idx in reversed(range(num_microbatches)):
                        compute_actions[rank].append(
                            BackwardFullInputComputeAction(
                                stage_idx=stage_idx,
                                microbatch_idx=mb_idx,
                                full_backward=True
                            )
                        )

        return add_communication_ops(
            compute_actions=compute_actions,
            stage_to_rank=stage_to_rank,
            num_stages=num_stages
        )

    @property
    def num_stages_per_rank(self) -> int:
        return self._num_stages_per_rank

    @property
    def topology_style(self) -> ScheduleStyle:
        return ScheduleStyle.loop

__init__(num_stages_per_rank, inference_mode=False)

Constructs the LoopedBFS builder.

Parameters:

Name Type Description Default
num_stages_per_rank int

Number of stages per rank.

required
inference_mode bool

If True, only forward passes are scheduled. If False, both forward and backward passes are scheduled.

False
Source code in d9d/pipelining/infra/schedule/program/bfs.py
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, num_stages_per_rank: int, inference_mode: bool = False):
    """
    Constructs the LoopedBFS builder.

    Args:
        num_stages_per_rank: Number of stages per rank.
        inference_mode: If True, only forward passes are scheduled. If False,
            both forward and backward passes are scheduled.
    """
    self._num_stages_per_rank = num_stages_per_rank
    self._inference_mode = inference_mode

ZeroBubbleVPipelineProgramBuilder

Bases: PipelineProgramBuilder

Builder for the Zero Bubble V (ZBV) Pipeline Schedule.

This schedule is designed for V-shape topologies (2 stages per rank) and utilizes the Zero Bubble optimizations by splitting backward passes.

It requires exactly two stages per rank organized in a V-shape topology and splits backward passes into Input and Weight gradients to optimize pipeline throughput.

References

https://arxiv.org/pdf/2401.10241, Section 6

Source code in d9d/pipelining/infra/schedule/program/zerobubblev.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class ZeroBubbleVPipelineProgramBuilder(PipelineProgramBuilder):
    """
    Builder for the Zero Bubble V (ZBV) Pipeline Schedule.

    This schedule is designed for V-shape topologies (2 stages per rank) and
    utilizes the Zero Bubble optimizations by splitting backward passes.

    It requires exactly two stages
    per rank organized in a V-shape topology and splits backward passes into
    Input and Weight gradients to optimize pipeline throughput.

    References:
        https://arxiv.org/pdf/2401.10241, Section 6
    """

    def __init__(self):
        """Constructs the ZBV builder."""

    def compose(
            self, num_microbatches: int, pp_size: int
    ) -> dict[int, list[ActionBase]]:
        num_stages = self.num_stages_per_rank * pp_size

        # 1. Topology
        # V-style: Rank 0 gets Stage 0 & Stage N-1. Rank 1 gets Stage 1 & Stage N-2...
        stage_to_rank = build_stage_to_host_rank_topology(
            pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.v
        )

        actions: dict[int, list[ActionBase]] = {}

        for rank in range(pp_size):
            actions[rank] = self._generate_rank_schedule(
                rank=rank,
                pp_size=pp_size,
                num_stages=num_stages,
                target_microbatches=num_microbatches,
            )

        # 2. Inject Communications
        return add_communication_ops(
            compute_actions=actions,
            stage_to_rank=stage_to_rank,
            num_stages=num_stages
        )

    def _generate_rank_schedule(  # noqa: C901
            self,
            rank: int,
            pp_size: int,
            num_stages: int,
            target_microbatches: int,
    ) -> list[ActionBase]:
        # ZBV logic assumes the pipeline is fully saturated to define the loop bounds.
        # We simulate enough steps to cover the topology startup, then filter
        # down to the user's requested microbatches at the end.
        simulated_n_micro = max(2 * pp_size - 1, target_microbatches)

        rank_ops: list[ActionBase] = []

        # -- Stage Identification (V-Shape) --
        # s0: The "Forward-going" chunk (e.g., Stage 0 for Rank 0)
        # s1: The "Backward-coming" chunk (e.g., Stage N-1 for Rank 0)
        s0 = rank
        s1 = num_stages - 1 - rank

        # -- Counters --
        # Track next microbatch index for each operation type on each chunk.
        # F: Forward, I: Backward Input, W: Backward Weight
        f0_cnt = 0
        b0_cnt = 0  # Input Grad Counter (Chunk 0)
        w0_cnt = 0  # Weight Grad Counter (Chunk 0)

        f1_cnt = 0
        b1_cnt = 0  # Input Grad Counter (Chunk 1)
        w1_cnt = 0  # Weight Grad Counter (Chunk 1)

        # -- Helpers --

        def emit_f(stage: int, idx: int):
            rank_ops.append(ForwardComputeAction(stage_idx=stage, microbatch_idx=idx))

        def emit_i_and_w(stage: int, idx: int):
            rank_ops.append(
                BackwardFullInputComputeAction(
                    stage_idx=stage, microbatch_idx=idx, full_backward=False
                )
            )
            rank_ops.append(
                BackwardWeightComputeAction(stage_idx=stage, microbatch_idx=idx)
            )

        def emit_i(stage: int, idx: int):
            rank_ops.append(
                BackwardFullInputComputeAction(
                    stage_idx=stage, microbatch_idx=idx, full_backward=False
                )
            )

        def emit_w(stage: int, idx: int):
            rank_ops.append(
                BackwardWeightComputeAction(stage_idx=stage, microbatch_idx=idx)
            )

        # -- Phase 1: Warmup 1 (Chunk 0 Forwards) --
        warmup_n1 = 2 * (pp_size - rank) - 1
        for _ in range(warmup_n1):
            emit_f(s0, f0_cnt)
            f0_cnt += 1

        # -- Phase 2: Warmup 2 (Interleave F1, F0) --
        warmup_n2 = rank
        for _ in range(warmup_n2):
            emit_f(s1, f1_cnt)
            f1_cnt += 1
            emit_f(s0, f0_cnt)
            f0_cnt += 1

        # -- Phase 3: Warmup 3 (F1, then B1 I+W) --
        warmup_n3 = pp_size - rank
        for _ in range(warmup_n3):
            emit_f(s1, f1_cnt)
            f1_cnt += 1

            emit_i_and_w(s1, b1_cnt)
            b1_cnt += 1
            w1_cnt += 1

        # -- Phase 4: Stable State --
        while f1_cnt < f0_cnt or f0_cnt < simulated_n_micro:
            # Emit F0 if within bounds
            if f0_cnt < simulated_n_micro:
                emit_f(s0, f0_cnt)
                f0_cnt += 1

            # Emit B0 (I+W)
            emit_i_and_w(s0, b0_cnt)
            b0_cnt += 1
            w0_cnt += 1

            # Emit F1
            emit_f(s1, f1_cnt)
            f1_cnt += 1

            # Emit B1 (I+W)
            emit_i_and_w(s1, b1_cnt)
            b1_cnt += 1
            w1_cnt += 1

        # -- Phase 5: Cooldown 1 (Splitting I and W) --
        # In cooldown, the I and W streams diverge to fill bubbles.
        cooldown_n1 = rank
        for _ in range(cooldown_n1):
            emit_i(s0, b0_cnt)
            b0_cnt += 1

            emit_i(s1, b1_cnt)
            b1_cnt += 1

        # -- Phase 6: Cooldown 2 (I0, then W0) --
        cooldown_n2 = pp_size - rank
        for _ in range(cooldown_n2):
            # Input Grad Chunk 0
            emit_i(s0, b0_cnt)
            b0_cnt += 1

            # Weight Grad Chunk 0 (delayed from previous steps)
            emit_w(s0, w0_cnt)
            w0_cnt += 1

        # -- Phase 7: Flush Remaining Weights --

        # Flush W1
        while w1_cnt < b1_cnt:
            emit_w(s1, w1_cnt)
            w1_cnt += 1

        # Flush W0
        while w0_cnt < b0_cnt:
            emit_w(s0, w0_cnt)
            w0_cnt += 1

        # -- Integrity Check --
        if not (w0_cnt == b0_cnt == f0_cnt):
            raise RuntimeError(
                f"ZBV Schedule Failed (Chunk 0): F={f0_cnt}, I={b0_cnt}, W={w0_cnt}"
            )
        if not (w1_cnt == b1_cnt == f1_cnt):
            raise RuntimeError(
                f"ZBV Schedule Failed (Chunk 1): F={f1_cnt}, I={b1_cnt}, W={w1_cnt}"
            )

        # -- Post-Process: Filter to Target Microbatches --
        # Remove any actions involving simulated microbatches beyond the user's request.
        final_ops: list[ActionBase] = []
        for action in rank_ops:
            if isinstance(action, (ForwardComputeAction,
                                   BackwardFullInputComputeAction,
                                   BackwardWeightComputeAction)):
                if action.microbatch_idx < target_microbatches:
                    final_ops.append(action)
            else:
                final_ops.append(action)

        return final_ops

    @property
    def num_stages_per_rank(self) -> int:
        return 2

    @property
    def topology_style(self) -> ScheduleStyle:
        return ScheduleStyle.v

__init__()

Constructs the ZBV builder.

Source code in d9d/pipelining/infra/schedule/program/zerobubblev.py
30
31
def __init__(self):
    """Constructs the ZBV builder."""

d9d.pipelining.training

PipelinedLRScheduler

Bases: LRSchedulerProtocol

Wrapper that manages multiple LR schedulers for a pipeline parallel rank.

Similar to PipelinedOptimizer, this aggregates schedulers corresponding to multiple model stages hosted on the current rank.

Source code in d9d/pipelining/training/scheduler.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class PipelinedLRScheduler(LRSchedulerProtocol):
    """
    Wrapper that manages multiple LR schedulers for a pipeline parallel rank.

    Similar to `PipelinedOptimizer`, this aggregates schedulers corresponding to
    multiple model stages hosted on the current rank.
    """

    def __init__(self, mesh_pp: DeviceMesh, schedulers: list[LRSchedulerProtocol]):
        self._mesh_pp = mesh_pp
        self._schedulers = schedulers

    def state_dict(self) -> dict[str, Any]:
        pp_rank = self._mesh_pp.get_local_rank()
        return {
            f"pp_{pp_rank}_stage_{i}": scheduler.state_dict()
            for i, scheduler in enumerate(self._schedulers)
        }

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        pp_rank = self._mesh_pp.get_local_rank()
        for i, scheduler in enumerate(self._schedulers):
            scheduler.load_state_dict(state_dict[f"pp_{pp_rank}_stage_{i}"])

    def step(self) -> None:
        for scheduler in self._schedulers:
            scheduler.step()

PipelinedOptimizer

Bases: OptimizerProtocol

Wrapper that manages multiple optimizers for a pipeline parallel rank.

In a pipeline parallel setup, a single rank might host multiple stages, each having its own parameters and optimizer. This class aggregates them into a single interface.

Source code in d9d/pipelining/training/optimizer.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class PipelinedOptimizer(OptimizerProtocol):
    """
    Wrapper that manages multiple optimizers for a pipeline parallel rank.

    In a pipeline parallel setup, a single rank might host multiple stages, each having its own parameters
    and optimizer.
    This class aggregates them into a single interface.
    """

    def __init__(self, mesh_pp: DeviceMesh, optimizers: list[Optimizer]):
        super().__init__()

        self._mesh_pp = mesh_pp
        self._optimizers = optimizers

    def state_dict(self) -> dict[str, Any]:
        pp_rank = self._mesh_pp.get_local_rank()
        return {
            f"pp_{pp_rank}_stage_{i}": optimizer.state_dict()
            for i, optimizer in enumerate(self._optimizers)
        }

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        pp_rank = self._mesh_pp.get_local_rank()
        for i, optimizer in enumerate(self._optimizers):
            optimizer.load_state_dict(state_dict[f"pp_{pp_rank}_stage_{i}"])

    def step(self) -> None:
        for optimizer in self._optimizers:
            optimizer.step()

    def zero_grad(self) -> None:
        for optimizer in self._optimizers:
            optimizer.zero_grad()