PyTree Sharding Utilities
About
The d9d.core.sharding package provides utilities for splitting and reconstructing complex nested structures (PyTrees) of PyTorch Tensors and Python Lists.
Sharding Spec
A Sharding Spec is a PyTree that mirrors the structure of your data (e.g., a State Dict).
- Structure: Mirrors the data hierarchy. The spec structure is used to traverse the data; sharding operations flatten the data tree up to the leaves defined in the spec.
- Leaves:
d9d.core.sharding.SpecShard(dim, do_stack=False):- Tensors: The tensor is split along dimension
dim. - Lists: The list is split into chunks.
dimmust be0. - do_stack: If
True, tensors are unbound/stacked (reducing/increasing dimensionality). IfFalse(default), tensors are split/concatenated (preserving dimensionality).
- Tensors: The tensor is split along dimension
d9d.core.sharding.SpecReplicate(orNone): The item is replicated (kept as-is/not split) across all shards.
Helper functions like shard_spec_on_dim allow generating these specs automatically.
d9d.core.sharding
SpecReplicate
dataclass
Specifies that a leaf node should be replicated across all shards.
SpecShard
dataclass
shard_spec_nothing(tree)
Creates a sharding specification where no sharding is performed.
This effectively clones the tree structure but replaces every leaf with SpecReplicate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree[Any]
|
The input PyTree structure. |
required |
Returns:
| Type | Description |
|---|---|
ShardingSpec
|
A new PyTree matching the input structure, containing strictly SpecReplicate for all leaves. |
shard_spec_on_dim(tree, dim)
Creates a sharding specification to split all tensors in the tree on a specific dimension.
Iterates over the input tree: * If a leaf is a Tensor with enough dimensions, it is mapped to a SpecShard(dim) object. * If a leaf is a list, it is mapped to a SpecShard(0) object (only dim=0 is allowed for lists). * Other types and 0-dim tensors are mapped to SpecReplicate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree[Any]
|
The input PyTree structure. |
required |
dim
|
int
|
The dimension index to shard eligible tensors on. |
required |
Returns:
| Type | Description |
|---|---|
ShardingSpec
|
A new PyTree matching the input structure, containing SpecShard or SpecReplicate objects. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a tensor exists in the tree with rank less than or equal to 'dim'. |
shard_tree(tree, sharding_spec, num_shards, enforce_even_split)
Shards a PyTree into a tuple of PyTrees, one for each shard rank.
This function takes a single global data structure and splits it into num_shards
structures.
- If a spec leaf is a
SpecShard(dim), the tensor or list is split along that dimension, and thei-th slice goes to thei-th output tree. - If a spec leaf is
SpecReplicate, the item is replicated (reference copy) to all output trees.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
TSameTree
|
The structure containing tensors to be sharded. |
required |
sharding_spec
|
ShardingSpec
|
A structure matching 'tree' containing |
required |
num_shards
|
int
|
The total number of shards to split the tensors into. |
required |
enforce_even_split
|
bool
|
If True, raises a ValueError if a tensor's dimension
size is not perfectly divisible by |
required |
Returns:
| Type | Description |
|---|---|
TSameTree
|
A tuple of length |
...
|
the structure of the input |
tuple[TSameTree, ...]
|
that specific rank. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If tree structures do not match, or valid sharding conditions are not met. |
unshard_tree(sharded_trees, sharding_spec)
Combines a sequence of PyTrees (one per rank) into a single global PyTree.
This is the inverse of shard_tree. It iterates over the provided trees,
gathering corresponding leaves from each rank.
- If the spec for a leaf is
SpecShard(dim), the tensors from all ranks are concatenated (or stacked ifdo_stack=True) along that dimension. - If the spec is
SpecReplicate, it assumes the data is replicated and takes the item from the first rank.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sharded_trees
|
Sequence[TSameTree]
|
A sequence (list or tuple) of PyTrees. There must be
one tree for each shard rank, and they must all share the same
structure as |
required |
sharding_spec
|
ShardingSpec
|
A structure matching the input trees containing
|
required |
Returns:
| Type | Description |
|---|---|
TSameTree
|
A single PyTree where distinct shards have been merged into full tensors. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |