About
The d9d.module.parallelism package provides high-level strategies for distributing model execution across device meshes.
These strategies are "Horizontal" in the sense that they function within a specific stage of a pipeline (intra-layer parallelism), as opposed to Pipeline Parallelism which is "Vertical" (inter-layer).
Design
DTensor-First Architecture
d9d enforces a DTensor-first philosophy. We mandate that every trainable parameter in the distributed environment be represented as a torch.distributed.tensor.DTensor.
This constraint simplifies the system architecture significantly:
- Universal Checkpointing: The checkpointing engine does not need to know about specific parallel strategies (like "This is DP" or "This is TP"). It simply inspects the
DTensor.placementsattribute to automatically determine how to gather, deduplicate, and save tensors. - Native Synchronization: Gradient synchronization for replicated parameters is handled entirely by the d9d internals.
Composition over Monoliths
We explicitly reject monolithic wrappers like torch.nn.parallel.DistributedDataParallel (DDP).
While DDP is efficient for pure Data Parallelism, it acts as a "black box" that assumes ownership of the entire model execution loop.
Instead, d9d relies on PyTorch's parallelize_module API. This allows for fine-grained, per-submodule parallelism decisions:
- Layer A can use Tensor Parallelism (Row/Col wise).
- Layer B (e.g., a Router) can use Replicate Parallelism.
- Layer C (e.g., MLP) can use Expert Parallelism.
By treating "Data Parallelism" simply as another tiling strategy ("Replicate") within the Tensor Parallel system, we achieve a unified interface for ND parallelism.
Strategies
Replicate Parallelism
parallelize_replicate implements Replicate Parallelism. It replicates parameters across the mesh. Used for Data Parallelism or Context Parallelism.
During the forward pass, it installs hooks that temporarily "unwrap" DTensor parameters into standard, local torch.Tensor objects. This allows standard PyTorch operations and custom kernels to run without modification, while accessing module's state dict and parameters still yields DTensor objects.
Expert Parallelism (MoE)
Mixture of Experts (MoE) requires a unique parallel strategy where:
1. Experts are sharded across the ep_shard mesh dimension (each GPU holds a subset of experts), optionally replicating along ep_replicate .
2. Routers are replicated (all GPUs have the same routing logic).
parallelize_expert_parallel applies sharding to MoELayer modules. It shards the GroupedLinear weights along the expert dimension. Simultaneously, it effectively applies parallelize_replicate to the router.
Fully Sharded Data Parallel (FSDP)
parallelize_fsdp provides a thin wrapper around PyTorch's native fully_shard.
Difference from standard FSDP: Standard FSDP averages gradients across the mesh (Sum / WorldSize) by default. d9d's wrapper forces the gradients being summed rather than averaged. This is required for our gradient accumulation logic that is handled externally.
Usage Examples
Applying Replicate Parallelism
import torch
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN
from d9d.module.parallelism.api import parallelize_replicate
# 1. Create a Distributed Context
ctx: DistributedContext = ...
# 2. Get Dense Domain Mesh
dense_mesh = ctx.mesh_for(DENSE_DOMAIN) # pp x dp_replicate x dp_cp_shard x cp_replicate x tp
# 2. Define Model
model = MyCustomLayer(...)
# 3. Parallelize
parallelize_replicate(model, dense_mesh[['dp_replicate', 'cp_replicate']])
Applying Expert Parallelism
import torch
from d9d.core.dist_context import DistributedContext, EXPERT_DOMAIN
from d9d.module.parallelism.api import parallelize_expert_parallel
from d9d.module.block.moe import MoELayer
# 1. Create a Distributed Context
ctx: DistributedContext = ...
# 2. Get Expert Domain Mesh
expert_mesh = ctx.mesh_for(EXPERT_DOMAIN) # pp x ep_replicate x ep_shard
# 3. Define Model
model = MoELayer(...)
# 4. Parallelize
parallelize_expert_parallel(
model,
mesh_experts=expert_mesh[['ep_replicate', 'ep_shard']],
expert_shard_dim='ep_shard'
)
Applying FSDP/HSDP
import torch
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN
from d9d.module.parallelism.api import parallelize_fsdp, parallelize_replicate
# 1. Create a Distributed Context
ctx: DistributedContext = ...
# 2. Define Model
model = MyCustomLayer(...)
# 3. Get Dense Domain Mesh
dense_mesh = ctx.mesh_for(DENSE_DOMAIN)
# 4. Parallelize
# If using replicate context parallel - you need to manually replicate along its dimension
# since PyTorch FSDP only supports 2D DeviceMesh, but can be composed with other parallelisms
parallelize_replicate(
model,
mesh=dense_mesh['cp_replicate']
)
# Like regular fully_shard(...), first dimension is for replication, second dimension is for sharding
parallelize_fsdp(
model,
mesh=dense_mesh[['dp_replicate', 'dp_cp_shard']]
)
d9d.module.parallelism.api
Horizontal parallelism strategies and utilities for d9d modules.
This package provides high-level helper functions to apply specific distributed parallelism strategies to PyTorch modules compatible with the d9d ecosystem.
parallelize_expert_parallel(module, mesh_experts, expert_shard_dim='ep_shard')
Applies Expert Parallelism to a MoE layer.
This function configures the provided Mixture of Experts layer for distributed execution.
It partitions the sparse experts across the specified dimension of the device mesh (Expert Parallelism) and replicates along other dims.
Simultaneously, it configures the router to be fully replicated across the mesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
MoELayer
|
The MoE layer instance to parallelize. |
required |
mesh_experts
|
DeviceMesh
|
The device mesh containing the expert parallel resources. |
required |
expert_shard_dim
|
str
|
The name of the mesh dimension where experts should be sharded. |
'ep_shard'
|
Source code in d9d/module/parallelism/api/expert_parallel.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | |
parallelize_fsdp(module, mesh, *args, **kwargs)
Applies Fully Sharded Data Parallel (FSDP) with forced gradient summation.
This function wraps the provided module with PyTorch's fully_shard API using
the specified device mesh. Unlike standard FSDP usage, this function explicitly
configures the module to sum gradients across the mesh
instead of averaging them and disables internal all-sum-reduce hooks.
This is intended for d9d to handle gradient normalization and reduction across replicas externally.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to shard. |
required |
mesh
|
DeviceMesh
|
The device mesh over which to shard the module. |
required |
*args
|
Any
|
Additional positional arguments passed to |
()
|
**kwargs
|
Any
|
Additional keyword arguments passed to |
{}
|
Source code in d9d/module/parallelism/api/fully_sharded.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | |
parallelize_replicate(module, mesh)
Applies replicated parallelism to the module.
This function configures the provided module to be fully replicated across the
given device mesh. It utilizes the ToLocalParallel style, which manages
DTensor wrapping for parameters and gradients (via Replicate
and Partial placements) while ensuring that the underlying computation
sees standard local tensors during the forward pass.
This approach is effectively Data Parallelism managed via the DTensor APIs, allowing seamless integration of modules that require local tensor inputs into a broader distributed mesh context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to parallelize. |
required |
mesh
|
DeviceMesh
|
The device mesh over which to replicate the module. |
required |
Source code in d9d/module/parallelism/api/replicate_parallel.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | |
d9d.module.parallelism.style
ShardMoESparseExpertsParallel
Bases: ParallelStyle
Parallel style that shards MoE experts across a specific mesh dimension.
This style is designed for MoELayer instances using GroupedLinear for experts.
It splits the experts across the specified
dimension of the device mesh (Expert Parallelism). Other dimensions in the
mesh treat the parameters as Replicated.
It also initializes the necessary distributed communication groups within the MoE layer to handle token dispatching.
Source code in d9d/module/parallelism/style/shard_experts.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | |
ToLocalParallel
Bases: ParallelStyle
Parallel style that distributes parameters and gradients but executes with local tensors.
This style wraps standard tensor distribution (via DTensor) but injects
runtime hooks to temporarily unwrap DTensor parameters into local torch.Tensor
during the forward pass.
This is useful for parallel strategies (like Replicate) where the underlying calculation logic is not DTensor-aware, but the parameters must remain distributed for gradient synchronization and for distributed checkpointing.
Source code in d9d/module/parallelism/style/to_local.py
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | |
__init__(param_placement, grad_placement)
Constructs ToLocalParallel object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param_placement
|
tuple[Placement, ...]
|
Tuple of placements defining how parameters are distributed. |
required |
grad_placement
|
tuple[Placement, ...]
|
Tuple of placements defining how gradients are synchronized. |
required |
Source code in d9d/module/parallelism/style/to_local.py
49 50 51 52 53 54 55 56 57 58 59 | |