Skip to content

Gradient Synchronization

Internal API Warning

If you are utilizing the standard d9d training infrastructure, you do not need to call these functions manually. The framework automatically handles gradient synchronization. This package is primarily intended for users extending d9d.

About

The d9d.internals.grad_sync package provides low-level primitives for synchronizing gradients in distributed training setups utilizing DTensor.

Unlike standard PyTorch DistributedDataParallel which assumes a uniform communication strategy for the entire model, this package is designed to work with heterogeneous distributions often found in ND-parallelism (e.g., mixtures of Data, Tensor, Sequence, and Pipeline parallelism). It inspects DTensor placements to automatically determine which dimensions require reduction (all-reduce) and groups parameters into efficient communication buckets.

Core Concepts

Bucketing & Flattening

Communication overhead is dominated by latency when reducing many small tensors. To mitigate this, GradientSynchronizer groups parameters into Buckets.

Inside a SyncGradientBucket, gradients for multiple parameters are flattened into a single contiguous block of memory. When a reduction is triggered, the system performs a single all_reduce operation on this large buffer instead of hundreds of small operations.

Parameters are grouped automatically based on:

  1. Device
  2. DType
  3. Associated DeviceMesh

Asynchronous Reduction

In large-scale training, effective batch size is often increased by accumulating gradients over multiple micro-batches before performing an optimizer step. This package manages the lifecycle of distributed DTensor gradients during this accumulation phase without simple no_sync context managers.

  1. Local Accumulation: During the backward pass of the first \(N-1\) micro-batches, local gradients are accumulated into the bucket's buffer. Conceptually, while the parameter DTensor is Replicated, these intermediate local gradients also represent a Replicate (although contain different data) state across the Data Parallel mesh.

  2. Automatic Triggering: Each bucket maintains an internal counter. The all_reduce communication is only triggered when the specific parameter group has reached the require_accumulations count. This trigger happens automatically inside the backward hook of the last micro-batch, allowing communication to immediately overlap with the computation of remaining layers higher up in the model. This communication is made in a separate CUDA stream that should be awaited before using the gradients in your default stream.

  3. Synchronization: Once the asynchronous reduction completes, the flat buffer contains the globally summed gradient. Metadata of the contained parameter gradients is marked as Replicate, making them safe for the Optimizer to consume without involving synchronization later.

d9d.internals.grad_sync

Gradient synchronization utilities.

This package provides the infrastructure for manual gradient bucketing and asynchronous reduction, similar to DistributedDataParallel but exposed for internal framework usage with DTensors.

GradientSynchronizer

Manages gradient synchronization for distributed training.

This class handles the bucketing of parameters, memory allocation for flat gradient buffers, and the orchestration of asynchronous all-reduce operations during the backward pass.

__init__(param_groups, bucket_size_mb, require_accumulations)

Constructs a GradientSynchronizer.

Parameters:

Name Type Description Default
param_groups list[list[Parameter]]

List of parameter groups.

required
bucket_size_mb int

Maximal size of a single gradient bucket in MB.

required
require_accumulations int

Number of micro-batches to accumulate before reducing.

required

bind()

Initializes the synchronizer for training.

Groups parameters, creates buckets, allocates memory, and registers hooks. Must be called before the backward pass.

unbind()

Releases resources.

Destroys buckets, frees memory buffers, and removes hooks.

wait()

Waits for all bucket operations (async reductions) to complete.

zero_grad()

Resets gradients and accumulation counters for all managed parameters.