About

Warning:

If you are utilizing the standard d9d training infrastructure, you do not need to call these functions manually. The framework automatically handles gradient synchronization. This package is primarily intended for users extending d9d.

The d9d.internals.grad_sync package provides low-level primitives for synchronizing gradients in distributed training setups utilizing DTensor.

Unlike standard PyTorch DistributedDataParallel which assumes a uniform communication strategy for the entire model, this package is designed to work with heterogeneous distributions often found in ND-parallelism (e.g., mixtures of Data, Tensor, Sequence, and Pipeline parallelism). It inspects DTensor placements to automatically determine which dimensions require reduction (all-reduce) and groups parameters into efficient communication buckets.

Core Concepts

Bucketing & Flattening

Communication overhead is dominated by latency when reducing many small tensors. To mitigate this, GradientSynchronizer groups parameters into Buckets.

Inside a SyncGradientBucket, gradients for multiple parameters are flattened into a single contiguous block of memory. When a reduction is triggered, the system performs a single all_reduce operation on this large buffer instead of hundreds of small operations.

Parameters are grouped automatically based on:

  1. Device
  2. DType
  3. Associated DeviceMesh

Asynchronous Reduction

In large-scale training, effective batch size is often increased by accumulating gradients over multiple micro-batches before performing an optimizer step. This package manages the lifecycle of distributed DTensor gradients during this accumulation phase without simple no_sync context managers.

  1. Local Accumulation: During the backward pass of the first \(N-1\) micro-batches, local gradients are accumulated into the bucket's buffer. Conceptually, while the parameter DTensor is Replicated, these intermediate local gradients also represent a Replicate (although contain different data) state across the Data Parallel mesh.

  2. Automatic Triggering: Each bucket maintains an internal counter. The all_reduce communication is only triggered when the specific parameter group has reached the require_accumulations count. This trigger happens automatically inside the backward hook of the last micro-batch, allowing communication to immediately overlap with the computation of remaining layers higher up in the model. This communication is made in a separate CUDA stream that should be awaited before using the gradients in your default stream.

  3. Synchronization: Once the asynchronous reduction completes, the flat buffer contains the globally summed gradient. Metadata of the contained parameter gradients is marked as Replicate, making them safe for the Optimizer to consume without involving synchronization later.

d9d.internals.grad_sync

Gradient synchronization utilities.

This package provides the infrastructure for manual gradient bucketing and asynchronous reduction, similar to DistributedDataParallel but exposed for internal framework usage with DTensors.

GradientSynchronizer

Manages gradient synchronization for distributed training.

This class handles the bucketing of parameters, memory allocation for flat gradient buffers, and the orchestration of asynchronous all-reduce operations during the backward pass.

Source code in d9d/internals/grad_sync/synchronizer.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class GradientSynchronizer:
    """
    Manages gradient synchronization for distributed training.

    This class handles the bucketing of parameters, memory allocation for flat
    gradient buffers, and the orchestration of asynchronous all-reduce operations
    during the backward pass.
    """

    def __init__(
            self,
            param_groups: list[list[nn.Parameter]],
            bucket_size_mb: int,
            require_accumulations: int
    ):
        """
        Constructs a GradientSynchronizer.

        Args:
            param_groups: List of parameter groups.
            bucket_size_mb: Maximal size of a single gradient bucket in MB.
            require_accumulations: Number of micro-batches to accumulate before reducing.
        """

        self._param_groups = param_groups
        self._bucket_size_mb = bucket_size_mb
        self._require_accumulations = require_accumulations

        self._communicate_stream: torch.cuda.Stream | None = None
        self._can_sync: bool
        self._buckets: list[AbstractGradientBucket] = []

    def bind(self):
        """
        Initializes the synchronizer for training.

        Groups parameters, creates buckets, allocates memory, and registers hooks.
        Must be called before the backward pass.
        """

        stream = torch.cuda.Stream()
        self._communicate_stream = stream
        self._buckets = _fill_buckets(
            _group_params_for_buckets(self._param_groups),
            bucket_size_mb=self._bucket_size_mb,
            require_accumulations=self._require_accumulations,
            communicate_stream=stream
        )

        for bucket in self._buckets:
            bucket.bind()

    def unbind(self):
        """
        Releases resources.

        Destroys buckets, frees memory buffers, and removes hooks.
        """

        for bucket in self._buckets:
            bucket.unbind()

        self._buckets = []
        self._communicate_stream = None

    def wait(self):
        """
        Waits for all bucket operations (async reductions) to complete.
        """

        torch.cuda.current_stream().wait_stream(self._communicate_stream)

        for bucket in self._buckets:
            bucket.mark_sync()

    def zero_grad(self):
        """
        Resets gradients and accumulation counters for all managed parameters.
        """

        for bucket in self._buckets:
            bucket.zero_grad()

