About
The d9d.core.dist_ops package provides high-level wrappers around torch.distributed collective operations.
While PyTorch's native distributed library is powerful, it often requires significant boilerplate code - specifically the manual pre-allocation of output buffers (e.g., creating a list of empty tensors for all_gather).
d9d simplifies this by handling buffer allocation automatically. It also introduces specialized operators for handling Variadic Shapes, allowing ranks to exchange tensors even when they do not know the incoming tensor shapes beforehand.
Usage Examples
Gathering Tensors
Gathering tensors of identical shapes from all ranks. d9d automatically allocates buffers for this operation.
import torch
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather
# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()
# Each rank has a tensor of the same shape (e.g., [2, 2])
# but different values
local_tensor = torch.ones((2, 2), device="cuda") * rank
# Gather
gathered_tensors = all_gather(local_tensor, group=group)
for i, t in enumerate(gathered_tensors):
print(f"From rank {i}: {t}")
Gathering Tensors with Variadic Shapes
Gathering tensors where dimensions differ across ranks.
import torch
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather_variadic_shape
# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()
# Rank 0 has shape [1], Rank 1 has shape [2], ...
local_tensor = torch.randn((rank + 1,), device="cuda")
# Gather
# The system automatically handles the shape mismatch
gathered_tensors = all_gather_variadic_shape(local_tensor, group=group)
for i, t in enumerate(gathered_tensors):
print(f"Rank {i} sent shape: {t.shape}")
Object Communication
Sending arbitrary Python objects between ranks. These objects must be picklable.
import torch.distributed as dist
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather_object
# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()
# Local data
my_metadata = {
"rank": rank,
"the-strongest": "satoru-gojo"
}
# Gather
results = all_gather_object(my_metadata, group=group)
for data in results:
print(f"Rank {data['rank']} sent {data}")
d9d.core.dist_ops
This module provides high-level wrappers around torch.distributed collective operations.
all_gather(tensor, group, async_op=False)
Gathers tensors from the whole process group to all ranks.
This function assumes that tensors on all ranks have the same shape and dtype as the tensor on the current rank. It automatically allocates the output buffer list.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
The local tensor to send. |
required |
group
|
ProcessGroup
|
The process group to work on. |
required |
async_op
|
bool
|
Whether the operation should be asynchronous. |
False
|
Returns:
| Type | Description |
|---|---|
list[Tensor] | tuple[list[Tensor], Work]
|
If async_op is False: A list of gathered tensors. |
list[Tensor] | tuple[list[Tensor], Work]
|
If async_op is True: A tuple containing (buffer_list, work_handle). |
Source code in d9d/core/dist_ops/tensor.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | |
all_gather_object(obj, group)
Gathers picklable objects from the whole process group to all ranks.
This acts as a wrapper around torch.distributed.all_gather_object that automatically initializes the output buffer list on all ranks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obj
|
T
|
The local object to send. Must be picklable. |
required |
group
|
ProcessGroup
|
The process group to work on. |
required |
Returns:
| Type | Description |
|---|---|
list[T]
|
A list of objects containing the data gathered from all ranks. |
Source code in d9d/core/dist_ops/object.py
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | |
all_gather_variadic_shape(tensor, group, async_op=False)
Gathers tensors of different shapes from the whole process group to all ranks.
Unlike standard all_gather, this function first communicates the shape of the tensor on every rank allowing for dynamic sizing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
The local tensor to send. |
required |
group
|
ProcessGroup
|
The process group to work on. |
required |
async_op
|
bool
|
Whether the final data gathering should be asynchronous. Note that shape gathering is always synchronous. |
False
|
Returns:
| Type | Description |
|---|---|
list[Tensor] | tuple[list[Tensor], Work]
|
If async_op is False: A list of gathered tensors of varying shapes. |
list[Tensor] | tuple[list[Tensor], Work]
|
If async_op is True: A tuple containing (buffer_list, work_handle). |
Source code in d9d/core/dist_ops/tensor.py
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | |
gather(tensor, group, group_dst, async_op=False)
Gathers tensors from the process group to a specific destination rank.
This function assumes that tensors on all ranks have the same shape and dtype as the tensor on the current rank. It automatically allocates the output buffer list on the destination.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
The local tensor to send. |
required |
group
|
ProcessGroup
|
The process group to work on. |
required |
group_dst
|
int
|
The rank within the group that will receive the tensors. |
required |
async_op
|
bool
|
Whether the operation should be asynchronous. |
False
|
Returns:
| Type | Description |
|---|---|
list[Tensor] | tuple[list[Tensor] | None, Work] | None
|
If async_op is False: A list of tensors on the destination rank, None elsewhere. |
list[Tensor] | tuple[list[Tensor] | None, Work] | None
|
If async_op is True: A tuple containing (buffer_list, work_handle). |
Source code in d9d/core/dist_ops/tensor.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | |
gather_object(obj, group, group_dst)
Gathers picklable objects from the whole process group to a specific destination rank.
This acts as a wrapper around torch.distributed.gather_object that automatically initializes the output buffer list on the destination rank.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obj
|
T
|
The local object to send. Must be picklable. |
required |
group
|
ProcessGroup
|
The process group to work on. |
required |
group_dst
|
int
|
The rank within the group that will receive the objects. |
required |
Returns:
| Type | Description |
|---|---|
list[T] | None
|
A list of objects from all ranks on the destination rank; None on other ranks. |
Source code in d9d/core/dist_ops/object.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | |
gather_variadic_shape(tensor, group, group_dst)
Gathers tensors of different shapes from the process group to a specific rank.
This function coordinates shape exchange and uses point-to-point communication (isend/irecv) to gather tensors that may differ in shape across ranks.
Currently, does not support async_op.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
The local tensor to send. |
required |
group
|
ProcessGroup
|
The process group to work on. |
required |
group_dst
|
int
|
The rank within the group that will receive the tensors. |
required |
Returns:
| Type | Description |
|---|---|
list[Tensor] | None
|
A list of tensors of varying shapes on the destination rank; None on other ranks. |
Source code in d9d/core/dist_ops/tensor.py
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | |