About
The d9d.model_state.mapper package solves the complexity of working with model checkpoints by providing
a declarative, graph-based framework for transforming model states.
Core Concept
Loading large-scale models is rarely a simple 1-to-1 key matching operation. You often face challenges such as:
- Naming Mismatches: HuggingFace uses
model.layers.0, your custom model usestransformer.h.0. - Shape Mismatches: The checkpoint stores
Q,K, andVseparately, but your model implementation expects a stackedQKVtensor. - Scale: The checkpoint is 500GB. You cannot load the whole dictionary on every GPU to process it.
Instead of writing a manual loop that loads tensors and blindly modifies them, this framework treats state transformation as a Directed Acyclic Graph (DAG).
Such a declarative approach makes it available for d9d to perform complex transform-save and transform-load operations effectively in a streamed manner without loading the whole checkpoint into memory.
Usage Examples
Pass-through Mapping for PyTorch Module
If you simply want to load a checkpoint where keys match the model definition (standard load_state_dict behavior), but want to utilize d9d's streaming/sharding capabilities.
import torch.nn as nn
from d9d.model_state.mapper.adapters import identity_mapper_from_module
# Define your PyTorch model
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 5)
)
# Automatically generate a mapper based on the model's actual parameter names
# This creates Identity mappers for "0.weight", "0.bias", "2.weight", "2.bias"
mapper = identity_mapper_from_module(model)
Using Leaf Mappers
This example demonstrates using leaf mappers to handle common mismatch scenario: merging separate Query/Key/Value tensors into a single tensor.
import torch
from d9d.model_state.mapper.leaf import (
ModelStateMapperRename,
ModelStateMapperStackTensors
)
# Stacking Tensors
# Scenario: Checkpoint has separate Q, K, V linear layers, we need one QKV tensor
stack_mapper = ModelStateMapperStackTensors(
source_names=["attn.q.weight", "attn.k.weight", "attn.v.weight"],
target_name="attn.qkv.weight",
stack_dim=0
)
# To show what this mapper needs:
print(stack_mapper.state_dependency_groups())
# Output: {StateGroup(inputs={'attn.q.weight', ...}, outputs={'attn.qkv.weight'})}
# To actually execute:
dummy_data = {
"attn.q.weight": torch.randn(64, 64),
"attn.k.weight": torch.randn(64, 64),
"attn.v.weight": torch.randn(64, 64),
}
result = stack_mapper.apply(dummy_data)
print(result["attn.qkv.weight"].shape)
# Output: torch.Size([3, 64, 64])
Composing Complex Pipelines
Converting an entire model state requires processing multiple keys in parallel, and potentially chaining operations (e.g., Rename then Stack).
from d9d.model_state.mapper.compose import ModelStateMapperSequential, ModelStateMapperParallel
from d9d.model_state.mapper.leaf import ModelStateMapperRename, ModelStateMapperStackTensors
# Define a transformation pipeline
mapper = ModelStateMapperSequential([
# Step 1: Rename keys to standard format
ModelStateMapperParallel([
ModelStateMapperRename("bert.encoder.layer.0.attention.self.query.weight", "layer.0.q"),
ModelStateMapperRename("bert.encoder.layer.0.attention.self.key.weight", "layer.0.k"),
ModelStateMapperRename("bert.encoder.layer.0.attention.self.value.weight", "layer.0.v"),
]),
# Step 2: Stack them into a specialized attention tensor
ModelStateMapperStackTensors(
source_names=["layer.0.q", "layer.0.k", "layer.0.v"],
target_name="layer.0.qkv",
stack_dim=0
)
])
d9d.model_state.mapper
This package provides core components of the state mapping system.
ModelStateMapper
Bases: ABC
The abstract base class for all model state transformation operations.
This class serves as the interface between the definition of a transformation topology and the actual execution of tensor operations.
It enforces a Declarative vs. Imperative separation of concerns:
- Declarative (Topology): Through
state_dependency_groups(), the mapper announces what it intends to do without handling any data. This allows the system to build execution graphs, validate chains, detect collisions, and shard tasks before allocating memory. - Imperative (Execution): Through
apply(), the mapper performs the actual logic (PyTorch operations) on model states.
Source code in d9d/model_state/mapper/abc.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | |
apply(group)
abstractmethod
Executes the transformation logic on a specific dictionary of tensors.
The orchestration system guarantees that the group dictionary passed here contains
all keys listed in the inputs of the active StateGroup.
Implementation of this method should guarantee that the result will contain all keys listed in the outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
group
|
dict[str, Tensor]
|
A dictionary containing the source data.
Keys match |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
A dictionary containing the transformed data. Keys must strictly match |
Source code in d9d/model_state/mapper/abc.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | |
state_dependency_groups()
abstractmethod
Calculates and returns the set of independent dependency groups this mapper handles.
Returns:
| Type | Description |
|---|---|
frozenset[StateGroup]
|
A frozenset of |
frozenset[StateGroup]
|
represents a disjoint operation. For example, a mapper that renames ten |
frozenset[StateGroup]
|
independent tensors would return ten distinct |
frozenset[StateGroup]
|
allowing them to be sharded or processed individually. |
Source code in d9d/model_state/mapper/abc.py
40 41 42 43 44 45 46 47 48 49 50 51 | |
StateGroup
dataclass
Represents an atomic unit of dependency in the model state transformation graph.
A StateGroup defines a strict contract between a set of input keys (source)
and a set of output keys (destination).
Attributes:
| Name | Type | Description |
|---|---|---|
inputs |
frozenset[str]
|
The complete set of keys required from the source state dictionary to satisfy this dependency. |
outputs |
frozenset[str]
|
The complete set of keys that will be produced as a result of this transformation. |
Source code in d9d/model_state/mapper/abc.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | |
d9d.model_state.mapper.adapters
This package provides utility functions that are used to create simple ModelStateMapper instances from objects such as PyTorch modules or other StateMappers
identity_mapper_from_mapper_outputs(mapper)
Creates an identity mapper covering all outputs produced by the provided mapper.
This function inspects the state_dependency_groups() of the input mapper,
extracts every key listed in the outputs set of each group, and creates a
corresponding ModelStateMapperIdentity for it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mapper
|
ModelStateMapper
|
The mapper whose output signature will be inspected to generate the new identity mapper. |
required |
Returns:
| Type | Description |
|---|---|
ModelStateMapper
|
A composite mapper that acts as a pass-through for every key produced by the source |
Source code in d9d/model_state/mapper/adapters/mapper.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | |
identity_mapper_from_module(module)
Creates an identity mapper for every parameter in a single PyTorch module.
It is useful when you want to define a "pass-through" pipeline where the
source checkpoint keys are expected to exactly match the model's current
parameter names (standard load_state_dict behavior).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The instantiated PyTorch model to inspect. |
required |
Source code in d9d/model_state/mapper/adapters/module.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | |
d9d.model_state.mapper.compose
Complex state mappers are built using composition. This package provides ModelStateMapper implementations that are composed of other mappers.
ModelStateMapperParallel
Bases: ModelStateMapper
Executes a list of states mappers independently alongside each other.
This class aggregates multiple mappers into a single logical unit. It enforces strict isolation between the mappers: no two mappers can consume the same input key (input collision) or produce the same output key (output collision).
During execution (apply), it routes the specific subset of the input dictionary
to the sub-mapper responsible for those keys.
Source code in d9d/model_state/mapper/compose/parallel.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | |
ModelStateMapperSequential
Bases: ModelStateMapper
Executes a list of mappers in a specific sequence (pipeline).
This class manages the data flow from one mapper to the next. It abstracts away intermediate states, exposing only the inputs required by the first relevant stage and the outputs produced by the final relevant stage.
Key Features:
-
Gap Filling: Automatically injects
Identitymappers if a tensor needs to pass through a stage without modification to reach a later stage or the final output. -
Group Merging: Computes the net dependency graph. If Stage A requires 'x' and produces 'y', and Stage B requires 'y' and produces 'z', the Sequential mapper reports a single group
{x} -> {z}.
Source code in d9d/model_state/mapper/compose/sequential.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | |
ModelStateMapperShard
Bases: ModelStateMapper
Wraps another state mapper and restricts its execution to a specific subset (shard) of dependency groups.
This is primarily used for parallelizing model loading across multiple processes
or nodes. By assigning a different current_shard index to each process,
the total set of tensors required by the sub_mapper is split evenly,
preventing every process from loading the entire checkpoint.
Source code in d9d/model_state/mapper/compose/shard.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | |
filter_empty_mappers(mappers)
Filters out mappers that have no effect (no inputs and no outputs).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mappers
|
Sequence[ModelStateMapper]
|
The list of mappers to filter. |
required |
Returns:
| Type | Description |
|---|---|
list[ModelStateMapper]
|
A new list containing only active mappers. |
Source code in d9d/model_state/mapper/compose/helper.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | |
d9d.model_state.mapper.leaf
This package provides leaf mapper implementations.
ModelStateMapperDistribute
Bases: ModelStateMapper
Converts a single local Tensor object into a DTensor object with specified
device_mesh and placements.
Source code in d9d/model_state/mapper/leaf/dtensor.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | |
ModelStateMapperGatherFullTensor
Bases: ModelStateMapper
Gathers a single DTensor object into a full Tensor object.
Source code in d9d/model_state/mapper/leaf/dtensor.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
ModelStateMapperIdentity
Bases: ModelStateMapper
Passes a single state tensor through unchanged.
Source code in d9d/model_state/mapper/leaf/identity.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | |
ModelStateMapperRename
Bases: ModelStateMapper
Renames a single state tensor from name_from to name_to.
Source code in d9d/model_state/mapper/leaf/rename.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | |
ModelStateMapperSelectChildModules
Bases: ModelStateMapper
Selects a set of keys belonging to a specific parent module (prefix) and renames them by removing that prefix.
This is effectively a batch rename operation that "hoists" parameters from a submodule scope to the current scope.
Source code in d9d/model_state/mapper/leaf/select_child.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | |
ModelStateMapperStackTensors
Bases: ModelStateMapper
Stacks multiple input tensors with names source_names into a single output tensor with name target_name
producing new stack_dim dimension.
Source code in d9d/model_state/mapper/leaf/stack.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | |