Skip to content

Distributed Operations

About

The d9d.core.dist_ops package provides high-level wrappers around torch.distributed collective operations.

While PyTorch's native distributed library is powerful, it often requires significant boilerplate code - specifically the manual pre-allocation of output buffers (e.g., creating a list of empty tensors for all_gather).

d9d simplifies this by handling buffer allocation automatically. It also introduces specialized operators for handling Variadic Shapes, allowing ranks to exchange tensors even when they do not know the incoming tensor shapes beforehand.

Usage Examples

Gathering Tensors

Gathering tensors of identical shapes from all ranks. d9d automatically allocates buffers for this operation.

import torch
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather

# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()

# Each rank has a tensor of the same shape (e.g., [2, 2])
# but different values
local_tensor = torch.ones((2, 2), device="cuda") * rank

# Gather
gathered_tensors = all_gather(local_tensor, group=group)

for i, t in enumerate(gathered_tensors):
    print(f"From rank {i}: {t}")

Gathering Tensors with Variadic Shapes

Gathering tensors where dimensions differ across ranks.

import torch
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather_variadic_shape

# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()

# Rank 0 has shape [1], Rank 1 has shape [2], ...
local_tensor = torch.randn((rank + 1,), device="cuda")

# Gather
# The system automatically handles the shape mismatch
gathered_tensors = all_gather_variadic_shape(local_tensor, group=group)

for i, t in enumerate(gathered_tensors):
    print(f"Rank {i} sent shape: {t.shape}")

Object Communication

Sending arbitrary Python objects between ranks. These objects must be picklable.

import torch.distributed as dist
from d9d.core.dist_context import DistributedContext, REGULAR_DOMAIN
from d9d.core.dist_ops import all_gather_object

# Setup
ctx: DistributedContext = ...
group = ctx.mesh_for(REGULAR_DOMAIN).get_group()
rank = ctx.mesh_for(REGULAR_DOMAIN).get_rank()

# Local data
my_metadata = {
    "rank": rank,
    "the-strongest": "satoru-gojo"
}

# Gather
results = all_gather_object(my_metadata, group=group)

for data in results:
    print(f"Rank {data['rank']} sent {data}")

d9d.core.dist_ops

This module provides high-level wrappers around torch.distributed collective operations.

all_gather(tensor, group, async_op=False)

Gathers tensors from the whole process group to all ranks.

This function assumes that tensors on all ranks have the same shape and dtype as the tensor on the current rank. It automatically allocates the output buffer list.

Parameters:

Name Type Description Default
tensor Tensor

The local tensor to send.

required
group ProcessGroup

The process group to work on.

required
async_op bool

Whether the operation should be asynchronous.

False

Returns:

Type Description
list[Tensor] | tuple[list[Tensor], Work]

If async_op is False: A list of gathered tensors. If async_op is True: A tuple containing (buffer_list, work_handle).

all_gather_object(obj, group)

Gathers picklable objects from the whole process group to all ranks.

This acts as a wrapper around torch.distributed.all_gather_object that automatically initializes the output buffer list on all ranks.

Parameters:

Name Type Description Default
obj T

The local object to send. Must be picklable.

required
group ProcessGroup

The process group to work on.

required

Returns:

Type Description
list[T]

A list of objects containing the data gathered from all ranks.

all_gather_variadic_shape(tensor, group, async_op=False)

Gathers tensors of different shapes from the whole process group to all ranks.

Unlike standard all_gather, this function first communicates the shape of the tensor on every rank allowing for dynamic sizing.

Parameters:

Name Type Description Default
tensor Tensor

The local tensor to send.

required
group ProcessGroup

The process group to work on.

required
async_op bool

Whether the final data gathering should be asynchronous. Note that shape gathering is always synchronous.

False

Returns:

Type Description
list[Tensor] | tuple[list[Tensor], Work]

If async_op is False: A list of gathered tensors of varying shapes. If async_op is True: A tuple containing (buffer_list, work_handle).

gather(tensor, group, group_dst, async_op=False)

Gathers tensors from the process group to a specific destination rank.

This function assumes that tensors on all ranks have the same shape and dtype as the tensor on the current rank. It automatically allocates the output buffer list on the destination.

Parameters:

Name Type Description Default
tensor Tensor

The local tensor to send.

required
group ProcessGroup

The process group to work on.

required
group_dst int

The rank within the group that will receive the tensors.

required
async_op bool

Whether the operation should be asynchronous.

False

Returns:

Type Description
list[Tensor] | tuple[list[Tensor] | None, Work] | None

If async_op is False: A list of tensors on the destination rank, None elsewhere. If async_op is True: A tuple containing (buffer_list, work_handle).

gather_object(obj, group, group_dst)

Gathers picklable objects from the whole process group to a specific destination rank.

This acts as a wrapper around torch.distributed.gather_object that automatically initializes the output buffer list on the destination rank.

Parameters:

Name Type Description Default
obj T

The local object to send. Must be picklable.

required
group ProcessGroup

The process group to work on.

required
group_dst int

The rank within the group that will receive the objects.

required

Returns:

Type Description
list[T] | None

A list of objects from all ranks on the destination rank; None on other ranks.

gather_variadic_shape(tensor, group, group_dst)

Gathers tensors of different shapes from the process group to a specific rank.

This function coordinates shape exchange and uses point-to-point communication (isend/irecv) to gather tensors that may differ in shape across ranks.

Currently, does not support async_op.

Parameters:

Name Type Description Default
tensor Tensor

The local tensor to send.

required
group ProcessGroup

The process group to work on.

required
group_dst int

The rank within the group that will receive the tensors.

required

Returns:

Type Description
list[Tensor] | None

A list of tensors of varying shapes on the destination rank; None on other ranks.