Skip to content

Model State Mapper

About

The d9d.model_state.mapper package solves the complexity of working with model checkpoints by providing a declarative, graph-based framework for transforming model states.

Core Concept

Loading large-scale models is rarely a simple 1-to-1 key matching operation. You often face challenges such as:

  • Naming Mismatches: HuggingFace uses model.layers.0, your custom model uses transformer.h.0.
  • Shape Mismatches: The checkpoint stores Q, K, and V separately, but your model implementation expects a stacked QKV tensor.
  • Scale: The checkpoint is 500GB. You cannot load the whole dictionary on every GPU to process it.

Instead of writing a manual loop that loads tensors and blindly modifies them, this framework treats state transformation as a Directed Acyclic Graph (DAG).

Such a declarative approach makes it available for d9d to perform complex transform-save and transform-load operations effectively in a streamed manner without loading the whole checkpoint into memory.

Usage Examples

Pass-through Mapping for PyTorch Module

If you simply want to load a checkpoint where keys match the model definition (standard load_state_dict behavior), but want to utilize d9d's streaming/sharding capabilities.

import torch.nn as nn
from d9d.model_state.mapper.adapters import identity_mapper_from_module

# Define your PyTorch model
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 5)
)

# Automatically generate a mapper based on the model's actual parameter names
# This creates Identity mappers for "0.weight", "0.bias", "2.weight", "2.bias"
mapper = identity_mapper_from_module(model)

Using Leaf Mappers

This example demonstrates using leaf mappers to handle common mismatch scenario: merging separate Query/Key/Value tensors into a single tensor.

import torch
from d9d.model_state.mapper.leaf import (
    ModelStateMapperRename,
    ModelStateMapperStackTensors
)

# Stacking Tensors
# Scenario: Checkpoint has separate Q, K, V linear layers, we need one QKV tensor
stack_mapper = ModelStateMapperStackTensors(
    source_names=["attn.q.weight", "attn.k.weight", "attn.v.weight"],
    target_name="attn.qkv.weight",
    stack_dim=0
)

# To show what this mapper needs:
print(stack_mapper.state_dependency_groups())
# Output: {StateGroup(inputs={'attn.q.weight', ...}, outputs={'attn.qkv.weight'})}

# To actually execute:
dummy_data = {
    "attn.q.weight": torch.randn(64, 64),
    "attn.k.weight": torch.randn(64, 64),
    "attn.v.weight": torch.randn(64, 64),
}
result = stack_mapper.apply(dummy_data)
print(result["attn.qkv.weight"].shape) 
# Output: torch.Size([3, 64, 64])

Composing Complex Pipelines

Converting an entire model state requires processing multiple keys in parallel, and potentially chaining operations (e.g., Rename then Stack).

from d9d.model_state.mapper.compose import ModelStateMapperSequential, ModelStateMapperParallel
from d9d.model_state.mapper.leaf import ModelStateMapperRename, ModelStateMapperStackTensors

# Define a transformation pipeline
mapper = ModelStateMapperSequential([
    # Step 1: Rename keys to standard format
    ModelStateMapperParallel([
        ModelStateMapperRename("bert.encoder.layer.0.attention.self.query.weight", "layer.0.q"),
        ModelStateMapperRename("bert.encoder.layer.0.attention.self.key.weight", "layer.0.k"),
        ModelStateMapperRename("bert.encoder.layer.0.attention.self.value.weight", "layer.0.v"),
    ]),
    # Step 2: Stack them into a specialized attention tensor
    ModelStateMapperStackTensors(
        source_names=["layer.0.q", "layer.0.k", "layer.0.v"],
        target_name="layer.0.qkv",
        stack_dim=0
    )
])

d9d.model_state.mapper

This package provides core components of the state mapping system.

ModelStateMapper

Bases: ABC

The abstract base class for all model state transformation operations.

This class serves as the interface between the definition of a transformation topology and the actual execution of tensor operations.

It enforces a Declarative vs. Imperative separation of concerns:

  1. Declarative (Topology): Through state_dependency_groups(), the mapper announces what it intends to do without handling any data. This allows the system to build execution graphs, validate chains, detect collisions, and shard tasks before allocating memory.
  2. Imperative (Execution): Through apply(), the mapper performs the actual logic (PyTorch operations) on model states.

apply(group) abstractmethod

Executes the transformation logic on a specific dictionary of tensors.

The orchestration system guarantees that the group dictionary passed here contains all keys listed in the inputs of the active StateGroup.

Implementation of this method should guarantee that the result will contain all keys listed in the outputs.

Parameters:

Name Type Description Default
group dict[str, Tensor]

A dictionary containing the source data. Keys match StateGroup.inputs.

required

Returns:

Type Description
dict[str, Tensor]

