Skip to content

Distributed Context

About

The d9d.core.dist_context package is the Source of Truth for the distributed execution environment.

In large-scale model training, ensuring that every rank agrees on the topology, global rank mapping, and communication groups is critical. This package provides the DistributedContext class, which serves as the central repository for this configuration.

It is extremely important to use this context for all distributed assertions (e.g., "Am I the main process?", "Which rank is my pipeline peer?") rather than checking raw os.environ variables or initializing ad-hoc process groups, which can lead to silent inconsistencies.

Comparison with Other Frameworks

The problem of managing distributed topology is solved in a different ways across different distributed training frameworks.

Megatron-LM (parallel_state)

Megatron-LM manages topology via a module often called mpu (Model Parallel Unit) or core.parallel_state.

Megatron historically relies on global variables and manual rank arithmetic. To find a peer rank, developers often write code involving modulo operations (e.g., rank % tp_size). This is flexible but error-prone and brittle.

HuggingFace Accelerate (PartialState)

Accelerate uses a class called PartialState to abstract the environment.

We find Accelerate's utility methods quite useful. d9d implements similar helpers, such as wait_world() (similar to wait_for_everyone()) and properties like is_main_process or is_local_main_process.

PartialState is primarily designed for "Flat" Data Parallelism (DDP/FSDP) and does not support complex multidimensional parallelisms natively.

PartialState is implemented as a Singleton. Instantiating it anywhere in the code returns the exact same global state. This makes flow of dependencies unclear and also could lead to initialization of your ProcessGroups and distributed environment in unexpected places in your code.

TorchTitan (ParallelDims)

TorchTitan is the most similar framework to d9d in spirit, as both are built on top of native PyTorch 2.x DeviceMesh abstractions.

However, ParallelDims in TorchTitan is more like a mesh factory rather than global distributed environment controller.

d9d (DistributedContext)

d9d positions DistributedContext as the explicit controller for managing all the distributed environment.

  • DistributedContext is a standard object that is instantiated and passed explicitly to dependent components. This ensures that the initialization of process groups happens exactly when and where the developer intends, making the initialization flow transparent.
  • It replaces manual rank arithmetic with formalized and native to PyTorch DeviceMesh abstractions.
  • Functionally, it elevates the mesh system into an active runtime controller. It bundles timeout management, context-aware logging, and node-level synchronization.

DeviceMesh Domains

Modern architectures require different parallelism strategies for different parts of the model (e.g., standard dense layers vs. Mixture-of-Experts layers). d9d handles this by abstracting these strategies into specific DeviceMesh Domains.

The underlying physical GPUs are immutable, but how we view them changes depending on what we are working with (distributing MoE layers, Dense layers, distributing input batch). DeviceMesh object for specific domain is retrieved via dist_ctx.mesh_for(domain_name).

Demonstration Video:

For better understanding domains, we have prepared a quick demonstration video on YouTube.

Regular Domain (regular)

  • Identifier: REGULAR_DOMAIN or "regular"
  • Purpose: The most granular mesh view for fully decomposed parallelism. Used for setting up logging and seeding.
  • Dimensions:
    1. pp: Pipeline Parallel
    2. dp_replicate: Data Parallel (DDP style)
    3. dp_shard: Data Parallel (FSDP style)
    4. cp_shard: Context Parallel (FSDP style)
    5. cp_replicate: Context Parallel (DDP style)
    6. tp: Tensor Parallelism

