About

The d9d.core.autograd package provides utilities to exert fine-grained control over PyTorch's automatic differentiation engine.

The Global Grad Context

Why

The primary purpose of so-called Global Grad Context is to solve specific limitations in torch.autograd.Function regarding partial backward passes, which are critical for advanced distributed training schedules like Zero-Bubble Pipeline Parallelism.

In standard PyTorch operations (like torch.matmul), the autograd engine is highly optimized. If you perform a backward pass specifying only a subset of inputs (e.g., torch.autograd.backward(..., inputs=[activations])), PyTorch will intelligently skip computing gradients for parameters (weights) to save compute.

However, custom torch.autograd.Function implementations do not share this intelligence. PyTorch sets ctx.needs_input_grad to True for every input that has requires_grad=True, regardless of whether that specific edge is actually being computed in the current backward() call.

This behavior makes it impossible to implement split-backward pipeline schedules (where activation gradients and weight gradients are computed at different times) using custom operations (like GroupedGEMM) without performing redundant calculations.

For more details, see PyTorch Issue #174017.

How it Works

To bypass this limitation, d9d introduces the GlobalGradContext. It acts as a side-channel state manager that allows the training loop to explicitly signal its intent to the custom operators.

  1. Orchestrator: The training loop sets the context (e.g., "I only want Input gradients now").
  2. Operator: The custom backward checks this context. Even if PyTorch says needs_input_grad=True, the operator will verify with GlobalGradContext before computation.

Usage

In Custom Autograd Functions

When writing a custom operation, you must tag your gradients with a semantic GradDirection and check the context before computation.

import torch
from d9d.core.autograd import GLOBAL_GRAD_CONTEXT, GradDirection

class MyCustomOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weight):
        # Save which direction 'inputs' and 'weight' correspond to
        ctx.dir_inputs = GradDirection.inputs
        ctx.dir_weight = GradDirection.weight
        ctx.save_for_backward(inputs, weight)
        return torch.matmul(inputs, weight)

    @staticmethod
    def backward(ctx, grad_output):
        inputs, weight = ctx.saved_tensors
        grad_input = grad_weight = None

        # Check 1: Does PyTorch need it? AND Check 2: Does Context allow it?

        # Calculate Input Gradients (Activation)
        if ctx.needs_input_grad[0] and GLOBAL_GRAD_CONTEXT.check_direction(ctx.dir_inputs):
            grad_input = torch.matmul(grad_output, weight.t())

        # Calculate Weight Gradients
        if ctx.needs_input_grad[1] and GLOBAL_GRAD_CONTEXT.check_direction(ctx.dir_weight):
            grad_weight = torch.matmul(inputs.t(), grad_output)

        return grad_input, grad_weight

In Training Loops

By default, the GLOBAL_GRAD_CONTEXT is set to compute both input and weight gradients.

The d9d pipelining API configures it for split-backward automatically. So, if you use the Trainer, everything will work out of the box.

If you use your own training loop implementation - you have to configure the context manually.

API Reference

d9d.core.autograd

GLOBAL_GRAD_CONTEXT = GlobalGradContext() module-attribute

The singleton instance of GlobalGradContext.

This should be used by custom autograd functions to check GLOBAL_GRAD_CONTEXT.check_direction() during their backward pass.

GlobalGradContext

Global state manager for controlling gradient computation in custom autograd functions.

This context addresses a limitation in PyTorch where custom torch.autograd.Function implementations set ctx.needs_input_grad to True for all edges requiring grad, even during partial backward passes (e.g., torch.autograd.backward(inputs=...)).

For additional information on this limitation, please refer to a related issue.

This class allows:

  1. For the training code - to explicitly signal which gradient edges (inputs vs weights) should currently be computed, allowing custom ops to skip unnecessary computations.
  2. For module code - to check whether it's required to compute a gradient edge.
Source code in d9d/core/autograd/grad_context.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class GlobalGradContext:
    """
    Global state manager for controlling gradient computation in custom autograd functions.

    This context addresses a limitation in PyTorch where custom `torch.autograd.Function`
    implementations set `ctx.needs_input_grad` to True for all edges requiring grad,
    even during partial backward passes (e.g., `torch.autograd.backward(inputs=...)`).

    For additional information on this limitation, please refer to a
        [related issue](https://github.com/pytorch/pytorch/issues/174017).

    This class allows:

    1. For the training code - to explicitly signal which gradient edges (inputs vs weights)
        should currently be computed, allowing custom ops to skip unnecessary computations.
    2. For module code - to check whether it's required to compute a gradient edge.
    """

    def __init__(self):
        """Constructs a GlobalGradContext object with all directions enabled by default."""

        # both directions by default
        self._enabled_directions: set[GradDirection] = {GradDirection.inputs, GradDirection.weight}

    def check_direction(self, direction: GradDirection | None) -> bool:
        """
        Checks if the gradient calculation for the given direction is currently enabled.

        Args:
            direction: The direction to check (inputs or weights). If None,
                returns True.

        Returns:
            True if the direction is enabled or None is passed, False otherwise.
        """

        if direction is None:
            return True

        return direction in self._enabled_directions

    def set_directions(self, *directions: GradDirection):
        """
        Sets the enabled gradient directions, overriding the current state.

        Args:
            *directions: Variable number of GradDirection enums to enable.
        """

        self._enabled_directions = set(directions)

__init__()

Constructs a GlobalGradContext object with all directions enabled by default.

Source code in d9d/core/autograd/grad_context.py
38
39
40
41
42
def __init__(self):
    """Constructs a GlobalGradContext object with all directions enabled by default."""

    # both directions by default
    self._enabled_directions: set[GradDirection] = {GradDirection.inputs, GradDirection.weight}

check_direction(direction)

Checks if the gradient calculation for the given direction is currently enabled.

Parameters:

Name Type Description Default
direction GradDirection | None

The direction to check (inputs or weights). If None, returns True.

required

Returns:

Type Description
bool

True if the direction is enabled or None is passed, False otherwise.

Source code in d9d/core/autograd/grad_context.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def check_direction(self, direction: GradDirection | None) -> bool:
    """
    Checks if the gradient calculation for the given direction is currently enabled.

    Args:
        direction: The direction to check (inputs or weights). If None,
            returns True.

    Returns:
        True if the direction is enabled or None is passed, False otherwise.
    """

    if direction is None:
        return True

    return direction in self._enabled_directions

set_directions(*directions)

Sets the enabled gradient directions, overriding the current state.

Parameters:

Name Type Description Default
*directions GradDirection

Variable number of GradDirection enums to enable.

()
Source code in d9d/core/autograd/grad_context.py
61
62
63
64
65
66
67
68
69
def set_directions(self, *directions: GradDirection):
    """
    Sets the enabled gradient directions, overriding the current state.

    Args:
        *directions: Variable number of GradDirection enums to enable.
    """

    self._enabled_directions = set(directions)

GradDirection

Bases: StrEnum

Enum representing the specific gradient edges to compute.

This is used to manually control gradient flow in custom autograd functions during split backward passes.

Attributes:

Name Type Description
inputs

Mark gradient edge as pointing to the module's inputs (activations).

weight

Mark gradient edge as pointing to the module's parameters (weights).

Source code in d9d/core/autograd/grad_context.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class GradDirection(StrEnum):
    """
    Enum representing the specific gradient edges to compute.

    This is used to manually control gradient flow in custom autograd functions
    during split backward passes.

    Attributes:
        inputs: Mark gradient edge as pointing to the module's inputs (activations).
        weight: Mark gradient edge as pointing to the module's parameters (weights).
    """

    inputs = "inputs"
    weight = "weights"