About

The d9d.core.sharding package provides utilities for splitting and reconstructing complex nested structures (PyTrees) of PyTorch Tensors and Python Lists.

Sharding Spec

A Sharding Spec is a PyTree that mirrors the structure of your data (e.g., a State Dict).

  • Structure: Mirrors the data hierarchy. The spec structure is used to traverse the data; sharding operations flatten the data tree up to the leaves defined in the spec.
  • Leaves:
    • d9d.core.sharding.SpecShard(dim, do_stack=False):
      • Tensors: The tensor is split along dimension dim.
      • Lists: The list is split into chunks. dim must be 0.
      • do_stack: If True, tensors are unbound/stacked (reducing/increasing dimensionality). If False (default), tensors are split/concatenated (preserving dimensionality).
    • d9d.core.sharding.SpecReplicate (or None): The item is replicated (kept as-is/not split) across all shards.

Helper functions like shard_spec_on_dim allow generating these specs automatically.

d9d.core.sharding

SpecReplicate dataclass

Specifies that a leaf node should be replicated across all shards.

Source code in d9d/core/sharding/spec.py
 6
 7
 8
 9
10
@dataclasses.dataclass(slots=True, frozen=True)
class SpecReplicate:
    """
    Specifies that a leaf node should be replicated across all shards.
    """

SpecShard dataclass

Specifies that a leaf node should be split along a specific dimension.

Attributes:

Name Type Description
dim int

The dimension to split.

do_stack bool

If True, sharding will squeeze the sharded dimension (it should be exactly the num_shards length)

Source code in d9d/core/sharding/spec.py
13
14
15
16
17
18
19
20
21
22
23
24
@dataclasses.dataclass(slots=True, frozen=True)
class SpecShard:
    """
    Specifies that a leaf node should be split along a specific dimension.

    Attributes:
        dim: The dimension to split.
        do_stack: If True, sharding will squeeze the sharded dimension (it should be exactly the num_shards length)
    """

    dim: int
    do_stack: bool = False

shard_spec_nothing(tree)

Creates a sharding specification where no sharding is performed.

This effectively clones the tree structure but replaces every leaf with SpecReplicate.

Parameters:

Name Type Description Default
tree PyTree[Any]

The input PyTree structure.

required

Returns:

Type Description
ShardingSpec

A new PyTree matching the input structure, containing strictly SpecReplicate for all leaves.

Source code in d9d/core/sharding/auto_spec.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def shard_spec_nothing(tree: PyTree[Any]) -> ShardingSpec:
    """
    Creates a sharding specification where no sharding is performed.

    This effectively clones the tree structure but replaces every leaf with SpecReplicate.

    Args:
        tree: The input PyTree structure.

    Returns:
        A new PyTree matching the input structure, containing strictly SpecReplicate for all leaves.
    """

    return pytree.tree_map(lambda _: SpecReplicate(), tree, is_leaf=lambda x: isinstance(x, (torch.Tensor, list)))

shard_spec_on_dim(tree, dim)

Creates a sharding specification to split all tensors in the tree on a specific dimension.

Iterates over the input tree: * If a leaf is a Tensor with enough dimensions, it is mapped to a SpecShard(dim) object. * If a leaf is a list, it is mapped to a SpecShard(0) object (only dim=0 is allowed for lists). * Other types and 0-dim tensors are mapped to SpecReplicate.

Parameters:

Name Type Description Default
tree PyTree[Any]

The input PyTree structure.

required
dim int

The dimension index to shard eligible tensors on.

required

Returns:

Type Description
ShardingSpec

A new PyTree matching the input structure, containing SpecShard or SpecReplicate objects.

Raises:

Type Description
ValueError

If a tensor exists in the tree with rank less than or equal to 'dim'.

Source code in d9d/core/sharding/auto_spec.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def shard_spec_on_dim(tree: PyTree[Any], dim: int) -> ShardingSpec:
    """
    Creates a sharding specification to split all tensors in the tree on a specific dimension.

    Iterates over the input tree:
    *   If a leaf is a Tensor with enough dimensions, it is mapped to a SpecShard(dim) object.
    *   If a leaf is a list, it is mapped to a SpecShard(0) object (only dim=0 is allowed for lists).
    *   Other types and 0-dim tensors are mapped to SpecReplicate.

    Args:
        tree: The input PyTree structure.
        dim: The dimension index to shard eligible tensors on.

    Returns:
        A new PyTree matching the input structure, containing SpecShard or SpecReplicate objects.

    Raises:
        ValueError: If a tensor exists in the tree with rank less than or equal to 'dim'.
    """

    return pytree.tree_map(
        lambda x: _tree_item_to_shard(x, dim),
        tree,
        is_leaf=lambda x: isinstance(x, (torch.Tensor, list))
    )

shard_tree(tree, sharding_spec, num_shards, enforce_even_split)

Shards a PyTree into a tuple of PyTrees, one for each shard rank.

This function takes a single global data structure and splits it into num_shards structures.

  • If a spec leaf is a SpecShard(dim), the tensor or list is split along that dimension, and the i-th slice goes to the i-th output tree.
  • If a spec leaf is SpecReplicate, the item is replicated (reference copy) to all output trees.

Parameters:

Name Type Description Default
tree TSameTree

The structure containing tensors to be sharded.

required
sharding_spec ShardingSpec

A structure matching 'tree' containing SpecShard or SpecReplicate objects.

required
num_shards int

The total number of shards to split the tensors into.

required
enforce_even_split bool

If True, raises a ValueError if a tensor's dimension size is not perfectly divisible by num_shards.

required

Returns:

Type Description
TSameTree

A tuple of length num_shards. Each element is a PyTree matching

...

the structure of the input tree, containing the local data for

tuple[TSameTree, ...]

that specific rank.

Raises:

Type Description
ValueError

If tree structures do not match, or valid sharding conditions are not met.

Source code in d9d/core/sharding/shard.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def shard_tree(
        tree: TSameTree,
        sharding_spec: ShardingSpec,
        num_shards: int,
        enforce_even_split: bool
) -> tuple[TSameTree, ...]:
    """
    Shards a PyTree into a tuple of PyTrees, one for each shard rank.

    This function takes a single global data structure and splits it into `num_shards`
    structures.

    *   If a spec leaf is a ``SpecShard(dim)``, the tensor or list is split along that dimension,
        and the ``i``-th slice goes to the ``i``-th output tree.
    *   If a spec leaf is ``SpecReplicate``, the item is replicated (reference copy) to all
        output trees.

    Args:
        tree: The structure containing tensors to be sharded.
        sharding_spec: A structure matching 'tree' containing ``SpecShard`` or ``SpecReplicate`` objects.
        num_shards: The total number of shards to split the tensors into.
        enforce_even_split: If True, raises a ValueError if a tensor's dimension
            size is not perfectly divisible by ``num_shards``.

    Returns:
        A tuple of length ``num_shards``. Each element is a PyTree matching
        the structure of the input ``tree``, containing the local data for
        that specific rank.

    Raises:
        ValueError: If tree structures do not match, or valid sharding conditions
            are not met.
    """
    flat_spec, spec_struct = pytree.tree_flatten(sharding_spec)

    try:
        flat_tree = spec_struct.flatten_up_to(tree)
    except (ValueError, TypeError) as e:
        raise ValueError("Tree structure does not match sharding spec") from e

    sharded_leaves_per_node = [
        _shard_leaf_to_list(item, spec, num_shards, enforce_even_split)
        for item, spec in zip(flat_tree, flat_spec, strict=True)
    ]

    rank_leaves = list(zip(*sharded_leaves_per_node, strict=True))

    return tuple(spec_struct.unflatten(leaves) for leaves in rank_leaves)

unshard_tree(sharded_trees, sharding_spec)

Combines a sequence of PyTrees (one per rank) into a single global PyTree.

This is the inverse of shard_tree. It iterates over the provided trees, gathering corresponding leaves from each rank.

  • If the spec for a leaf is SpecShard(dim), the tensors from all ranks are concatenated (or stacked if do_stack=True) along that dimension.
  • If the spec is SpecReplicate, it assumes the data is replicated and takes the item from the first rank.

Parameters:

Name Type Description Default
sharded_trees Sequence[TSameTree]

A sequence (list or tuple) of PyTrees. There must be one tree for each shard rank, and they must all share the same structure as sharding_spec.

required
sharding_spec ShardingSpec

A structure matching the input trees containing SpecShard or SpecReplicate objects.

required

Returns:

Type Description
TSameTree

A single PyTree where distinct shards have been merged into full tensors.

Raises:

Type Description
ValueError

If sharded_trees is empty, or if unit structures do not match the spec.

Source code in d9d/core/sharding/unshard.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def unshard_tree(
        sharded_trees: Sequence[TSameTree],
        sharding_spec: ShardingSpec
) -> TSameTree:
    """
    Combines a sequence of PyTrees (one per rank) into a single global PyTree.

    This is the inverse of ``shard_tree``. It iterates over the provided trees,
    gathering corresponding leaves from each rank.

    *   If the spec for a leaf is ``SpecShard(dim)``, the tensors from all ranks are
        concatenated (or stacked if ``do_stack=True``) along that dimension.
    *   If the spec is ``SpecReplicate``, it assumes the data is replicated
        and takes the item from the first rank.

    Args:
        sharded_trees: A sequence (list or tuple) of PyTrees. There must be
            one tree for each shard rank, and they must all share the same
            structure as ``sharding_spec``.
        sharding_spec: A structure matching the input trees containing
            ``SpecShard`` or ``SpecReplicate`` objects.

    Returns:
        A single PyTree where distinct shards have been merged into full tensors.

    Raises:
        ValueError: If ``sharded_trees`` is empty, or if unit structures do
            not match the spec.
    """
    if not sharded_trees:
        raise ValueError("sharded_trees sequence cannot be empty")

    flat_spec, spec_struct = pytree.tree_flatten(sharding_spec)

    flat_shards_per_rank = []
    for i, tree in enumerate(sharded_trees):
        try:
            leaves = spec_struct.flatten_up_to(tree)
        except (ValueError, TypeError) as e:
            raise ValueError(
                f"Structure mismatch at shard {i}: tree does not match sharding spec structure"
            ) from e

        flat_shards_per_rank.append(leaves)

    grouped_leaves = list(zip(*flat_shards_per_rank, strict=True))

    reconstructed_leaves = [
        _unshard_leaf_from_group(group, spec)
        for group, spec in zip(grouped_leaves, flat_spec, strict=True)
    ]

    return spec_struct.unflatten(reconstructed_leaves)