About
The d9d.core.autograd package provides utilities to exert fine-grained control over PyTorch's automatic differentiation engine.
The Global Grad Context
Why
The primary purpose of so-called Global Grad Context is to solve specific limitations in torch.autograd.Function
regarding partial backward passes, which are critical for advanced distributed training schedules like Zero-Bubble Pipeline Parallelism.
In standard PyTorch operations (like torch.matmul), the autograd engine is highly optimized.
If you perform a backward pass specifying only a subset of inputs (e.g., torch.autograd.backward(..., inputs=[activations])),
PyTorch will intelligently skip computing gradients for parameters (weights) to save compute.
However, custom torch.autograd.Function implementations do not share this intelligence.
PyTorch sets ctx.needs_input_grad to True for every input that has requires_grad=True, regardless of whether
that specific edge is actually being computed in the current backward() call.
This behavior makes it impossible to implement split-backward pipeline schedules (where activation gradients and weight gradients are computed at different times) using custom operations (like GroupedGEMM) without performing redundant calculations.
For more details, see PyTorch Issue #174017.
How it Works
To bypass this limitation, d9d introduces the GlobalGradContext. It acts as a side-channel state manager that
allows the training loop to explicitly signal its intent to the custom operators.
- Orchestrator: The training loop sets the context (e.g., "I only want Input gradients now").
- Operator: The custom
backwardchecks this context. Even if PyTorch saysneeds_input_grad=True, the operator will verify withGlobalGradContextbefore computation.
Usage
In Custom Autograd Functions
When writing a custom operation, you must tag your gradients with a semantic GradDirection and check the context before computation.
import torch
from d9d.core.autograd import GLOBAL_GRAD_CONTEXT, GradDirection
class MyCustomOp(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, weight):
# Save which direction 'inputs' and 'weight' correspond to
ctx.dir_inputs = GradDirection.inputs
ctx.dir_weight = GradDirection.weight
ctx.save_for_backward(inputs, weight)
return torch.matmul(inputs, weight)
@staticmethod
def backward(ctx, grad_output):
inputs, weight = ctx.saved_tensors
grad_input = grad_weight = None
# Check 1: Does PyTorch need it? AND Check 2: Does Context allow it?
# Calculate Input Gradients (Activation)
if ctx.needs_input_grad[0] and GLOBAL_GRAD_CONTEXT.check_direction(ctx.dir_inputs):
grad_input = torch.matmul(grad_output, weight.t())
# Calculate Weight Gradients
if ctx.needs_input_grad[1] and GLOBAL_GRAD_CONTEXT.check_direction(ctx.dir_weight):
grad_weight = torch.matmul(inputs.t(), grad_output)
return grad_input, grad_weight
In Training Loops
By default, the GLOBAL_GRAD_CONTEXT is set to compute both input and weight gradients.
The d9d pipelining API configures it for split-backward automatically. So, if you use the Trainer, everything will work out of the box.
If you use your own training loop implementation - you have to configure the context manually.
API Reference
d9d.core.autograd
GLOBAL_GRAD_CONTEXT = GlobalGradContext()
module-attribute
The singleton instance of GlobalGradContext.
This should be used by custom autograd functions to check GLOBAL_GRAD_CONTEXT.check_direction()
during their backward pass.
GlobalGradContext
Global state manager for controlling gradient computation in custom autograd functions.
This context addresses a limitation in PyTorch where custom torch.autograd.Function
implementations set ctx.needs_input_grad to True for all edges requiring grad,
even during partial backward passes (e.g., torch.autograd.backward(inputs=...)).
For additional information on this limitation, please refer to a related issue.
This class allows:
- For the training code - to explicitly signal which gradient edges (inputs vs weights) should currently be computed, allowing custom ops to skip unnecessary computations.
- For module code - to check whether it's required to compute a gradient edge.
Source code in d9d/core/autograd/grad_context.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | |
__init__()
Constructs a GlobalGradContext object with all directions enabled by default.
Source code in d9d/core/autograd/grad_context.py
38 39 40 41 42 | |
check_direction(direction)
Checks if the gradient calculation for the given direction is currently enabled.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
direction
|
GradDirection | None
|
The direction to check (inputs or weights). If None, returns True. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if the direction is enabled or None is passed, False otherwise. |
Source code in d9d/core/autograd/grad_context.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | |
set_directions(*directions)
Sets the enabled gradient directions, overriding the current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*directions
|
GradDirection
|
Variable number of GradDirection enums to enable. |
()
|
Source code in d9d/core/autograd/grad_context.py
61 62 63 64 65 66 67 68 69 | |
GradDirection
Bases: StrEnum
Enum representing the specific gradient edges to compute.
This is used to manually control gradient flow in custom autograd functions during split backward passes.
Attributes:
| Name | Type | Description |
|---|---|---|
inputs |
Mark gradient edge as pointing to the module's inputs (activations). |
|
weight |
Mark gradient edge as pointing to the module's parameters (weights). |
Source code in d9d/core/autograd/grad_context.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 | |