Gradient Synchronization
Internal API Warning
If you are utilizing the standard d9d training infrastructure, you do not need to call these functions manually. The framework automatically handles gradient synchronization. This package is primarily intended for users extending d9d.
About
The d9d.internals.grad_sync package provides low-level primitives for synchronizing gradients in distributed training setups utilizing DTensor.
Unlike standard PyTorch DistributedDataParallel which assumes a uniform communication strategy for the entire model, this package is designed to work with heterogeneous distributions often found in ND-parallelism (e.g., mixtures of Data, Tensor, Sequence, and Pipeline parallelism). It inspects DTensor placements to automatically determine which dimensions require reduction (all-reduce) and groups parameters into efficient communication buckets.
Core Concepts
Bucketing & Flattening
Communication overhead is dominated by latency when reducing many small tensors. To mitigate this, GradientSynchronizer groups parameters into Buckets.
Inside a SyncGradientBucket, gradients for multiple parameters are flattened into a single contiguous block of memory. When a reduction is triggered, the system performs a single all_reduce operation on this large buffer instead of hundreds of small operations.
Parameters are grouped automatically based on:
- Device
- DType
- Associated DeviceMesh
Asynchronous Reduction
In large-scale training, effective batch size is often increased by accumulating gradients over multiple micro-batches before performing an optimizer step. This package manages the lifecycle of distributed DTensor gradients during this accumulation phase without simple no_sync context managers.
-
Local Accumulation: During the backward pass of the first \(N-1\) micro-batches, local gradients are accumulated into the bucket's buffer. Conceptually, while the parameter
DTensorisReplicated, these intermediate local gradients also represent aReplicate(although contain different data) state across the Data Parallel mesh. -
Automatic Triggering: Each bucket maintains an internal counter. The
all_reducecommunication is only triggered when the specific parameter group has reached therequire_accumulationscount. This trigger happens automatically inside the backward hook of the last micro-batch, allowing communication to immediately overlap with the computation of remaining layers higher up in the model. This communication is made in a separate CUDA stream that should be awaited before using the gradients in your default stream. -
Synchronization: Once the asynchronous reduction completes, the flat buffer contains the globally summed gradient. Metadata of the contained parameter gradients is marked as
Replicate, making them safe for the Optimizer to consume without involving synchronization later.
d9d.internals.grad_sync
Gradient synchronization utilities.
This package provides the infrastructure for manual gradient bucketing and asynchronous reduction, similar to DistributedDataParallel but exposed for internal framework usage with DTensors.
GradientSynchronizer
Manages gradient synchronization for distributed training.
This class handles the bucketing of parameters, memory allocation for flat gradient buffers, and the orchestration of asynchronous all-reduce operations during the backward pass.
__init__(param_groups, bucket_size_mb, require_accumulations)
Constructs a GradientSynchronizer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param_groups
|
list[list[Parameter]]
|
List of parameter groups. |
required |
bucket_size_mb
|
int
|
Maximal size of a single gradient bucket in MB. |
required |
require_accumulations
|
int
|
Number of micro-batches to accumulate before reducing. |
required |
bind()
Initializes the synchronizer for training.
Groups parameters, creates buckets, allocates memory, and registers hooks. Must be called before the backward pass.
unbind()
Releases resources.
Destroys buckets, frees memory buffers, and removes hooks.
wait()
Waits for all bucket operations (async reductions) to complete.
zero_grad()
Resets gradients and accumulation counters for all managed parameters.