Horizontal Parallelism
About
The d9d.module.parallelism package provides high-level strategies for distributing model execution across device meshes.
These strategies are "Horizontal" in the sense that they function within a specific stage of a pipeline (intra-layer parallelism), as opposed to Pipeline Parallelism which is "Vertical" (inter-layer).
Design
DTensor-First Architecture
d9d enforces a DTensor-first philosophy. We mandate that every trainable parameter in the distributed environment be represented as a torch.distributed.tensor.DTensor.
This constraint simplifies the system architecture significantly:
- Universal Checkpointing: The checkpointing engine does not need to know about specific parallel strategies (like "This is DP" or "This is TP"). It simply inspects the
DTensor.placementsattribute to automatically determine how to gather, deduplicate, and save tensors. - Native Synchronization: Gradient synchronization for replicated parameters is handled entirely by the d9d internals, that now knows which tensor dimensions are Replicated.
Composition over Monoliths
We explicitly reject monolithic wrappers like torch.nn.parallel.DistributedDataParallel (DDP).
While DDP is efficient for pure Data Parallelism, it acts as a "black box" that assumes ownership of the entire model execution loop.
Instead, d9d relies on PyTorch's parallelize_module API. This allows for fine-grained, per-submodule parallelism decisions:
- Layer A can use Tensor Parallelism (Row/Col wise).
- Layer B (e.g., a Router) can use Replicate Parallelism.
- Layer C (e.g., MLP) can use Expert Parallelism.
By treating "Data Parallelism" simply as another tiling strategy ("Replicate") within the Tensor Parallel system, we achieve a unified interface for ND parallelism.
Strategies
Replicate Parallelism
parallelize_replicate implements Replicate Parallelism. It replicates parameters across the mesh. Used for Data Parallelism or Context Parallelism.
During the forward pass, it installs hooks that temporarily "unwrap" DTensor parameters into standard, local torch.Tensor objects. This allows standard PyTorch operations and custom kernels to run without modification, while accessing module's state dict and parameters still yields DTensor objects.
Expert Parallelism (MoE)
Mixture of Experts (MoE) requires a unique parallel strategy where:
1. Experts are sharded across the ep_shard mesh dimension (each GPU holds a subset of experts), optionally replicating along ep_replicate .
2. Routers are replicated (all GPUs have the same routing logic).
parallelize_expert_parallel applies sharding to MoELayer modules. It shards the GroupedLinear weights along the expert dimension. Simultaneously, it effectively applies parallelize_replicate to the router.
Fully Sharded Data Parallel (FSDP)
parallelize_fsdp provides a thin wrapper around PyTorch's native fully_shard.
Difference from standard FSDP:
- Standard FSDP averages gradients across the mesh (Sum / WorldSize) by default. d9d's wrapper forces the gradients being summed rather than averaged. This is required for our gradient accumulation logic that is handled externally.
parallelize_fsdpstrictly requires a 1D DeviceMesh. To use it in multi-dimensional meshes (e.g., combining Replication and Sharding), useparallelize_hsdpor applyparallelize_replicateto the other dimensions manually first.
Hybrid Sharded Data Parallel (HSDP)
parallelize_hsdp is a high-level composite strategy for mixing Full Sharding with Replicate Parallel.
parallelize_hsdp accepts a multi-dimensional mesh and a target shard_dim. It identifies all dimensions other than shard_dim as Replication Dimensions.
It applies parallelize_replicate to the replication dimensions.
It applies parallelize_fsdp to the specific sharding dimension.
Usage Examples
Replicate Parallelism
Applying Expert Parallelism
Applying FSDP
Applying HSDP
d9d.module.parallelism.api
Horizontal parallelism strategies and utilities for d9d modules.
This package provides high-level helper functions to apply specific distributed parallelism strategies to PyTorch modules compatible with the d9d ecosystem.
parallelize_expert_parallel(module, mesh_experts, expert_shard_dim='ep_shard')
Applies Expert Parallelism to a MoE layer.
This function configures the provided Mixture of Experts layer for distributed execution.
It partitions the sparse experts across the specified dimension of the device mesh (Expert Parallelism) and replicates along other dims.
Simultaneously, it configures the router to be fully replicated across the mesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
MoELayer
|
The MoE layer instance to parallelize. |
required |
mesh_experts
|
DeviceMesh
|
The device mesh containing the expert parallel resources. |
required |
expert_shard_dim
|
str
|
The name of the mesh dimension where experts should be sharded. |
'ep_shard'
|
parallelize_fsdp(module, mesh, *args, **kwargs)
Applies Fully Sharded Data Parallel (FSDP) with forced gradient summation.
This function wraps the provided module with PyTorch's fully_shard API using
the specified device mesh. Unlike standard FSDP usage, this function explicitly
configures the module to sum gradients across the mesh
instead of averaging them and disables internal all-sum-reduce hooks.
This is intended for d9d to handle gradient normalization and reduction across replicas externally.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to shard. |
required |
mesh
|
DeviceMesh
|
The device mesh over which to shard the module. |
required |
*args
|
Any
|
Additional positional arguments passed to |
()
|
**kwargs
|
Any
|
Additional keyword arguments passed to |
{}
|
parallelize_hsdp(module, mesh, shard_dim='dp_cp_shard', *fsdp_args, **fsdp_kwargs)
Applies Hybrid Sharded Data Parallelism (HSDP) to a module.
This function decomposes the provided device mesh into sharding dimensions and replication dimensions. It applies replication parallelism across the replication dimensions and Fully Sharded Data Parallelism (FSDP) across the specified shard dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to parallelize. |
required |
mesh
|
DeviceMesh
|
The device mesh over which to distribute the module. |
required |
shard_dim
|
str
|
The name of the mesh dimension used for FSDP sharding. Any dimension in the mesh not matching this name will be treated as a replication dimension. |
'dp_cp_shard'
|
*fsdp_args
|
Any
|
Positional arguments passed to the underlying FSDP parallelizer. |
()
|
**fsdp_kwargs
|
Any
|
Keyword arguments passed to the underlying FSDP parallelizer. |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If the device mesh does not have named dimensions. |
parallelize_replicate(module, mesh)
Applies replicated parallelism to the module.
This function configures the provided module to be fully replicated across the
given device mesh. It utilizes the ToLocalParallel style, which manages
DTensor wrapping for parameters and gradients (via Replicate placements)
while ensuring that the underlying computation sees standard local tensors during the forward pass.
This approach is effectively Data Parallelism managed via the DTensor APIs, allowing seamless integration of modules that require local tensor inputs into a broader distributed mesh context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to parallelize. |
required |
mesh
|
DeviceMesh
|
The device mesh over which to replicate the module. |
required |
d9d.module.parallelism.style
ShardMoESparseExpertsParallel
Bases: ParallelStyle
Parallel style that shards MoE experts across a specific mesh dimension.
This style is designed for MoELayer instances using GroupedLinear for experts.
It splits the experts across the specified
dimension of the device mesh (Expert Parallelism). Other dimensions in the
mesh treat the parameters as Replicated.
It also initializes the necessary distributed communication groups within the MoE layer to handle token dispatching.
ToLocalParallel
Bases: ParallelStyle
Parallel style that distributes parameters and gradients but executes with local tensors.
This style wraps standard tensor distribution (via DTensor) but injects
runtime hooks to temporarily unwrap DTensor parameters into local torch.Tensor
during the forward pass.
This is useful for parallel strategies (like Replicate) where the underlying calculation logic is not DTensor-aware, but the parameters must remain distributed for gradient synchronization and for distributed checkpointing.