About

The d9d.module.parallelism package provides high-level strategies for distributing model execution across device meshes.

These strategies are "Horizontal" in the sense that they function within a specific stage of a pipeline (intra-layer parallelism), as opposed to Pipeline Parallelism which is "Vertical" (inter-layer).

Design

DTensor-First Architecture

d9d enforces a DTensor-first philosophy. We mandate that every trainable parameter in the distributed environment be represented as a torch.distributed.tensor.DTensor.

This constraint simplifies the system architecture significantly:

  • Universal Checkpointing: The checkpointing engine does not need to know about specific parallel strategies (like "This is DP" or "This is TP"). It simply inspects the DTensor.placements attribute to automatically determine how to gather, deduplicate, and save tensors.
  • Native Synchronization: Gradient synchronization for replicated parameters is handled entirely by the d9d internals.

Composition over Monoliths

We explicitly reject monolithic wrappers like torch.nn.parallel.DistributedDataParallel (DDP).

While DDP is efficient for pure Data Parallelism, it acts as a "black box" that assumes ownership of the entire model execution loop. Instead, d9d relies on PyTorch's parallelize_module API. This allows for fine-grained, per-submodule parallelism decisions:

  • Layer A can use Tensor Parallelism (Row/Col wise).
  • Layer B (e.g., a Router) can use Replicate Parallelism.
  • Layer C (e.g., MLP) can use Expert Parallelism.

By treating "Data Parallelism" simply as another tiling strategy ("Replicate") within the Tensor Parallel system, we achieve a unified interface for ND parallelism.

Strategies

Replicate Parallelism

parallelize_replicate implements Replicate Parallelism. It replicates parameters across the mesh. Used for Data Parallelism or Context Parallelism.

During the forward pass, it installs hooks that temporarily "unwrap" DTensor parameters into standard, local torch.Tensor objects. This allows standard PyTorch operations and custom kernels to run without modification, while accessing module's state dict and parameters still yields DTensor objects.

Expert Parallelism (MoE)

Mixture of Experts (MoE) requires a unique parallel strategy where: 1. Experts are sharded across the ep_shard mesh dimension (each GPU holds a subset of experts), optionally replicating along ep_replicate . 2. Routers are replicated (all GPUs have the same routing logic).

parallelize_expert_parallel applies sharding to MoELayer modules. It shards the GroupedLinear weights along the expert dimension. Simultaneously, it effectively applies parallelize_replicate to the router.

Fully Sharded Data Parallel (FSDP)

parallelize_fsdp provides a thin wrapper around PyTorch's native fully_shard.

Difference from standard FSDP: Standard FSDP averages gradients across the mesh (Sum / WorldSize) by default. d9d's wrapper forces the gradients being summed rather than averaged. This is required for our gradient accumulation logic that is handled externally.

Usage Examples

Applying Replicate Parallelism

import torch
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN
from d9d.module.parallelism.api import parallelize_replicate

# 1. Create a Distributed Context
ctx: DistributedContext = ...

# 2. Get Dense Domain Mesh
dense_mesh = ctx.mesh_for(DENSE_DOMAIN)  # pp x dp_replicate x dp_cp_shard x cp_replicate x tp

# 2. Define Model
model = MyCustomLayer(...)

# 3. Parallelize
parallelize_replicate(model, dense_mesh[['dp_replicate', 'cp_replicate']])

Applying Expert Parallelism

import torch
from d9d.core.dist_context import DistributedContext, EXPERT_DOMAIN
from d9d.module.parallelism.api import parallelize_expert_parallel
from d9d.module.block.moe import MoELayer

# 1. Create a Distributed Context
ctx: DistributedContext = ...

# 2. Get Expert Domain Mesh
expert_mesh = ctx.mesh_for(EXPERT_DOMAIN)  # pp x ep_replicate x ep_shard

# 3. Define Model
model = MoELayer(...)

