Skip to content

Horizontal Parallelism

About

The d9d.module.parallelism package provides high-level strategies for distributing model execution across device meshes.

These strategies are "Horizontal" in the sense that they function within a specific stage of a pipeline (intra-layer parallelism), as opposed to Pipeline Parallelism which is "Vertical" (inter-layer).

Design

DTensor-First Architecture

d9d enforces a DTensor-first philosophy. We mandate that every trainable parameter in the distributed environment be represented as a torch.distributed.tensor.DTensor.

This constraint simplifies the system architecture significantly:

  • Universal Checkpointing: The checkpointing engine does not need to know about specific parallel strategies (like "This is DP" or "This is TP"). It simply inspects the DTensor.placements attribute to automatically determine how to gather, deduplicate, and save tensors.
  • Native Synchronization: Gradient synchronization for replicated parameters is handled entirely by the d9d internals, that now knows which tensor dimensions are Replicated.

Composition over Monoliths

We explicitly reject monolithic wrappers like torch.nn.parallel.DistributedDataParallel (DDP).

While DDP is efficient for pure Data Parallelism, it acts as a "black box" that assumes ownership of the entire model execution loop. Instead, d9d relies on PyTorch's parallelize_module API. This allows for fine-grained, per-submodule parallelism decisions:

  • Layer A can use Tensor Parallelism (Row/Col wise).
  • Layer B (e.g., a Router) can use Replicate Parallelism.
  • Layer C (e.g., MLP) can use Expert Parallelism.

By treating "Data Parallelism" simply as another tiling strategy ("Replicate") within the Tensor Parallel system, we achieve a unified interface for ND parallelism.

Strategies

Replicate Parallelism

parallelize_replicate implements Replicate Parallelism. It replicates parameters across the mesh. Used for Data Parallelism or Context Parallelism.

During the forward pass, it installs hooks that temporarily "unwrap" DTensor parameters into standard, local torch.Tensor objects. This allows standard PyTorch operations and custom kernels to run without modification, while accessing module's state dict and parameters still yields DTensor objects.

Expert Parallelism (MoE)

Mixture of Experts (MoE) requires a unique parallel strategy where: 1. Experts are sharded across the ep_shard mesh dimension (each GPU holds a subset of experts), optionally replicating along ep_replicate . 2. Routers are replicated (all GPUs have the same routing logic).

parallelize_expert_parallel applies sharding to MoELayer modules. It shards the GroupedLinear weights along the expert dimension. Simultaneously, it effectively applies parallelize_replicate to the router.

Fully Sharded Data Parallel (FSDP)

parallelize_fsdp provides a thin wrapper around PyTorch's native fully_shard.

Difference from standard FSDP:

  • Standard FSDP averages gradients across the mesh (Sum / WorldSize) by default. d9d's wrapper forces the gradients being summed rather than averaged. This is required for our gradient accumulation logic that is handled externally.
  • parallelize_fsdp strictly requires a 1D DeviceMesh. To use it in multi-dimensional meshes (e.g., combining Replication and Sharding), use parallelize_hsdp or apply parallelize_replicate to the other dimensions manually first.

Hybrid Sharded Data Parallel (HSDP)

parallelize_hsdp is a high-level composite strategy for mixing Full Sharding with Replicate Parallel.

parallelize_hsdp accepts a multi-dimensional mesh and a target shard_dim. It identifies all dimensions other than shard_dim as Replication Dimensions. It applies parallelize_replicate to the replication dimensions. It applies parallelize_fsdp to the specific sharding dimension.

Usage Examples

Replicate Parallelism

import torch
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN
from d9d.module.parallelism.api import parallelize_replicate

# 1. Create a Distributed Context
ctx: DistributedContext = ...

# 2. Get Dense Domain Mesh
dense_mesh = ctx.mesh_for(DENSE_DOMAIN)  # pp x dp_replicate x dp_cp_shard x cp_replicate x tp

# 2. Define Model
model = MyCustomLayer(...)

# 3. Parallelize
parallelize_replicate(model, dense_mesh[['dp_replicate', 'cp_replicate']])

Applying Expert Parallelism

import torch
from d9d.core.dist_context import DistributedContext, EXPERT_DOMAIN
from d9d.module.parallelism.api import parallelize_expert_parallel
from d9d.module.block.moe import MoELayer

# 1. Create a Distributed Context
ctx: DistributedContext = ...

# 2. Get Expert Domain Mesh
expert_mesh = ctx.mesh_for(EXPERT_DOMAIN)  # pp x ep_replicate x ep_shard

# 3. Define Model
model = MoELayer(...)

# 4. Parallelize
parallelize_expert_parallel(
    model, 
    mesh_experts=expert_mesh[['ep_replicate', 'ep_shard']],
    expert_shard_dim='ep_shard'
)

Applying FSDP

import torch
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN
from d9d.module.parallelism.api import parallelize_fsdp, parallelize_replicate

# 1. Create a Distributed Context
ctx: DistributedContext = ...

# 2. Define Model
model = MyCustomLayer(...)

# 3. Get Dense Domain Mesh

