About

The d9d.core.dist_ops package provides high-level wrappers around torch.distributed collective operations.

While PyTorch's native distributed library is powerful, it often requires significant boilerplate code - specifically the manual pre-allocation of output buffers (e.g., creating a list of empty tensors for all_gather).

d9d simplifies this by handling buffer allocation automatically. It also introduces specialized operators for handling Variadic Shapes, allowing ranks to exchange tensors even when they do not know the incoming tensor shapes beforehand.

Usage Examples

Gathering Tensors

Gathering tensors of identical shapes from all ranks. d9d automatically allocates buffers for this operation.

import torch
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather

# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()

# Each rank has a tensor of the same shape (e.g., [2, 2])
# but different values
local_tensor = torch.ones((2, 2), device="cuda") * rank

# Gather
gathered_tensors = all_gather(local_tensor, group=group)

for i, t in enumerate(gathered_tensors):
    print(f"From rank {i}: {t}")

Gathering Tensors with Variadic Shapes

Gathering tensors where dimensions differ across ranks.

import torch
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather_variadic_shape

# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()

# Rank 0 has shape [1], Rank 1 has shape [2], ...
local_tensor = torch.randn((rank + 1,), device="cuda")

# Gather
# The system automatically handles the shape mismatch
gathered_tensors = all_gather_variadic_shape(local_tensor, group=group)

for i, t in enumerate(gathered_tensors):
    print(f"Rank {i} sent shape: {t.shape}")

Object Communication

Sending arbitrary Python objects between ranks. These objects must be picklable.

import torch.distributed as dist
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather_object

# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()

# Local data
my_metadata = {
    "rank": rank,
    "the-strongest": "satoru-gojo"
}

# Gather
results = all_gather_object(my_metadata, group=group)

for data in results:
    print(f"Rank {data['rank']} sent {data}")

d9d.core.dist_ops

This module provides high-level wrappers around torch.distributed collective operations.

all_gather(tensor, group, async_op=False)

Gathers tensors from the whole process group to all ranks.

This function assumes that tensors on all ranks have the same shape and dtype as the tensor on the current rank. It automatically allocates the output buffer list.

Parameters:

Name Type Description Default
tensor Tensor

The local tensor to send.

required
group ProcessGroup

The process group to work on.

required
async_op bool

Whether the operation should be asynchronous.

False

Returns:

Type Description
list[Tensor] | tuple[list[Tensor], Work]

If async_op is False: A list of gathered tensors.

list[Tensor] | tuple[list[Tensor], Work]

If async_op is True: A tuple containing (buffer_list, work_handle).

Source code in d9d/core/dist_ops/tensor.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def all_gather(
        tensor: torch.Tensor,
        group: dist.ProcessGroup,
        async_op: bool = False
) -> list[torch.Tensor] | tuple[list[torch.Tensor], dist.Work]:
    """
    Gathers tensors from the whole process group to all ranks.

    This function assumes that tensors on all ranks have the same shape and dtype
    as the tensor on the current rank. It automatically allocates the output
    buffer list.

    Args:
        tensor: The local tensor to send.
        group: The process group to work on.
        async_op: Whether the operation should be asynchronous.

    Returns:
        If async_op is False: A list of gathered tensors.
        If async_op is True: A tuple containing (buffer_list, work_handle).
    """

    save_list = [torch.empty_like(tensor) for _ in range(group.size())]
    work = dist.all_gather(
        save_list,
        tensor,
        group=group,
        async_op=async_op
    )
    if async_op:
        return save_list, work
    else:
        return save_list

all_gather_object(obj, group)

Gathers picklable objects from the whole process group to all ranks.

This acts as a wrapper around torch.distributed.all_gather_object that automatically initializes the output buffer list on all ranks.

Parameters:

Name Type Description Default
obj T

The local object to send. Must be picklable.

required
group ProcessGroup

The process group to work on.

required

Returns:

Type Description
list[T]

A list of objects containing the data gathered from all ranks.

Source code in d9d/core/dist_ops/object.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def all_gather_object(
        obj: T,
        group: dist.ProcessGroup
) -> list[T]:
    """
    Gathers picklable objects from the whole process group to all ranks.

    This acts as a wrapper around torch.distributed.all_gather_object that automatically
    initializes the output buffer list on all ranks.

    Args:
        obj: The local object to send. Must be picklable.
        group: The process group to work on.

    Returns:
        A list of objects containing the data gathered from all ranks.
    """
    # We initialize with None, but we cast to list[T] because we know
    # dist.gather_object will populate these slots with actual objects.
    save_list = cast(list[T], [None for _ in range(group.size())])
    dist.all_gather_object(
        save_list,
        obj,
        group=group
    )
    return save_list

all_gather_variadic_shape(tensor, group, async_op=False)

Gathers tensors of different shapes from the whole process group to all ranks.

Unlike standard all_gather, this function first communicates the shape of the tensor on every rank allowing for dynamic sizing.

Parameters:

Name Type Description Default
tensor Tensor

The local tensor to send.

required
group ProcessGroup

The process group to work on.

required
async_op bool

Whether the final data gathering should be asynchronous. Note that shape gathering is always synchronous.

False

Returns:

Type Description
list[Tensor] | tuple[list[Tensor], Work]

If async_op is False: A list of gathered tensors of varying shapes.

list[Tensor] | tuple[list[Tensor], Work]

If async_op is True: A tuple containing (buffer_list, work_handle).

Source code in d9d/core/dist_ops/tensor.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def all_gather_variadic_shape(
        tensor: torch.Tensor,
        group: dist.ProcessGroup,
        async_op: bool = False
) -> list[torch.Tensor] | tuple[list[torch.Tensor], dist.Work]:
    """
    Gathers tensors of different shapes from the whole process group to all ranks.

    Unlike standard all_gather, this function first communicates the shape of the
    tensor on every rank allowing for dynamic sizing.

    Args:
        tensor: The local tensor to send.
        group: The process group to work on.
        async_op: Whether the final data gathering should be asynchronous.
                  Note that shape gathering is always synchronous.

    Returns:
        If async_op is False: A list of gathered tensors of varying shapes.
        If async_op is True: A tuple containing (buffer_list, work_handle).
    """

    all_shape = _all_gather_shapes(tensor, group)

    all_result = [torch.empty(tuple(shape), dtype=tensor.dtype, device=tensor.device) for shape in all_shape]
    all_result_wait = dist.all_gather(
        all_result,
        tensor,
        group=group,
        async_op=async_op
    )
    if async_op:
        return all_result, all_result_wait
    else:
        return all_result

gather(tensor, group, group_dst, async_op=False)

Gathers tensors from the process group to a specific destination rank.

This function assumes that tensors on all ranks have the same shape and dtype as the tensor on the current rank. It automatically allocates the output buffer list on the destination.

Parameters:

Name Type Description Default
tensor Tensor

The local tensor to send.

required
group ProcessGroup

The process group to work on.

required
group_dst int

The rank within the group that will receive the tensors.

required
async_op bool

Whether the operation should be asynchronous.

False

Returns:

Type Description
list[Tensor] | tuple[list[Tensor] | None, Work] | None

If async_op is False: A list of tensors on the destination rank, None elsewhere.

list[Tensor] | tuple[list[Tensor] | None, Work] | None

If async_op is True: A tuple containing (buffer_list, work_handle).

Source code in d9d/core/dist_ops/tensor.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def gather(
        tensor: torch.Tensor,
        group: dist.ProcessGroup,
        group_dst: int,
        async_op: bool = False
) -> list[torch.Tensor] | tuple[list[torch.Tensor] | None, dist.Work] | None:
    """
    Gathers tensors from the process group to a specific destination rank.

    This function assumes that tensors on all ranks have the same shape and dtype
    as the tensor on the current rank. It automatically allocates the output
    buffer list on the destination.

    Args:
        tensor: The local tensor to send.
        group: The process group to work on.
        group_dst: The rank within the group that will receive the tensors.
        async_op: Whether the operation should be asynchronous.

    Returns:
        If async_op is False: A list of tensors on the destination rank, None elsewhere.
        If async_op is True: A tuple containing (buffer_list, work_handle).
    """

    if group.rank() == group_dst:
        save_list = [torch.empty_like(tensor) for _ in range(group.size())]
    else:
        save_list = None

    work = dist.gather(
        tensor,
        save_list,
        group=group,
        group_dst=group_dst,
        async_op=async_op
    )

    if async_op:
        return save_list, work
    else:
        return save_list