# 4. Parallelize
parallelize_expert_parallel(
    model, 
    mesh_experts=expert_mesh[['ep_replicate', 'ep_shard']],
    expert_shard_dim='ep_shard'
)

Applying FSDP/HSDP

import torch
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN
from d9d.module.parallelism.api import parallelize_fsdp, parallelize_replicate

# 1. Create a Distributed Context
ctx: DistributedContext = ...

# 2. Define Model
model = MyCustomLayer(...)

# 3. Get Dense Domain Mesh

dense_mesh = ctx.mesh_for(DENSE_DOMAIN)

# 4. Parallelize

# If using replicate context parallel - you need to manually replicate along its dimension
# since PyTorch FSDP only supports 2D DeviceMesh, but can be composed with other parallelisms
parallelize_replicate(
    model,
    mesh=dense_mesh['cp_replicate']
)

# Like regular fully_shard(...), first dimension is for replication, second dimension is for sharding
parallelize_fsdp(
    model, 
    mesh=dense_mesh[['dp_replicate', 'dp_cp_shard']]
)

d9d.module.parallelism.api

Horizontal parallelism strategies and utilities for d9d modules.

This package provides high-level helper functions to apply specific distributed parallelism strategies to PyTorch modules compatible with the d9d ecosystem.

parallelize_expert_parallel(module, mesh_experts, expert_shard_dim='ep_shard')

Applies Expert Parallelism to a MoE layer.

This function configures the provided Mixture of Experts layer for distributed execution.

It partitions the sparse experts across the specified dimension of the device mesh (Expert Parallelism) and replicates along other dims.

Simultaneously, it configures the router to be fully replicated across the mesh.

Parameters:

Name Type Description Default
module MoELayer

The MoE layer instance to parallelize.

required
mesh_experts DeviceMesh

The device mesh containing the expert parallel resources.

required
expert_shard_dim str

The name of the mesh dimension where experts should be sharded.

'ep_shard'
Source code in d9d/module/parallelism/api/expert_parallel.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def parallelize_expert_parallel(
        module: MoELayer,
        mesh_experts: DeviceMesh,
        expert_shard_dim: str = "ep_shard"
):
    """
    Applies Expert Parallelism to a MoE layer.

    This function configures the provided Mixture of Experts layer for distributed
    execution.

    It partitions the sparse experts across the specified dimension
    of the device mesh (Expert Parallelism) and replicates along other dims.

    Simultaneously, it configures the router to be fully replicated across
    the mesh.

    Args:
        module: The MoE layer instance to parallelize.
        mesh_experts: The device mesh containing the expert parallel resources.
        expert_shard_dim: The name of the mesh dimension where experts should be sharded.
    """

    parallelize_module(module, mesh_experts, ShardMoESparseExpertsParallel(shard_dim_name=expert_shard_dim))
    parallelize_module(module.router, mesh_experts, ToLocalParallel(
        param_placement=tuple(Replicate() for _ in range(mesh_experts.ndim)),
        grad_placement=tuple(Partial("sum") for _ in range(mesh_experts.ndim))
    ))

parallelize_fsdp(module, mesh, *args, **kwargs)

Applies Fully Sharded Data Parallel (FSDP) with forced gradient summation.

This function wraps the provided module with PyTorch's fully_shard API using the specified device mesh. Unlike standard FSDP usage, this function explicitly configures the module to sum gradients across the mesh instead of averaging them and disables internal all-sum-reduce hooks. This is intended for d9d to handle gradient normalization and reduction across replicas externally.

Parameters:

Name Type Description Default
module Module

The module to shard.

required
mesh DeviceMesh

The device mesh over which to shard the module.

required
*args Any

Additional positional arguments passed to fully_shard.

()
**kwargs Any

Additional keyword arguments passed to fully_shard.

