Gradient Norm & Clipping
Internal API Warning
If you are utilizing the standard d9d training infrastructure, you do not need to call these functions manually. The framework automatically handles gradient clipping. This package is primarily intended for users extending the internals of d9d.
About
The d9d.internals.grad_norm package handles the calculation and clipping of gradient norms in complex distributed environments.
Standard PyTorch clip_grad_norm_ functions are not fully aware of heterogeneous ND-Parallelism strategies (mixing Pipeline, Data, Tensor, and Context Parallelism). This package ensures that the global norm is correctly calculated across all parallel dimensions and that DTensor sharding is handled without unnecessary full-tensor materialization.
Concepts
Distributed Heterogeneity
Some parameters might be Sharded across a TP/FSDP mesh, while others are Replicated. Also model may be pipelined.
To handle this, we decompose the problem:
- Local Norm: Calculate the norm of the tensor shards actually present in GPU memory (using
to_local()). - Horizontal Reduction: Perform
all_reducestrictly on the meshes where parameters are sharded. This ensures that sharded parameters contribute correctly to the global norm, while replicated parameters do not trigger double-counting or unnecessary communication for norm calculation. - Pipeline Reduction: Finally, norms are summed across the Pipeline Parallel mesh, as different stages hold completely different parameters.
Grouping & Overlap
To optimize performance, group_parameters_for_norm groups parameters into GradNormGroup buckets. This grouping is based on:
- Sharding Strategy: Parameters sharded on the same mesh are grouped together so their norms can be reduced in a single collective operation.
- Device & DType: Ensures compatibility for local math operations.
The system attempts to overlap communication with computation. Groups containing sharded tensors are prioritized so their all_reduce operations can run asynchronously while local norms for other groups are being computed.
Mathematical Correctness
The goal of distributed gradient clipping is to calculate the Global Norm (\(\|\mathbf{g}\|\)) of a single model instance, regardless of how that model is physically fragmented across GPUs.
Let the total set of model parameters \(\mathcal{P}\) be divided into disjoint subsets based on parallelism strategy:
- \(\mathcal{P}_{pp}\): Sets of parameters residing on different Pipeline stages.
- \(\mathcal{P}_{sharded}\): Parameters split across a TP/EP/FSDP group.
- \(\mathcal{P}_{repl}\): Parameters replicated across other groups.
The definition of the global \(L_2\) norm is:
We prove that our strategy of separating aggregation logic based on placement prevents double-counting.
Proof for Sharded Parameters (TP/EP/FSDP)
For a parameter \(w \in \mathcal{P}_{sharded}\), the logical gradient tensor \(G\) is split into physical shards \(G_1, G_2, \dots, G_k\) across \(k\) devices. By the definition of the Frobenius norm:
Strategy: We calculate local norms and apply all_reduce(op=SUM).
Proof for Replicated Parameters (DP)
For a parameter \(w \in \mathcal{P}_{repl}\), the logical gradient tensor \(G\) is identical on all \(k\) devices (assuming DP synchronization has occurred). $$ G_{rank_1} = G_{rank_2} = \dots = G $$
If we were to sum these (as we did for TP), we would obtain: $$ \sum_{rank=1}^{k} |G_{rank}|^2 = k \cdot |G|^2 \quad (\text{Incorrect: Double Counting}) $$
Strategy: We group these parameters separately and do not communicate.
Proof for Pipeline Parallelism (PP)
Pipeline stages hold disjoint sets of parameters. The total norm is simply the sum of the norms of the stages.
Strategy: We apply all_reduce(op=SUM) across the PP mesh.
Result
The final formula utilized by d9d ensures \(1:1\) correspondence with a single-device baseline:
d9d.internals.grad_norm
clip_grad_norm_distributed_(parameter_groups, max_norm, norm_type, pp_mesh)
Clips gradient norms in a fully distributed environment.
This function calculates the global gradient norm across all dimensions of parallelism (Horizontal - DP/CP/TP/EP/..., and Pipeline) and scales the gradients in-place to ensure the norm does not exceed max_norm.
It accurately handles DTensors by identifying their sharding placements and performing reductions only on the necessary process groups.
Overlaps communication and computation if possible.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameter_groups
|
ParametersForNorm
|
Dictionary grouping parameters by synchronization requirements,
typically created by |
required |
max_norm
|
float | None
|
The maximum allowed norm of the gradients. If None, the function calculates and returns the global norm without modifying the gradients. |
required |
norm_type
|
float
|
The type of the norm to calculate (e.g., 2.0 for L2 norm, inf for max norm). |
required |
pp_mesh
|
DeviceMesh | None
|
The device mesh representing the pipeline parallel dimension, needed to reduce norms across pipeline stages. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The calculated global gradient norm. |
group_parameters_for_norm(parameters)
Groups parameters based on their distributed tensor characteristics.
Groups parameters by their sharding meshes, device, and gradient data type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameters
|
Iterable[Parameter]
|
The iterable of parameters to group. |
required |
Returns:
| Type | Description |
|---|---|
ParametersForNorm
|
A dictionary mapping synchronization groups to lists of parameters. |