gather_object(obj, group, group_dst)

Gathers picklable objects from the whole process group to a specific destination rank.

This acts as a wrapper around torch.distributed.gather_object that automatically initializes the output buffer list on the destination rank.

Parameters:

Name Type Description Default
obj T

The local object to send. Must be picklable.

required
group ProcessGroup

The process group to work on.

required
group_dst int

The rank within the group that will receive the objects.

required

Returns:

Type Description
list[T] | None

A list of objects from all ranks on the destination rank; None on other ranks.

Source code in d9d/core/dist_ops/object.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def gather_object(
        obj: T,
        group: dist.ProcessGroup,
        group_dst: int
) -> list[T] | None:
    """
    Gathers picklable objects from the whole process group to a specific destination rank.

    This acts as a wrapper around torch.distributed.gather_object that automatically
    initializes the output buffer list on the destination rank.

    Args:
        obj: The local object to send. Must be picklable.
        group: The process group to work on.
        group_dst: The rank within the group that will receive the objects.

    Returns:
        A list of objects from all ranks on the destination rank; None on other ranks.
    """

    if group.rank() == group_dst:
        # We initialize with None, but we cast to list[T] because we know
        # dist.gather_object will populate these slots with actual objects.
        save_list = cast(list[T], [None for _ in range(group.size())])
    else:
        save_list = None
    dist.gather_object(
        obj,
        save_list,
        group=group,
        group_dst=group_dst
    )
    return save_list

gather_variadic_shape(tensor, group, group_dst)

Gathers tensors of different shapes from the process group to a specific rank.

This function coordinates shape exchange and uses point-to-point communication (isend/irecv) to gather tensors that may differ in shape across ranks.

Currently, does not support async_op.

Parameters:

Name Type Description Default
tensor Tensor

The local tensor to send.

required
group ProcessGroup

The process group to work on.

required
group_dst int

The rank within the group that will receive the tensors.

required

Returns:

Type Description
list[Tensor] | None

A list of tensors of varying shapes on the destination rank; None on other ranks.

Source code in d9d/core/dist_ops/tensor.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def gather_variadic_shape(
        tensor: torch.Tensor,
        group: dist.ProcessGroup,
        group_dst: int
) -> list[torch.Tensor] | None:
    """
    Gathers tensors of different shapes from the process group to a specific rank.

    This function coordinates shape exchange and uses point-to-point communication
    (isend/irecv) to gather tensors that may differ in shape across ranks.

    Currently, does not support async_op.

    Args:
        tensor: The local tensor to send.
        group: The process group to work on.
        group_dst: The rank within the group that will receive the tensors.

    Returns:
        A list of tensors of varying shapes on the destination rank; None on other ranks.
    """

    is_current_dst = group.rank() == group_dst

    all_shape = _all_gather_shapes(tensor, group)

    if is_current_dst:
        all_recv_futures: list[dist.Work] = []
        all_result: list[torch.Tensor] = cast(list[torch.Tensor], [None for _ in range(group.size())])
        for group_src_i in range(group.size()):
            if group_src_i == group_dst:
                all_result[group_src_i] = tensor
                continue
            all_result[group_src_i] = torch.empty(
                tuple(all_shape[group_src_i]), dtype=tensor.dtype, device=tensor.device
            )
            all_recv_future = dist.irecv(all_result[group_src_i], group=group, group_src=group_src_i)
            all_recv_future = cast(dist.Work, all_recv_future)  # we know we are on dst rank
            all_recv_futures.append(all_recv_future)
        for recv_future in all_recv_futures:
            recv_future.wait()
        return all_result
    else:
        dist.isend(tensor=tensor, group=group, group_dst=group_dst)
        return None