Typing Extensions
About
The d9d.core.types package gathers common Type Aliases used throughout the framework.
The d9d.core.protocol package defines standard interfaces (Protocols) for standard PyTorch components used in the distributed training loop.
d9d.core.types
Common type definitions used throughout the framework.
CollateFn = Callable[[Sequence[TDataTree]], TDataTree]
module-attribute
Type alias for a function that collates a sequence of samples into a batch.
The function receives a sequence of individual data point structures (PyTrees) and is responsible for stacking or merging them into a single batched structure.
PyTree = TLeaf | list['PyTree[TLeaf]'] | dict[str, 'PyTree[TLeaf]'] | tuple['PyTree[TLeaf]', ...]
module-attribute
A recursive type definition representing a tree of data.
This type alias covers standard Python containers (dictionaries, lists, tuples)
nested arbitrarily deep, terminating in a leaf node of type TLeaf.
This is commonly used for handling nested state dictionaries or arguments
passed to functions that support recursive traversal (similar to torch.utils._pytree).
ScalarTree = PyTree[str | float | int | bool]
module-attribute
A recursive tree structure where the leaf nodes are python scalars (str, float, int).
TensorTree = PyTree[torch.Tensor]
module-attribute
A recursive tree structure where the leaf nodes are PyTorch Tensors.
d9d.core.protocol
Package providing protocol definitions for standard PyTorch objects.
LRSchedulerProtocol
Bases: Protocol
Protocol defining an interface for a Learning Rate Scheduler.
This protocol ensures that the wrapped scheduler supports stepping and state checkpointing via the Stateful interface.
load_state_dict(state_dict)
state_dict()
step()
Performs a single learning rate scheduling step.
OptimizerProtocol
Bases: Protocol
Protocol defining an interface for standard PyTorch Optimizer object.
This protocol ensures that the wrapped optimizer supports standard API and state checkpointing via the Stateful interface.
load_state_dict(state_dict)
state_dict()
step()
Performs a single optimization step.
zero_grad()
Sets the gradients of all optimized tensors to zero.