About
The d9d.core.sharding package provides utilities for splitting and reconstructing complex nested structures (PyTrees) of PyTorch Tensors and Python Lists.
Sharding Spec
A Sharding Spec is a PyTree that mirrors the structure of your data (e.g., a State Dict).
- Structure: Mirrors the data hierarchy. The spec structure is used to traverse the data; sharding operations flatten the data tree up to the leaves defined in the spec.
- Leaves:
d9d.core.sharding.SpecShard(dim, do_stack=False):- Tensors: The tensor is split along dimension
dim. - Lists: The list is split into chunks.
dimmust be0. - do_stack: If
True, tensors are unbound/stacked (reducing/increasing dimensionality). IfFalse(default), tensors are split/concatenated (preserving dimensionality).
- Tensors: The tensor is split along dimension
d9d.core.sharding.SpecReplicate(orNone): The item is replicated (kept as-is/not split) across all shards.
Helper functions like shard_spec_on_dim allow generating these specs automatically.
d9d.core.sharding
SpecReplicate
dataclass
Specifies that a leaf node should be replicated across all shards.
Source code in d9d/core/sharding/spec.py
6 7 8 9 10 | |
SpecShard
dataclass
Specifies that a leaf node should be split along a specific dimension.
Attributes:
| Name | Type | Description |
|---|---|---|
dim |
int
|
The dimension to split. |
do_stack |
bool
|
If True, sharding will squeeze the sharded dimension (it should be exactly the num_shards length) |
Source code in d9d/core/sharding/spec.py
13 14 15 16 17 18 19 20 21 22 23 24 | |
shard_spec_nothing(tree)
Creates a sharding specification where no sharding is performed.
This effectively clones the tree structure but replaces every leaf with SpecReplicate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree[Any]
|
The input PyTree structure. |
required |
Returns:
| Type | Description |
|---|---|
ShardingSpec
|
A new PyTree matching the input structure, containing strictly SpecReplicate for all leaves. |
Source code in d9d/core/sharding/auto_spec.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 | |
shard_spec_on_dim(tree, dim)
Creates a sharding specification to split all tensors in the tree on a specific dimension.
Iterates over the input tree: * If a leaf is a Tensor with enough dimensions, it is mapped to a SpecShard(dim) object. * If a leaf is a list, it is mapped to a SpecShard(0) object (only dim=0 is allowed for lists). * Other types and 0-dim tensors are mapped to SpecReplicate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree[Any]
|
The input PyTree structure. |
required |
dim
|
int
|
The dimension index to shard eligible tensors on. |
required |
Returns:
| Type | Description |
|---|---|
ShardingSpec
|
A new PyTree matching the input structure, containing SpecShard or SpecReplicate objects. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a tensor exists in the tree with rank less than or equal to 'dim'. |
Source code in d9d/core/sharding/auto_spec.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | |
shard_tree(tree, sharding_spec, num_shards, enforce_even_split)
Shards a PyTree into a tuple of PyTrees, one for each shard rank.
This function takes a single global data structure and splits it into num_shards
structures.
- If a spec leaf is a
SpecShard(dim), the tensor or list is split along that dimension, and thei-th slice goes to thei-th output tree. - If a spec leaf is
SpecReplicate, the item is replicated (reference copy) to all output trees.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
TSameTree
|
The structure containing tensors to be sharded. |
required |
sharding_spec
|
ShardingSpec
|
A structure matching 'tree' containing |
required |
num_shards
|
int
|
The total number of shards to split the tensors into. |
required |
enforce_even_split
|
bool
|
If True, raises a ValueError if a tensor's dimension
size is not perfectly divisible by |
required |
Returns:
| Type | Description |
|---|---|
TSameTree
|
A tuple of length |
...
|
the structure of the input |
tuple[TSameTree, ...]
|
that specific rank. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If tree structures do not match, or valid sharding conditions are not met. |
Source code in d9d/core/sharding/shard.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |
unshard_tree(sharded_trees, sharding_spec)
Combines a sequence of PyTrees (one per rank) into a single global PyTree.
This is the inverse of shard_tree. It iterates over the provided trees,
gathering corresponding leaves from each rank.
- If the spec for a leaf is
SpecShard(dim), the tensors from all ranks are concatenated (or stacked ifdo_stack=True) along that dimension. - If the spec is
SpecReplicate, it assumes the data is replicated and takes the item from the first rank.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sharded_trees
|
Sequence[TSameTree]
|
A sequence (list or tuple) of PyTrees. There must be
one tree for each shard rank, and they must all share the same
structure as |
required |
sharding_spec
|
ShardingSpec
|
A structure matching the input trees containing
|
required |
Returns:
| Type | Description |
|---|---|
TSameTree
|
A single PyTree where distinct shards have been merged into full tensors. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in d9d/core/sharding/unshard.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | |