About
Warning:
If you are utilizing the standard d9d training infrastructure, you do not need to call these functions manually. The framework automatically handles gradient synchronization. This package is primarily intended for users extending d9d.
The d9d.internals.grad_sync package provides low-level primitives for synchronizing gradients in distributed training setups utilizing DTensor.
Unlike standard PyTorch DistributedDataParallel which assumes a uniform communication strategy for the entire model, this package is designed to work with heterogeneous distributions often found in ND-parallelism (e.g., mixtures of Data, Tensor, Sequence, and Pipeline parallelism). It inspects DTensor placements to automatically determine which dimensions require reduction (all-reduce) and groups parameters into efficient communication buckets.
Core Concepts
Bucketing & Flattening
Communication overhead is dominated by latency when reducing many small tensors. To mitigate this, GradientSynchronizer groups parameters into Buckets.
Inside a SyncGradientBucket, gradients for multiple parameters are flattened into a single contiguous block of memory. When a reduction is triggered, the system performs a single all_reduce operation on this large buffer instead of hundreds of small operations.
Parameters are grouped automatically based on:
- Device
- DType
- Associated DeviceMesh
Asynchronous Reduction
In large-scale training, effective batch size is often increased by accumulating gradients over multiple micro-batches before performing an optimizer step. This package manages the lifecycle of distributed DTensor gradients during this accumulation phase without simple no_sync context managers.
-
Local Accumulation: During the backward pass of the first \(N-1\) micro-batches, local gradients are accumulated into the bucket's buffer. Conceptually, while the parameter
DTensorisReplicated, these intermediate local gradients also represent aReplicate(although contain different data) state across the Data Parallel mesh. -
Automatic Triggering: Each bucket maintains an internal counter. The
all_reducecommunication is only triggered when the specific parameter group has reached therequire_accumulationscount. This trigger happens automatically inside the backward hook of the last micro-batch, allowing communication to immediately overlap with the computation of remaining layers higher up in the model. This communication is made in a separate CUDA stream that should be awaited before using the gradients in your default stream. -
Synchronization: Once the asynchronous reduction completes, the flat buffer contains the globally summed gradient. Metadata of the contained parameter gradients is marked as
Replicate, making them safe for the Optimizer to consume without involving synchronization later.
d9d.internals.grad_sync
Gradient synchronization utilities.
This package provides the infrastructure for manual gradient bucketing and asynchronous reduction, similar to DistributedDataParallel but exposed for internal framework usage with DTensors.
GradientSynchronizer
Manages gradient synchronization for distributed training.
This class handles the bucketing of parameters, memory allocation for flat gradient buffers, and the orchestration of asynchronous all-reduce operations during the backward pass.
Source code in d9d/internals/grad_sync/synchronizer.py
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | |
__init__(param_groups, bucket_size_mb, require_accumulations)
Constructs a GradientSynchronizer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param_groups
|
list[list[Parameter]]
|
List of parameter groups. |
required |
bucket_size_mb
|
int
|
Maximal size of a single gradient bucket in MB. |
required |
require_accumulations
|
int
|
Number of micro-batches to accumulate before reducing. |
required |
Source code in d9d/internals/grad_sync/synchronizer.py
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
bind()
Initializes the synchronizer for training.
Groups parameters, creates buckets, allocates memory, and registers hooks. Must be called before the backward pass.
Source code in d9d/internals/grad_sync/synchronizer.py
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | |
unbind()
Releases resources.
Destroys buckets, frees memory buffers, and removes hooks.
Source code in d9d/internals/grad_sync/synchronizer.py
228 229 230 231 232 233 234 235 236 237 238 239 | |
wait()
Waits for all bucket operations (async reductions) to complete.
Source code in d9d/internals/grad_sync/synchronizer.py
241 242 243 244 245 246 247 248 249 | |
zero_grad()
Resets gradients and accumulation counters for all managed parameters.
Source code in d9d/internals/grad_sync/synchronizer.py
251 252 253 254 255 256 257 | |