__init__(param_groups, bucket_size_mb, require_accumulations)

Constructs a GradientSynchronizer.

Parameters:

Name Type Description Default
param_groups list[list[Parameter]]

List of parameter groups.

required
bucket_size_mb int

Maximal size of a single gradient bucket in MB.

required
require_accumulations int

Number of micro-batches to accumulate before reducing.

required
Source code in d9d/internals/grad_sync/synchronizer.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def __init__(
        self,
        param_groups: list[list[nn.Parameter]],
        bucket_size_mb: int,
        require_accumulations: int
):
    """
    Constructs a GradientSynchronizer.

    Args:
        param_groups: List of parameter groups.
        bucket_size_mb: Maximal size of a single gradient bucket in MB.
        require_accumulations: Number of micro-batches to accumulate before reducing.
    """

    self._param_groups = param_groups
    self._bucket_size_mb = bucket_size_mb
    self._require_accumulations = require_accumulations

    self._communicate_stream: torch.cuda.Stream | None = None
    self._can_sync: bool
    self._buckets: list[AbstractGradientBucket] = []

bind()

Initializes the synchronizer for training.

Groups parameters, creates buckets, allocates memory, and registers hooks. Must be called before the backward pass.

Source code in d9d/internals/grad_sync/synchronizer.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def bind(self):
    """
    Initializes the synchronizer for training.

    Groups parameters, creates buckets, allocates memory, and registers hooks.
    Must be called before the backward pass.
    """

    stream = torch.cuda.Stream()
    self._communicate_stream = stream
    self._buckets = _fill_buckets(
        _group_params_for_buckets(self._param_groups),
        bucket_size_mb=self._bucket_size_mb,
        require_accumulations=self._require_accumulations,
        communicate_stream=stream
    )

    for bucket in self._buckets:
        bucket.bind()

unbind()

Releases resources.

Destroys buckets, frees memory buffers, and removes hooks.

Source code in d9d/internals/grad_sync/synchronizer.py
228
229
230
231
232
233
234
235
236
237
238
239
def unbind(self):
    """
    Releases resources.

    Destroys buckets, frees memory buffers, and removes hooks.
    """

    for bucket in self._buckets:
        bucket.unbind()

    self._buckets = []
    self._communicate_stream = None

wait()

Waits for all bucket operations (async reductions) to complete.

Source code in d9d/internals/grad_sync/synchronizer.py
241
242
243
244
245
246
247
248
249
def wait(self):
    """
    Waits for all bucket operations (async reductions) to complete.
    """

    torch.cuda.current_stream().wait_stream(self._communicate_stream)

    for bucket in self._buckets:
        bucket.mark_sync()

zero_grad()

Resets gradients and accumulation counters for all managed parameters.

Source code in d9d/internals/grad_sync/synchronizer.py
251
252
253
254
255
256
257
def zero_grad(self):
    """
    Resets gradients and accumulation counters for all managed parameters.
    """

    for bucket in self._buckets:
        bucket.zero_grad()