Skip to content

PyTree Sharding Utilities

About

The d9d.core.sharding package provides utilities for splitting and reconstructing complex nested structures (PyTrees) of PyTorch Tensors and Python Lists.

Sharding Spec

A Sharding Spec is a PyTree that mirrors the structure of your data (e.g., a State Dict).

  • Structure: Mirrors the data hierarchy. The spec structure is used to traverse the data; sharding operations flatten the data tree up to the leaves defined in the spec.
  • Leaves:
    • d9d.core.sharding.SpecShard(dim, do_stack=False):
      • Tensors: The tensor is split along dimension dim.
      • Lists: The list is split into chunks. dim must be 0.
      • do_stack: If True, tensors are unbound/stacked (reducing/increasing dimensionality). If False (default), tensors are split/concatenated (preserving dimensionality).
    • d9d.core.sharding.SpecReplicate (or None): The item is replicated (kept as-is/not split) across all shards.

Helper functions like shard_spec_on_dim allow generating these specs automatically.

d9d.core.sharding

SpecReplicate dataclass

Specifies that a leaf node should be replicated across all shards.

SpecShard dataclass

Specifies that a leaf node should be split along a specific dimension.

Attributes:

Name Type Description
dim int

The dimension to split.

do_stack bool

If True, sharding will squeeze the sharded dimension (it should be exactly the num_shards length)

shard_spec_nothing(tree)

Creates a sharding specification where no sharding is performed.

This effectively clones the tree structure but replaces every leaf with SpecReplicate.

Parameters:

Name Type Description Default
tree PyTree[Any]

The input PyTree structure.

required

Returns:

Type Description
ShardingSpec

A new PyTree matching the input structure, containing strictly SpecReplicate for all leaves.

shard_spec_on_dim(tree, dim)

Creates a sharding specification to split all tensors in the tree on a specific dimension.

Iterates over the input tree: * If a leaf is a Tensor with enough dimensions, it is mapped to a SpecShard(dim) object. * If a leaf is a list, it is mapped to a SpecShard(0) object (only dim=0 is allowed for lists). * Other types and 0-dim tensors are mapped to SpecReplicate.

Parameters:

Name Type Description Default
tree PyTree[Any]

The input PyTree structure.

required
dim int

The dimension index to shard eligible tensors on.

required

Returns:

Type Description
ShardingSpec

A new PyTree matching the input structure, containing SpecShard or SpecReplicate objects.

Raises:

Type Description
ValueError

If a tensor exists in the tree with rank less than or equal to 'dim'.

shard_tree(tree, sharding_spec, num_shards, enforce_even_split)

Shards a PyTree into a tuple of PyTrees, one for each shard rank.

This function takes a single global data structure and splits it into num_shards structures.

  • If a spec leaf is a SpecShard(dim), the tensor or list is split along that dimension, and the i-th slice goes to the i-th output tree.
  • If a spec leaf is SpecReplicate, the item is replicated (reference copy) to all output trees.

Parameters:

Name Type Description Default
tree TSameTree

The structure containing tensors to be sharded.

required
sharding_spec ShardingSpec

A structure matching 'tree' containing SpecShard or SpecReplicate objects.

required
num_shards int

The total number of shards to split the tensors into.

required
enforce_even_split bool

If True, raises a ValueError if a tensor's dimension size is not perfectly divisible by num_shards.

required

Returns:

Type Description
TSameTree

A tuple of length num_shards. Each element is a PyTree matching

...

the structure of the input tree, containing the local data for

tuple[TSameTree, ...]

that specific rank.

Raises:

Type Description
ValueError

If tree structures do not match, or valid sharding conditions are not met.

unshard_tree(sharded_trees, sharding_spec)

Combines a sequence of PyTrees (one per rank) into a single global PyTree.

This is the inverse of shard_tree. It iterates over the provided trees, gathering corresponding leaves from each rank.

  • If the spec for a leaf is SpecShard(dim), the tensors from all ranks are concatenated (or stacked if do_stack=True) along that dimension.
  • If the spec is SpecReplicate, it assumes the data is replicated and takes the item from the first rank.

Parameters:

Name Type Description Default
sharded_trees Sequence[TSameTree]

A sequence (list or tuple) of PyTrees. There must be one tree for each shard rank, and they must all share the same structure as sharding_spec.

required
sharding_spec ShardingSpec

A structure matching the input trees containing SpecShard or SpecReplicate objects.

required

Returns:

Type Description
TSameTree

A single PyTree where distinct shards have been merged into full tensors.

Raises:

Type Description
ValueError

If sharded_trees is empty, or if unit structures do not match the spec.