Distributed Context
About
The d9d.core.dist_context package is the Source of Truth for the distributed execution environment.
In large-scale model training, ensuring that every rank agrees on the topology, global rank mapping, and communication groups is critical. This package provides the DistributedContext class, which serves as the central repository for this configuration.
It is extremely important to use this context for all distributed assertions (e.g., "Am I the main process?", "Which rank is my pipeline peer?") rather than checking raw os.environ variables or initializing ad-hoc process groups, which can lead to silent inconsistencies.
Comparison with Other Frameworks
The problem of managing distributed topology is solved in a different ways across different distributed training frameworks.
Megatron-LM (parallel_state)
Megatron-LM manages topology via a module often called mpu (Model Parallel Unit) or core.parallel_state.
Megatron historically relies on global variables and manual rank arithmetic. To find a peer rank, developers often write code involving modulo operations (e.g., rank % tp_size). This is flexible but error-prone and brittle.
HuggingFace Accelerate (PartialState)
Accelerate uses a class called PartialState to abstract the environment.
We find Accelerate's utility methods quite useful. d9d implements similar helpers, such as wait_world() (similar to wait_for_everyone()) and properties like is_main_process or is_local_main_process.
PartialState is primarily designed for "Flat" Data Parallelism (DDP/FSDP) and does not support complex multidimensional parallelisms natively.
PartialState is implemented as a Singleton. Instantiating it anywhere in the code returns the exact same global state. This makes flow of dependencies unclear and also could lead to initialization of your ProcessGroups and distributed environment in unexpected places in your code.
TorchTitan (ParallelDims)
TorchTitan is the most similar framework to d9d in spirit, as both are built on top of native PyTorch 2.x DeviceMesh abstractions.
However, ParallelDims in TorchTitan is more like a mesh factory rather than global distributed environment controller.
d9d (DistributedContext)
d9d positions DistributedContext as the explicit controller for managing all the distributed environment.
DistributedContextis a standard object that is instantiated and passed explicitly to dependent components. This ensures that the initialization of process groups happens exactly when and where the developer intends, making the initialization flow transparent.- It replaces manual rank arithmetic with formalized and native to PyTorch
DeviceMeshabstractions. - Functionally, it elevates the mesh system into an active runtime controller. It bundles timeout management, context-aware logging, and node-level synchronization.
DeviceMesh Domains
Modern architectures require different parallelism strategies for different parts of the model (e.g., standard dense layers vs. Mixture-of-Experts layers). d9d handles this by abstracting these strategies into specific DeviceMesh Domains.
The underlying physical GPUs are immutable, but how we view them changes depending on what we are working with (distributing MoE layers, Dense layers, distributing input batch). DeviceMesh object for specific domain is retrieved via dist_ctx.mesh_for(domain_name).
Demonstration Video:
For better understanding domains, we have prepared a quick demonstration video on YouTube.
Regular Domain (regular)
- Identifier:
REGULAR_DOMAINor"regular" - Purpose: The most granular mesh view for fully decomposed parallelism. Used for setting up logging and seeding.
- Dimensions:
pp: Pipeline Paralleldp_replicate: Data Parallel (DDP style)dp_shard: Data Parallel (FSDP style)cp_shard: Context Parallel (FSDP style)cp_replicate: Context Parallel (DDP style)tp: Tensor Parallelism
Expert Domain (expert)
- Identifier:
EXPERT_DOMAINor"expert" - Purpose: Mesh view optimized for distributing MoE (Mixture of Experts) layers. It is intended that sparse expert layers should be sharded across
ep_sharddimension and replicated acrossep_replicatedimension. - Dimensions:
pp: Pipeline Parallelep_replicate: Combined Replication Dimension ((DP * CP) // EP)ep_shard: Expert Parallel Dimension
Dense Domain (dense)
- Identifier:
DENSE_DOMAINor"dense" - Purpose: Mesh view for distributing dense layers.
- Dimensions:
pp: Pipeline Paralleldp_replicate: Data Parallel for replication using HSDPdp_cp_shard: Merged Data and Context Parallel dimension for sharding using HSDPcp_replicate: Context Parallel for replicationtp: Tensor Parallel
Batch Domain (batch)
- Identifier:
BATCH_DOMAINor"batch" - Purpose: Mesh view for distributing batch tensor and setting up DataLoader sharding.
- Dimensions:
pp: Pipeline Paralleldp: Data Parallelcp: Context Paralleltp: Tensor Parallel
Flat Domain (flat)
- Identifier:
FLAT_DOMAINor"flat" - Purpose: Mesh view with a single dimension.
- Dimensions:
world: World Size
Usage
Initialization
The system is usually initialized via DeviceMeshParameters.
Accessing DeviceMesh Domains
Rank Utilities
Accessing rank information.
Context Managers
Control execution flow across ranks.
d9d.core.dist_context
This package configures the distributed environment and device meshes.
DeviceMeshParameters
Bases: BaseModel
Configuration parameters for initializing Distributed Device Meshes.
Attributes:
| Name | Type | Description |
|---|---|---|
pipeline_parallel |
int
|
Degree of pipeline parallelism (PP). |
data_parallel_replicate |
int
|
Degree of data parallel replication (DDP). |
data_parallel_shard |
int
|
Degree of data parallel sharding (FSDP). |
context_parallel_replicate |
int
|
Degree of context parallel (CP) replication. |
context_parallel_shard |
int
|
Degree of context parallel (FSCP) sharding. |
tensor_parallel |
int
|
Degree of tensor parallelism (TP). |
expert_parallel |
int
|
Degree of expert parallelism (EP/MoE). |
has_data_parallel_replicate
property
Checks if data parallel replication is enabled (degree > 1).
has_data_parallel_shard
property
Checks if data parallel sharding is enabled (degree > 1).
has_expert_parallel
property
Checks if expert parallelism is enabled (degree > 1).
has_pipeline_parallel
property
Checks if pipeline parallelism is enabled (degree > 1).
is_distributed
property
Checks if any form of parallelism is enabled.
build(log_level=logging.INFO)
Initializes the DistributedContext using these parameters.
Returns:
| Type | Description |
|---|---|
DistributedContext
|
A new DistributedContext instance containing the initialized device meshes. |
DistributedContext
Acts as the single source of truth for the distributed execution environment.
It acts as the central repository for the distributed configuration, managing the creation and synchronization of PyTorch DeviceMeshes for different domains (Regular domain, Expert Parallel domain, ...).
All assertions regarding rank placement, group memberships, and parallel topology must be derived from this context to ensure consistency.
current_device
property
Returns the CUDA device associated with this rank.
is_local_main_process
property
Checks if the current process is the rank 0 on the specific node.
is_main_process
property
Checks if the current process is the global rank 0.
local_rank
property
Returns the rank of the current process within its node.
logger
property
Returns the logger instance configured for distributed logging.
master_addr
property
Returns the IP address or domain name of the master node.
mesh_params
property
Returns the parameters used to initialize this context.
node_rank
property
Returns the index of the node this process is running on.
num_nodes
property
Returns the total number of nodes in the cluster.
local_main_process_first()
Context manager that executes the block on the local main process first.
Other local ranks wait at the entrance. The local main process waits at the exit to synchronize before continuing.
main_process_first()
Context manager that executes the block on the global main process first.
All other ranks wait at the entrance. The global main process waits at the exit to synchronize before continuing.
mesh_for(domain)
Returns the device mesh view associated with a specific logical domain.
Available Domains and Dimensions
regular(REGULAR_DOMAIN): The most granular mesh for fully decomposed parallelism. Dimensions:('pp', 'dp_replicate', 'dp_shard', 'cp_shard', 'cp_replicate', 'tp')expert(EXPERT_DOMAIN): Mesh optimized for distributing MoE (Mixture of Experts) layers. Dimensions:('pp', 'replicate', 'ep')dense(DENSE_DOMAIN): Mesh optimized for distributing dense layers. Dimensions:('pp', 'dp_replicate', 'dp_cp_shard', 'cp_replicate', 'tp')batch(BATCH_DOMAIN): Mesh optimized for distributing input data. Dimensions:('pp', 'dp', 'cp', 'tp')flat(FLAT_DOMAIN): Mesh containing a single dimension with all the processes. Dimensions:('world')
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
domain
|
str
|
The name of the domain to retrieve. |
required |
Returns:
| Type | Description |
|---|---|
DeviceMesh
|
The PyTorch DeviceMesh configured for the requested domain. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the specified domain does not exist. |
set_timeout(timeout_seconds)
Updates the NCCL/process group timeout for all underlying meshes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timeout_seconds
|
float
|
New timeout duration in seconds. |
required |
wait_world()
Blocks process execution until all ranks reach this point.
build_dist_logger(qualifier, level)
Configures and returns a logger instance for d9d.
The logger is configured to write to stdout with a formatter that includes the provided rank qualifier, allowing for easier debugging in distributed logs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qualifier
|
str
|
A string identifying the current rank's position in the mesh. |
required |
level
|
int
|
Log level to set by default |
required |
Returns:
| Type | Description |
|---|---|
Logger
|
A configured logging.Logger instance. |