Expert Domain (expert)

  • Identifier: EXPERT_DOMAIN or "expert"
  • Purpose: Mesh view optimized for distributing MoE (Mixture of Experts) layers. It is intended that sparse expert layers should be sharded across ep_shard dimension and replicated across ep_replicate dimension.
  • Dimensions:
    1. pp: Pipeline Parallel
    2. ep_replicate: Combined Replication Dimension ((DP * CP) // EP)
    3. ep_shard: Expert Parallel Dimension

Dense Domain (dense)

  • Identifier: DENSE_DOMAIN or "dense"
  • Purpose: Mesh view for distributing dense layers.
  • Dimensions:
    1. pp: Pipeline Parallel
    2. dp_replicate: Data Parallel for replication using HSDP
    3. dp_cp_shard: Merged Data and Context Parallel dimension for sharding using HSDP
    4. cp_replicate: Context Parallel for replication
    5. tp: Tensor Parallel

Batch Domain (batch)

  • Identifier: BATCH_DOMAIN or "batch"
  • Purpose: Mesh view for distributing batch tensor and setting up DataLoader sharding.
  • Dimensions:
    1. pp: Pipeline Parallel
    2. dp: Data Parallel
    3. cp: Context Parallel
    4. tp: Tensor Parallel

Flat Domain (flat)

  • Identifier: FLAT_DOMAIN or "flat"
  • Purpose: Mesh view with a single dimension.
  • Dimensions:
    1. world: World Size

Usage

Initialization

The system is usually initialized via DeviceMeshParameters.

from d9d.core.dist_context import DeviceMeshParameters

# Define the topology
params = DeviceMeshParameters(
    pipeline_parallel=2,
    data_parallel_replicate=8,
    data_parallel_shard=1,
    context_parallel_replicate=1,
    context_parallel_shard=1,
    expert_parallel=8,
    tensor_parallel=1
)

dist_ctx = params.build()

Accessing DeviceMesh Domains

1
2
3
4
5
6
from torch.distributed import DeviceMesh
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN

dist_ctx: DistributedContext = ...

mesh_dense: DeviceMesh = dist_ctx.mesh_for(DENSE_DOMAIN)

Rank Utilities

Accessing rank information.

1
2
3
4
5
6
7
8
if dist_ctx.is_main_process:
    print("I am Global Rank 0 (Master)")

if dist_ctx.is_local_main_process:
    print("I am Rank 0 on this specific node")

# Synchronize
dist_ctx.wait_world()

Context Managers

Control execution flow across ranks.

1
2
3
4
5
6
# Ensure only one process per node downloads a file
with dist_ctx.local_main_process_first():
    if dist_ctx.is_local_main_process:
        download_dataset()
    # Others wait here implicitly
# All resume together

d9d.core.dist_context

This package configures the distributed environment and device meshes.

DeviceMeshParameters

Bases: BaseModel

Configuration parameters for initializing Distributed Device Meshes.

Attributes:

Name Type Description
pipeline_parallel int

Degree of pipeline parallelism (PP).

data_parallel_replicate int

Degree of data parallel replication (DDP).

data_parallel_shard int

Degree of data parallel sharding (FSDP).

context_parallel_replicate int

Degree of context parallel (CP) replication.

context_parallel_shard int

Degree of context parallel (FSCP) sharding.

tensor_parallel int

Degree of tensor parallelism (TP).

expert_parallel int

Degree of expert parallelism (EP/MoE).

has_data_parallel_replicate property

Checks if data parallel replication is enabled (degree > 1).

has_data_parallel_shard property

Checks if data parallel sharding is enabled (degree > 1).

has_expert_parallel property

Checks if expert parallelism is enabled (degree > 1).

has_pipeline_parallel property

Checks if pipeline parallelism is enabled (degree > 1).

is_distributed property

Checks if any form of parallelism is enabled.

build(log_level=logging.INFO)

Initializes the DistributedContext using these parameters.

Returns:

Type Description
DistributedContext

A new DistributedContext instance containing the initialized device meshes.

DistributedContext

Acts as the single source of truth for the distributed execution environment.

It acts as the central repository for the distributed configuration, managing the creation and synchronization of PyTorch DeviceMeshes for different domains (Regular domain, Expert Parallel domain, ...).

All assertions regarding rank placement, group memberships, and parallel topology must be derived from this context to ensure consistency.

current_device property

Returns the CUDA device associated with this rank.

is_local_main_process property

Checks if the current process is the rank 0 on the specific node.

is_main_process property

Checks if the current process is the global rank 0.

local_rank property

Returns the rank of the current process within its node.

logger property

Returns the logger instance configured for distributed logging.

master_addr property

Returns the IP address or domain name of the master node.

mesh_params property

Returns the parameters used to initialize this context.

node_rank property

Returns the index of the node this process is running on.

num_nodes property

Returns the total number of nodes in the cluster.

local_main_process_first()

Context manager that executes the block on the local main process first.

Other local ranks wait at the entrance. The local main process waits at the exit to synchronize before continuing.

main_process_first()

Context manager that executes the block on the global main process first.

All other ranks wait at the entrance. The global main process waits at the exit to synchronize before continuing.

mesh_for(domain)

Returns the device mesh view associated with a specific logical domain.

Available Domains and Dimensions
  • regular (REGULAR_DOMAIN): The most granular mesh for fully decomposed parallelism. Dimensions: ('pp', 'dp_replicate', 'dp_shard', 'cp_shard', 'cp_replicate', 'tp')
  • expert (EXPERT_DOMAIN): Mesh optimized for distributing MoE (Mixture of Experts) layers. Dimensions: ('pp', 'replicate', 'ep')
  • dense (DENSE_DOMAIN): Mesh optimized for distributing dense layers. Dimensions: ('pp', 'dp_replicate', 'dp_cp_shard', 'cp_replicate', 'tp')
  • batch (BATCH_DOMAIN): Mesh optimized for distributing input data. Dimensions: ('pp', 'dp', 'cp', 'tp')
  • flat (FLAT_DOMAIN): Mesh containing a single dimension with all the processes. Dimensions: ('world')

Parameters:

Name Type Description Default
domain str

The name of the domain to retrieve.

required

Returns:

Type Description
DeviceMesh

The PyTorch DeviceMesh configured for the requested domain.

Raises:

Type Description
ValueError

If the specified domain does not exist.

set_timeout(timeout_seconds)

Updates the NCCL/process group timeout for all underlying meshes.

Parameters:

Name Type Description Default
timeout_seconds float

New timeout duration in seconds.

required

wait_world()

Blocks process execution until all ranks reach this point.

build_dist_logger(qualifier, level)

Configures and returns a logger instance for d9d.

The logger is configured to write to stdout with a formatter that includes the provided rank qualifier, allowing for easier debugging in distributed logs.

Parameters:

Name Type Description Default
qualifier str

A string identifying the current rank's position in the mesh.

required
level int

Log level to set by default

required

Returns:

Type Description
Logger

A configured logging.Logger instance.