About
The d9d.core.dist_context package is the Source of Truth for the distributed execution environment.
In large-scale model training, ensuring that every rank agrees on the topology, global rank mapping, and communication groups is critical. This package provides the DistributedContext class, which serves as the central repository for this configuration.
It is extremely important to use this context for all distributed assertions (e.g., "Am I the main process?", "Which rank is my pipeline peer?") rather than checking raw os.environ variables or initializing ad-hoc process groups, which can lead to silent inconsistencies.
Comparison with Other Frameworks
The problem of managing distributed topology is solved in a different ways across different distributed training frameworks.
Megatron-LM (parallel_state)
Megatron-LM manages topology via a module often called mpu (Model Parallel Unit) or core.parallel_state.
Megatron historically relies on global variables and manual rank arithmetic. To find a peer rank, developers often write code involving modulo operations (e.g., rank % tp_size). This is flexible but error-prone and brittle.
HuggingFace Accelerate (PartialState)
Accelerate uses a class called PartialState to abstract the environment.
We find Accelerate's utility methods quite useful. d9d implements similar helpers, such as wait_world() (similar to wait_for_everyone()) and properties like is_main_process or is_local_main_process.
PartialState is primarily designed for "Flat" Data Parallelism (DDP/FSDP) and does not support complex multidimensional parallelisms natively.
PartialState is implemented as a Singleton. Instantiating it anywhere in the code returns the exact same global state. This makes flow of dependencies unclear and also could lead to initialization of your ProcessGroups and distributed environment in unexpected places in your code.
TorchTitan (ParallelDims)
TorchTitan is the most similar framework to d9d in spirit, as both are built on top of native PyTorch 2.x DeviceMesh abstractions.
However, ParallelDims in TorchTitan is more like a mesh factory rather than global distributed environment controller.
d9d (DistributedContext)
d9d positions DistributedContext as the explicit controller for managing all the distributed environment.
DistributedContextis a standard object that is instantiated and passed explicitly to dependent components. This ensures that the initialization of process groups happens exactly when and where the developer intends, making the initialization flow transparent.- It replaces manual rank arithmetic with formalized and native to PyTorch
DeviceMeshabstractions. - Functionally, it elevates the mesh system into an active runtime controller. It bundles timeout management, context-aware logging, and node-level synchronization.
DeviceMesh Domains
Modern architectures require different parallelism strategies for different parts of the model (e.g., standard dense layers vs. Mixture-of-Experts layers). d9d handles this by abstracting these strategies into specific DeviceMesh Domains.
The underlying physical GPUs are immutable, but how we view them changes depending on what we are working with (distributing MoE layers, Dense layers, distributing input batch). DeviceMesh object for specific domain is retrieved via dist_ctx.mesh_for(domain_name).
Demonstration Video:
For better understanding domains, we have prepared a quick demonstration video on YouTube.
Regular Domain (regular)
- Identifier:
REGULAR_DOMAINor"regular" - Purpose: The most granular mesh view for fully decomposed parallelism. Used for setting up logging and seeding.
- Dimensions:
pp: Pipeline Paralleldp_replicate: Data Parallel (DDP style)dp_shard: Data Parallel (FSDP style)cp_shard: Context Parallel (FSDP style)cp_replicate: Context Parallel (DDP style)tp: Tensor Parallelism
Expert Domain (expert)
- Identifier:
EXPERT_DOMAINor"expert" - Purpose: Mesh view optimized for distributing MoE (Mixture of Experts) layers. It is intended that sparse expert layers should be sharded across
ep_sharddimension and replicated acrossep_replicatedimension. - Dimensions:
pp: Pipeline Parallelep_replicate: Combined Replication Dimension ((DP * CP) // EP)ep_shard: Expert Parallel Dimension
Dense Domain (dense)
- Identifier:
DENSE_DOMAINor"dense" - Purpose: Mesh view for distributing dense layers.
- Dimensions:
pp: Pipeline Paralleldp_replicate: Data Parallel for replication using HSDPdp_cp_shard: Merged Data and Context Parallel dimension for sharding using HSDPcp_replicate: Context Parallel for replicationtp: Tensor Parallel
Batch Domain (batch)
- Identifier:
BATCH_DOMAINor"batch" - Purpose: Mesh view for distributing batch tensor and setting up DataLoader sharding.
- Dimensions:
pp: Pipeline Paralleldp: Data Parallelcp: Context Paralleltp: Tensor Parallel
Flat Domain (flat)
- Identifier:
FLAT_DOMAINor"flat" - Purpose: Mesh view with a single dimension.
- Dimensions:
world: World Size
Usage
Initialization
The system is usually initialized via DeviceMeshParameters.
from d9d.core.dist_context import DeviceMeshParameters
# Define the topology
params = DeviceMeshParameters(
pipeline_parallel=2,
data_parallel_replicate=8,
data_parallel_shard=1,
context_parallel_replicate=1,
context_parallel_shard=1,
expert_parallel=8,
tensor_parallel=1
)
dist_ctx = params.build()
Accessing DeviceMesh Domains
from torch.distributed import DeviceMesh
from d9d.core.dist_context import DistributedContext, DENSE_DOMAIN
dist_ctx: DistributedContext = ...
mesh_dense: DeviceMesh = dist_ctx.mesh_for(DENSE_DOMAIN)
Rank Utilities
Accessing rank information.
if dist_ctx.is_main_process:
print("I am Global Rank 0 (Master)")
if dist_ctx.is_local_main_process:
print("I am Rank 0 on this specific node")
# Synchronize
dist_ctx.wait_world()
Context Managers
Control execution flow across ranks.
# Ensure only one process per node downloads a file
with dist_ctx.local_main_process_first():
if dist_ctx.is_local_main_process:
download_dataset()
# Others wait here implicitly
# All resume together
d9d.core.dist_context
This package configures the distributed environment and device meshes.
DeviceMeshParameters
Bases: BaseModel
Configuration parameters for initializing Distributed Device Meshes.
Attributes:
| Name | Type | Description |
|---|---|---|
pipeline_parallel |
int
|
Degree of pipeline parallelism (PP). |
data_parallel_replicate |
int
|
Degree of data parallel replication (DDP). |
data_parallel_shard |
int
|
Degree of data parallel sharding (FSDP). |
context_parallel_replicate |
int
|
Degree of context parallel (CP) replication. |
context_parallel_shard |
int
|
Degree of context parallel (FSCP) sharding. |
tensor_parallel |
int
|
Degree of tensor parallelism (TP). |
expert_parallel |
int
|
Degree of expert parallelism (EP/MoE). |
Source code in d9d/core/dist_context/params.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | |
has_data_parallel_replicate
property
Checks if data parallel replication is enabled (degree > 1).
has_data_parallel_shard
property
Checks if data parallel sharding is enabled (degree > 1).
has_expert_parallel
property
Checks if expert parallelism is enabled (degree > 1).
has_pipeline_parallel
property
Checks if pipeline parallelism is enabled (degree > 1).
is_distributed
property
Checks if any form of parallelism is enabled.
build(log_level=logging.INFO)
Initializes the DistributedContext using these parameters.
Returns:
| Type | Description |
|---|---|
DistributedContext
|
A new DistributedContext instance containing the initialized device meshes. |
Source code in d9d/core/dist_context/params.py
103 104 105 106 107 108 109 110 111 | |
DistributedContext
Acts as the single source of truth for the distributed execution environment.
It acts as the central repository for the distributed configuration, managing the creation and synchronization of PyTorch DeviceMeshes for different domains (Regular domain, Expert Parallel domain, ...).
All assertions regarding rank placement, group memberships, and parallel topology must be derived from this context to ensure consistency.
Source code in d9d/core/dist_context/configured.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | |
current_device
property
Returns the CUDA device associated with this rank.
is_local_main_process
property
Checks if the current process is the rank 0 on the specific node.
is_main_process
property
Checks if the current process is the global rank 0.
logger
property
Returns the logger instance configured for distributed logging.
master_addr
property
Returns the IP address or domain name of the master node.
mesh_params
property
Returns the parameters used to initialize this context.
node_rank
property
Returns the index of the node this process is running on.
num_nodes
property
Returns the total number of nodes in the cluster.
local_main_process_first()
Context manager that executes the block on the local main process first.
Other local ranks wait at the entrance. The local main process waits at the exit to synchronize before continuing.
Source code in d9d/core/dist_context/configured.py
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | |
main_process_first()
Context manager that executes the block on the global main process first.
All other ranks wait at the entrance. The global main process waits at the exit to synchronize before continuing.
Source code in d9d/core/dist_context/configured.py
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | |
mesh_for(domain)
Returns the device mesh view associated with a specific logical domain.
Available Domains and Dimensions
regular(REGULAR_DOMAIN): The most granular mesh for fully decomposed parallelism. Dimensions:('pp', 'dp_replicate', 'dp_shard', 'cp_shard', 'cp_replicate', 'tp')expert(EXPERT_DOMAIN): Mesh optimized for distributing MoE (Mixture of Experts) layers. Dimensions:('pp', 'replicate', 'ep')dense(DENSE_DOMAIN): Mesh optimized for distributing dense layers. Dimensions:('pp', 'dp_replicate', 'dp_cp_shard', 'cp_replicate', 'tp')batch(BATCH_DOMAIN): Mesh optimized for distributing input data. Dimensions:('pp', 'dp', 'cp', 'tp')flat(FLAT_DOMAIN): Mesh containing a single dimension with all the processes. Dimensions:('world')
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
domain
|
str
|
The name of the domain to retrieve. |
required |
Returns:
| Type | Description |
|---|---|
DeviceMesh
|
The PyTorch DeviceMesh configured for the requested domain. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the specified domain does not exist. |
Source code in d9d/core/dist_context/configured.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | |
set_timeout(timeout_seconds)
Updates the NCCL/process group timeout for all underlying meshes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timeout_seconds
|
float
|
New timeout duration in seconds. |
required |
Source code in d9d/core/dist_context/configured.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
wait_world()
Blocks process execution until all ranks reach this point.
Source code in d9d/core/dist_context/configured.py
129 130 131 132 133 | |
build_dist_logger(qualifier, level)
Configures and returns a logger instance for d9d.
The logger is configured to write to stdout with a formatter that includes the provided rank qualifier, allowing for easier debugging in distributed logs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
qualifier
|
str
|
A string identifying the current rank's position in the mesh. |
required |
level
|
int
|
Log level to set by default |
required |
Returns:
| Type | Description |
|---|---|
Logger
|
A configured logging.Logger instance. |
Source code in d9d/core/dist_context/log.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | |