Autograd Extensions
About
The d9d.core.autograd package provides utilities to exert fine-grained control over PyTorch's automatic differentiation engine.
The Global Grad Context
Why
The primary purpose of so-called Global Grad Context is to solve specific limitations in torch.autograd.Function
regarding partial backward passes, which are critical for advanced distributed training schedules like Zero-Bubble Pipeline Parallelism.
In standard PyTorch operations (like torch.matmul), the autograd engine is highly optimized.
If you perform a backward pass specifying only a subset of inputs (e.g., torch.autograd.backward(..., inputs=[activations])),
PyTorch will intelligently skip computing gradients for parameters (weights) to save compute.
However, custom torch.autograd.Function implementations do not share this intelligence.
PyTorch sets ctx.needs_input_grad to True for every input that has requires_grad=True, regardless of whether
that specific edge is actually being computed in the current backward() call.
This behavior makes it impossible to implement split-backward pipeline schedules (where activation gradients and weight gradients are computed at different times) using custom operations (like GroupedGEMM) without performing redundant calculations.
For more details, see PyTorch Issue #174017.
How it Works
To bypass this limitation, d9d introduces the GlobalGradContext. It acts as a side-channel state manager that
allows the training loop to explicitly signal its intent to the custom operators.
- Orchestrator: The training loop sets the context (e.g., "I only want Input gradients now").
- Operator: The custom
backwardchecks this context. Even if PyTorch saysneeds_input_grad=True, the operator will verify withGlobalGradContextbefore computation.
Usage
In Custom Autograd Functions
When writing a custom operation, you must tag your gradients with a semantic GradDirection and check the context before computation.
In Training Loops
By default, the GLOBAL_GRAD_CONTEXT is set to compute both input and weight gradients.
The d9d pipelining API configures it for split-backward automatically. So, if you use the Trainer, everything will work out of the box.
If you use your own training loop implementation - you have to configure the context manually.
API Reference
d9d.core.autograd
GLOBAL_GRAD_CONTEXT = GlobalGradContext()
module-attribute
The singleton instance of GlobalGradContext.
This should be used by custom autograd functions to check GLOBAL_GRAD_CONTEXT.check_direction()
during their backward pass.
GlobalGradContext
Global state manager for controlling gradient computation in custom autograd functions.
This context addresses a limitation in PyTorch where custom torch.autograd.Function
implementations set ctx.needs_input_grad to True for all edges requiring grad,
even during partial backward passes (e.g., torch.autograd.backward(inputs=...)).
For additional information on this limitation, please refer to a related issue.
This class allows:
- For the training code - to explicitly signal which gradient edges (inputs vs weights) should currently be computed, allowing custom ops to skip unnecessary computations.
- For module code - to check whether it's required to compute a gradient edge.
__init__()
Constructs a GlobalGradContext object with all directions enabled by default.
check_direction(direction)
Checks if the gradient calculation for the given direction is currently enabled.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
direction
|
GradDirection | None
|
The direction to check (inputs or weights). If None, returns True. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if the direction is enabled or None is passed, False otherwise. |
with_directions(*directions)
Context manager that sets the enabled gradient directions.
This overrides the current state for the duration of the context and restores the previous state afterwards.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*directions
|
GradDirection
|
The gradient directions to enable. |
()
|
GradDirection
Bases: StrEnum
Enum representing the specific gradient edges to compute.
This is used to manually control gradient flow in custom autograd functions during split backward passes.
Attributes:
| Name | Type | Description |
|---|---|---|
inputs |
Mark gradient edge as pointing to the module's inputs (activations). |
|
weight |
Mark gradient edge as pointing to the module's parameters (weights). |