A dictionary containing the transformed data. Keys must strictly match StateGroup.outputs.

state_dependency_groups() abstractmethod

Calculates and returns the set of independent dependency groups this mapper handles.

Returns:

Type Description
frozenset[StateGroup]

A frozenset of StateGroup objects. Each group

frozenset[StateGroup]

represents a disjoint operation. For example, a mapper that renames ten

frozenset[StateGroup]

independent tensors would return ten distinct StateGroup objects,

frozenset[StateGroup]

allowing them to be sharded or processed individually.

StateGroup dataclass

Represents an atomic unit of dependency in the model state transformation graph.

A StateGroup defines a strict contract between a set of input keys (source) and a set of output keys (destination).

Attributes:

Name Type Description
inputs frozenset[str]

The complete set of keys required from the source state dictionary to satisfy this dependency.

outputs frozenset[str]

The complete set of keys that will be produced as a result of this transformation.

d9d.model_state.mapper.adapters

This package provides utility functions that are used to create simple ModelStateMapper instances from objects such as PyTorch modules or other StateMappers

identity_mapper_from_mapper_outputs(mapper)

Creates an identity mapper covering all outputs produced by the provided mapper.

This function inspects the state_dependency_groups() of the input mapper, extracts every key listed in the outputs set of each group, and creates a corresponding ModelStateMapperIdentity for it.

Parameters:

Name Type Description Default
mapper ModelStateMapper

The mapper whose output signature will be inspected to generate the new identity mapper.

required

Returns:

Type Description
ModelStateMapper

A composite mapper that acts as a pass-through for every key produced by the source mapper.

identity_mapper_from_module(module)

Creates an identity mapper for every parameter in a single PyTorch module.

It is useful when you want to define a "pass-through" pipeline where the source checkpoint keys are expected to exactly match the model's current parameter names (standard load_state_dict behavior).

Parameters:

Name Type Description Default
module Module

The instantiated PyTorch model to inspect.

required

d9d.model_state.mapper.compose

Complex state mappers are built using composition. This package provides ModelStateMapper implementations that are composed of other mappers.

ModelStateMapperParallel

Bases: ModelStateMapper

Executes a list of states mappers independently alongside each other.

This class aggregates multiple mappers into a single logical unit. It enforces strict isolation between the mappers: no two mappers can consume the same input key (input collision) or produce the same output key (output collision).

During execution (apply), it routes the specific subset of the input dictionary to the sub-mapper responsible for those keys.

ModelStateMapperSequential

Bases: ModelStateMapper

Executes a list of mappers in a specific sequence (pipeline).

This class manages the data flow from one mapper to the next. It abstracts away intermediate states, exposing only the inputs required by the first relevant stage and the outputs produced by the final relevant stage.

Key Features:

  1. Gap Filling: Automatically injects Identity mappers if a tensor needs to pass through a stage without modification to reach a later stage or the final output.

  2. Group Merging: Computes the net dependency graph. If Stage A requires 'x' and produces 'y', and Stage B requires 'y' and produces 'z', the Sequential mapper reports a single group {x} -> {z}.

ModelStateMapperShard

Bases: ModelStateMapper

Wraps another state mapper and restricts its execution to a specific subset (shard) of dependency groups.

This is primarily used for parallelizing model loading across multiple processes or nodes. By assigning a different current_shard index to each process, the total set of tensors required by the sub_mapper is split evenly, preventing every process from loading the entire checkpoint.

filter_empty_mappers(mappers)

Filters out mappers that have no effect (no inputs and no outputs).

Parameters:

Name Type Description Default
mappers Sequence[ModelStateMapper]

The list of mappers to filter.

required

Returns:

Type Description
list[ModelStateMapper]

A new list containing only active mappers.

d9d.model_state.mapper.leaf

This package provides leaf mapper implementations.

ModelStateMapperDistribute

Bases: ModelStateMapper

Converts a single local Tensor object into a DTensor object with specified device_mesh and placements.

ModelStateMapperGatherFullTensor

Bases: ModelStateMapper

Gathers a single DTensor object into a full Tensor object.

ModelStateMapperIdentity

Bases: ModelStateMapper

Passes a single state tensor through unchanged.

ModelStateMapperRename

Bases: ModelStateMapper

Renames a single state tensor from name_from to name_to.

ModelStateMapperSelectChildModules

Bases: ModelStateMapper

Selects a set of keys belonging to a specific parent module (prefix) and renames them by removing that prefix.

This is effectively a batch rename operation that "hoists" parameters from a submodule scope to the current scope.

ModelStateMapperStackTensors

Bases: ModelStateMapper

Stacks multiple input tensors with names source_names into a single output tensor with name target_name producing new stack_dim dimension.