dense_mesh = ctx.mesh_for(DENSE_DOMAIN)

# 4. Parallelize

parallelize_fsdp(
    model, 
    mesh=dense_mesh['dp_cp_shard']
)

Applying HSDP

import torch
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN
from d9d.module.parallelism.api import parallelize_hsdp

# 1. Create a Distributed Context
ctx: DistributedContext = ...

# 2. Get Mesh
dense_mesh = ctx.mesh_for(DENSE_DOMAIN)  # pp x dp_replicate x dp_cp_shard x cp_replicate x tp

# 3. Define Model
model = MyCustomLayer(...)

# 4. Parallelize
parallelize_hsdp(
    model,
    mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
    shard_dim="dp_cp_shard"
)

d9d.module.parallelism.api

Horizontal parallelism strategies and utilities for d9d modules.

This package provides high-level helper functions to apply specific distributed parallelism strategies to PyTorch modules compatible with the d9d ecosystem.

parallelize_expert_parallel(module, mesh_experts, expert_shard_dim='ep_shard')

Applies Expert Parallelism to a MoE layer.

This function configures the provided Mixture of Experts layer for distributed execution.

It partitions the sparse experts across the specified dimension of the device mesh (Expert Parallelism) and replicates along other dims.

Simultaneously, it configures the router to be fully replicated across the mesh.

Parameters:

Name Type Description Default
module MoELayer

The MoE layer instance to parallelize.

required
mesh_experts DeviceMesh

The device mesh containing the expert parallel resources.

required
expert_shard_dim str

The name of the mesh dimension where experts should be sharded.

'ep_shard'

parallelize_fsdp(module, mesh, *args, **kwargs)

Applies Fully Sharded Data Parallel (FSDP) with forced gradient summation.

This function wraps the provided module with PyTorch's fully_shard API using the specified device mesh. Unlike standard FSDP usage, this function explicitly configures the module to sum gradients across the mesh instead of averaging them and disables internal all-sum-reduce hooks. This is intended for d9d to handle gradient normalization and reduction across replicas externally.

Parameters:

Name Type Description Default
module Module

The module to shard.

required
mesh DeviceMesh

The device mesh over which to shard the module.

required
*args Any

Additional positional arguments passed to fully_shard.

()
**kwargs Any

Additional keyword arguments passed to fully_shard.

{}

parallelize_hsdp(module, mesh, shard_dim='dp_cp_shard', *fsdp_args, **fsdp_kwargs)

Applies Hybrid Sharded Data Parallelism (HSDP) to a module.

This function decomposes the provided device mesh into sharding dimensions and replication dimensions. It applies replication parallelism across the replication dimensions and Fully Sharded Data Parallelism (FSDP) across the specified shard dimension.

Parameters:

Name Type Description Default
module Module

The module to parallelize.

required
mesh DeviceMesh

The device mesh over which to distribute the module.

required
shard_dim str

The name of the mesh dimension used for FSDP sharding. Any dimension in the mesh not matching this name will be treated as a replication dimension.

'dp_cp_shard'
*fsdp_args Any

Positional arguments passed to the underlying FSDP parallelizer.

()
**fsdp_kwargs Any

Keyword arguments passed to the underlying FSDP parallelizer.

{}

Raises:

Type Description
ValueError

If the device mesh does not have named dimensions.

parallelize_replicate(module, mesh)

Applies replicated parallelism to the module.

This function configures the provided module to be fully replicated across the given device mesh. It utilizes the ToLocalParallel style, which manages DTensor wrapping for parameters and gradients (via Replicate placements) while ensuring that the underlying computation sees standard local tensors during the forward pass.

This approach is effectively Data Parallelism managed via the DTensor APIs, allowing seamless integration of modules that require local tensor inputs into a broader distributed mesh context.

Parameters:

Name Type Description Default
module Module

The module to parallelize.

required
mesh DeviceMesh

The device mesh over which to replicate the module.

required

d9d.module.parallelism.style

ShardMoESparseExpertsParallel

Bases: ParallelStyle

Parallel style that shards MoE experts across a specific mesh dimension.

This style is designed for MoELayer instances using GroupedLinear for experts. It splits the experts across the specified dimension of the device mesh (Expert Parallelism). Other dimensions in the mesh treat the parameters as Replicated.

It also initializes the necessary distributed communication groups within the MoE layer to handle token dispatching.

ToLocalParallel

Bases: ParallelStyle

Parallel style that distributes parameters and gradients but executes with local tensors.

This style wraps standard tensor distribution (via DTensor) but injects runtime hooks to temporarily unwrap DTensor parameters into local torch.Tensor during the forward pass.

This is useful for parallel strategies (like Replicate) where the underlying calculation logic is not DTensor-aware, but the parameters must remain distributed for gradient synchronization and for distributed checkpointing.

__init__(param_placement, grad_placement)

Constructs ToLocalParallel object.

Parameters:

Name Type Description Default
param_placement tuple[Placement, ...]

Tuple of placements defining how parameters are distributed.

required
grad_placement tuple[Placement, ...]

Tuple of placements defining how gradients are synchronized.

required