Skip to content

Typing Extensions

About

The d9d.core.types package gathers common Type Aliases used throughout the framework.

The d9d.core.protocol package defines standard interfaces (Protocols) for standard PyTorch components used in the distributed training loop.

d9d.core.types

Common type definitions used throughout the framework.

CollateFn = Callable[[Sequence[TDataTree]], TDataTree] module-attribute

Type alias for a function that collates a sequence of samples into a batch.

The function receives a sequence of individual data point structures (PyTrees) and is responsible for stacking or merging them into a single batched structure.

PyTree = TLeaf | list['PyTree[TLeaf]'] | dict[str, 'PyTree[TLeaf]'] | tuple['PyTree[TLeaf]', ...] module-attribute

A recursive type definition representing a tree of data.

This type alias covers standard Python containers (dictionaries, lists, tuples) nested arbitrarily deep, terminating in a leaf node of type TLeaf.

This is commonly used for handling nested state dictionaries or arguments passed to functions that support recursive traversal (similar to torch.utils._pytree).

ScalarTree = PyTree[str | float | int | bool] module-attribute

A recursive tree structure where the leaf nodes are python scalars (str, float, int).

TensorTree = PyTree[torch.Tensor] module-attribute

A recursive tree structure where the leaf nodes are PyTorch Tensors.

d9d.core.protocol

Package providing protocol definitions for standard PyTorch objects.

LRSchedulerProtocol

Bases: Protocol

Protocol defining an interface for a Learning Rate Scheduler.

This protocol ensures that the wrapped scheduler supports stepping and state checkpointing via the Stateful interface.

load_state_dict(state_dict)

Restore the object's state from the provided state_dict.

Parameters:

Name Type Description Default
state_dict dict[str, Any]

The state dict to restore from

required

state_dict()

Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict().

Returns:

Name Type Description
Dict dict[str, Any]

The objects state dict

step()

Performs a single learning rate scheduling step.

OptimizerProtocol

Bases: Protocol

Protocol defining an interface for standard PyTorch Optimizer object.

This protocol ensures that the wrapped optimizer supports standard API and state checkpointing via the Stateful interface.

load_state_dict(state_dict)

Restore the object's state from the provided state_dict.

Parameters:

Name Type Description Default
state_dict dict[str, Any]

The state dict to restore from

required

state_dict()

Objects should return their state_dict representation as a dictionary. The output of this function will be checkpointed, and later restored in load_state_dict().

Returns:

Name Type Description
Dict dict[str, Any]

The objects state dict

step()

Performs a single optimization step.

zero_grad()

Sets the gradients of all optimized tensors to zero.