{}
Source code in d9d/module/parallelism/api/fully_sharded.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def parallelize_fsdp(
        module: nn.Module,
        mesh: DeviceMesh,
        *args: Any,
        **kwargs: Any
):
    """
    Applies Fully Sharded Data Parallel (FSDP) with forced gradient summation.

    This function wraps the provided module with PyTorch's ``fully_shard`` API using
    the specified device mesh. Unlike standard FSDP usage, this function explicitly
    configures the module to sum gradients across the mesh
    instead of averaging them and disables internal all-sum-reduce hooks.
    This is intended for d9d to handle gradient normalization and reduction across replicas externally.

    Args:
        module: The module to shard.
        mesh: The device mesh over which to shard the module.
        *args: Additional positional arguments passed to ``fully_shard``.
        **kwargs: Additional keyword arguments passed to ``fully_shard``.
    """

    fully_shard(module, *args, mesh=mesh, **kwargs)
    if not isinstance(module, FSDPModule):
        raise RuntimeError("Torch FSDP did not convert the module into FSDPModule")
    _force_fsdp_grad_reduction_policy(module)

parallelize_replicate(module, mesh)

Applies replicated parallelism to the module.

This function configures the provided module to be fully replicated across the given device mesh. It utilizes the ToLocalParallel style, which manages DTensor wrapping for parameters and gradients (via Replicate and Partial placements) while ensuring that the underlying computation sees standard local tensors during the forward pass.

This approach is effectively Data Parallelism managed via the DTensor APIs, allowing seamless integration of modules that require local tensor inputs into a broader distributed mesh context.

Parameters:

Name Type Description Default
module Module

The module to parallelize.

required
mesh DeviceMesh

The device mesh over which to replicate the module.

required
Source code in d9d/module/parallelism/api/replicate_parallel.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def parallelize_replicate(
    module: nn.Module,
    mesh: DeviceMesh,
):
    """
    Applies replicated parallelism to the module.

    This function configures the provided module to be fully replicated across the
    given device mesh. It utilizes the ``ToLocalParallel`` style, which manages
    ``DTensor`` wrapping for parameters and gradients (via ``Replicate``
    and ``Partial`` placements) while ensuring that the underlying computation
    sees standard local tensors during the forward pass.

    This approach is effectively Data Parallelism managed via the DTensor
    APIs, allowing seamless integration of modules that require local tensor inputs
    into a broader distributed mesh context.

    Args:
     module: The module to parallelize.
     mesh: The device mesh over which to replicate the module.
    """

    parallelize_module(module, mesh, ToLocalParallel(
        param_placement=tuple(Replicate() for _ in range(mesh.ndim)),
        grad_placement=tuple(Partial("sum") for _ in range(mesh.ndim))
    ))

d9d.module.parallelism.style

ShardMoESparseExpertsParallel

Bases: ParallelStyle

Parallel style that shards MoE experts across a specific mesh dimension.

This style is designed for MoELayer instances using GroupedLinear for experts. It splits the experts across the specified dimension of the device mesh (Expert Parallelism). Other dimensions in the mesh treat the parameters as Replicated.

It also initializes the necessary distributed communication groups within the MoE layer to handle token dispatching.

Source code in d9d/module/parallelism/style/shard_experts.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class ShardMoESparseExpertsParallel(ParallelStyle):
    """
    Parallel style that shards MoE experts across a specific mesh dimension.

    This style is designed for ``MoELayer`` instances using ``GroupedLinear`` for experts.
    It splits the experts across the specified
    dimension of the device mesh (Expert Parallelism). Other dimensions in the
    mesh treat the parameters as Replicated.

    It also initializes the necessary distributed communication groups within the
    MoE layer to handle token dispatching.
    """

    def __init__(self, shard_dim_name: str):
        self._shard_dim_name = shard_dim_name

    def _partition_experts(self, module_name: str, mod: nn.Module, device_mesh: DeviceMesh):
        if not isinstance(mod, GroupedLinear):
            raise TypeError("This plan should be applied only on GroupedLinear")

        mesh_dim_names = device_mesh.mesh_dim_names

        if mesh_dim_names is None:
            raise ValueError("This plan should be applied only on named DeviceMeshes")

        placements = [
            Shard(0) if dim_name == self._shard_dim_name else Replicate()
            for dim_name
            in mesh_dim_names
        ]
        weight = nn.Parameter(
            distribute_tensor(mod.weight, device_mesh, placements),
            requires_grad=mod.weight.requires_grad
        )
        mod.weight = weight

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        if not isinstance(module, MoELayer):
            raise TypeError("This plan should be applied only on MoELayer")

        module.enable_distributed_communicator(device_mesh.get_group(self._shard_dim_name))

        for submod in module.modules():
            if isinstance(submod, GroupedLinear):
                distribute_module(submod, device_mesh, self._partition_experts)

        return module

ToLocalParallel

Bases: ParallelStyle

Parallel style that distributes parameters and gradients but executes with local tensors.

This style wraps standard tensor distribution (via DTensor) but injects runtime hooks to temporarily unwrap DTensor parameters into local torch.Tensor during the forward pass.

This is useful for parallel strategies (like Replicate) where the underlying calculation logic is not DTensor-aware, but the parameters must remain distributed for gradient synchronization and for distributed checkpointing.

Source code in d9d/module/parallelism/style/to_local.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class ToLocalParallel(ParallelStyle):
    """
    Parallel style that distributes parameters and gradients but executes with local tensors.

    This style wraps standard tensor distribution (via ``DTensor``) but injects
    runtime hooks to temporarily unwrap ``DTensor`` parameters into local ``torch.Tensor``
    during the forward pass.

    This is useful for parallel strategies (like Replicate)
    where the underlying calculation logic is not DTensor-aware, but the parameters must remain
    distributed for gradient synchronization and for distributed checkpointing.
    """

    def __init__(self, param_placement: tuple[Placement, ...], grad_placement: tuple[Placement, ...]):
        """
        Constructs ToLocalParallel object.

        Args:
            param_placement: Tuple of placements defining how parameters are distributed.
            grad_placement: Tuple of placements defining how gradients are synchronized.
        """

        self._grad_placement = grad_placement
        self._param_placement = param_placement

    def _distribute_params(self, name: str, module: nn.Module, device_mesh: DeviceMesh):
        for param_name, param in module.named_parameters(recurse=False):
            new_param = nn.Parameter(
                distribute_tensor(param.data, device_mesh, self._param_placement),
                requires_grad=param.requires_grad
            )

            module.register_parameter(param_name, new_param)

    def _apply(self, master_module: nn.Module, device_mesh: DeviceMesh):
        patched_classes = {}
        original_classes = {}

        for submod_name, submod in master_module.named_modules():
            param_names = [name for name, p in submod.named_parameters(recurse=False)]
            patched_classes[submod_name] = _build_to_local_patched_class(submod, self._grad_placement, param_names)
            original_classes[submod_name] = submod.__class__

            distribute_module(
                submod,
                device_mesh,
                self._distribute_params
            )

        master_module.register_forward_pre_hook(_ModulePatch(patched_classes))
        master_module.register_forward_hook(_ModulePatch(original_classes))

__init__(param_placement, grad_placement)

Constructs ToLocalParallel object.

Parameters:

Name Type Description Default
param_placement tuple[Placement, ...]

Tuple of placements defining how parameters are distributed.

required
grad_placement tuple[Placement, ...]

Tuple of placements defining how gradients are synchronized.

required
Source code in d9d/module/parallelism/style/to_local.py
49
50
51
52
53
54
55
56
57
58
59
def __init__(self, param_placement: tuple[Placement, ...], grad_placement: tuple[Placement, ...]):
    """
    Constructs ToLocalParallel object.

    Args:
        param_placement: Tuple of placements defining how parameters are distributed.
        grad_placement: Tuple of placements defining how gradients are synchronized.
    """

    self._grad_placement = grad_placement
    self._param_placement = param_placement