About

The d9d.model_state.mapper package solves the complexity of working with model checkpoints by providing a declarative, graph-based framework for transforming model states.

Core Concept

Loading large-scale models is rarely a simple 1-to-1 key matching operation. You often face challenges such as:

  • Naming Mismatches: HuggingFace uses model.layers.0, your custom model uses transformer.h.0.
  • Shape Mismatches: The checkpoint stores Q, K, and V separately, but your model implementation expects a stacked QKV tensor.
  • Scale: The checkpoint is 500GB. You cannot load the whole dictionary on every GPU to process it.

Instead of writing a manual loop that loads tensors and blindly modifies them, this framework treats state transformation as a Directed Acyclic Graph (DAG).

Such a declarative approach makes it available for d9d to perform complex transform-save and transform-load operations effectively in a streamed manner without loading the whole checkpoint into memory.

Usage Examples

Pass-through Mapping for PyTorch Module

If you simply want to load a checkpoint where keys match the model definition (standard load_state_dict behavior), but want to utilize d9d's streaming/sharding capabilities.

import torch.nn as nn
from d9d.model_state.mapper.adapters import identity_mapper_from_module

# Define your PyTorch model
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 5)
)

# Automatically generate a mapper based on the model's actual parameter names
# This creates Identity mappers for "0.weight", "0.bias", "2.weight", "2.bias"
mapper = identity_mapper_from_module(model)

Using Leaf Mappers

This example demonstrates using leaf mappers to handle common mismatch scenario: merging separate Query/Key/Value tensors into a single tensor.

import torch
from d9d.model_state.mapper.leaf import (
    ModelStateMapperRename,
    ModelStateMapperStackTensors
)

# Stacking Tensors
# Scenario: Checkpoint has separate Q, K, V linear layers, we need one QKV tensor
stack_mapper = ModelStateMapperStackTensors(
    source_names=["attn.q.weight", "attn.k.weight", "attn.v.weight"],
    target_name="attn.qkv.weight",
    stack_dim=0
)

# To show what this mapper needs:
print(stack_mapper.state_dependency_groups())
# Output: {StateGroup(inputs={'attn.q.weight', ...}, outputs={'attn.qkv.weight'})}

# To actually execute:
dummy_data = {
    "attn.q.weight": torch.randn(64, 64),
    "attn.k.weight": torch.randn(64, 64),
    "attn.v.weight": torch.randn(64, 64),
}
result = stack_mapper.apply(dummy_data)
print(result["attn.qkv.weight"].shape) 
# Output: torch.Size([3, 64, 64])

Composing Complex Pipelines

Converting an entire model state requires processing multiple keys in parallel, and potentially chaining operations (e.g., Rename then Stack).

from d9d.model_state.mapper.compose import ModelStateMapperSequential, ModelStateMapperParallel
from d9d.model_state.mapper.leaf import ModelStateMapperRename, ModelStateMapperStackTensors

# Define a transformation pipeline
mapper = ModelStateMapperSequential([
    # Step 1: Rename keys to standard format
    ModelStateMapperParallel([
        ModelStateMapperRename("bert.encoder.layer.0.attention.self.query.weight", "layer.0.q"),
        ModelStateMapperRename("bert.encoder.layer.0.attention.self.key.weight", "layer.0.k"),
        ModelStateMapperRename("bert.encoder.layer.0.attention.self.value.weight", "layer.0.v"),
    ]),
    # Step 2: Stack them into a specialized attention tensor
    ModelStateMapperStackTensors(
        source_names=["layer.0.q", "layer.0.k", "layer.0.v"],
        target_name="layer.0.qkv",
        stack_dim=0
    )
])

d9d.model_state.mapper

This package provides core components of the state mapping system.

ModelStateMapper

Bases: ABC

The abstract base class for all model state transformation operations.

This class serves as the interface between the definition of a transformation topology and the actual execution of tensor operations.

It enforces a Declarative vs. Imperative separation of concerns:

  1. Declarative (Topology): Through state_dependency_groups(), the mapper announces what it intends to do without handling any data. This allows the system to build execution graphs, validate chains, detect collisions, and shard tasks before allocating memory.
  2. Imperative (Execution): Through apply(), the mapper performs the actual logic (PyTorch operations) on model states.
Source code in d9d/model_state/mapper/abc.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class ModelStateMapper(abc.ABC):
    """
   The abstract base class for all model state transformation operations.

   This class serves as the interface between the definition of a transformation
   topology and the actual execution of tensor operations.

   It enforces a Declarative vs. Imperative separation of concerns:

   1.  Declarative (Topology): Through `state_dependency_groups()`, the mapper
       announces *what* it intends to do without handling any data. This allows the system to build execution graphs,
       validate chains, detect collisions, and shard tasks *before* allocating memory.
   2.  Imperative (Execution): Through `apply()`, the mapper performs the
       actual logic (PyTorch operations) on model states.
   """

    @abc.abstractmethod
    def state_dependency_groups(self) -> frozenset[StateGroup]:
        """
        Calculates and returns the set of independent dependency groups this mapper handles.

        Returns:
            A frozenset of `StateGroup` objects. Each group
            represents a disjoint operation. For example, a mapper that renames ten
            independent tensors would return ten distinct `StateGroup` objects,
            allowing them to be sharded or processed individually.
        """
        ...

    @abc.abstractmethod
    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """
        Executes the transformation logic on a specific dictionary of tensors.

        The orchestration system guarantees that the `group` dictionary passed here contains
        all keys listed in the `inputs` of the active `StateGroup`.

        Implementation of this method should guarantee that the result will contain all keys listed in the `outputs`.

        Args:
           group: A dictionary containing the source data.
               Keys match `StateGroup.inputs`.

        Returns:
           A dictionary containing the transformed data. Keys must strictly match `StateGroup.outputs`.
        """
        ...

apply(group) abstractmethod

Executes the transformation logic on a specific dictionary of tensors.

The orchestration system guarantees that the group dictionary passed here contains all keys listed in the inputs of the active StateGroup.

Implementation of this method should guarantee that the result will contain all keys listed in the outputs.

Parameters:

Name Type Description Default
group dict[str, Tensor]

A dictionary containing the source data. Keys match StateGroup.inputs.

required

Returns:

Type Description
dict[str, Tensor]

A dictionary containing the transformed data. Keys must strictly match StateGroup.outputs.

Source code in d9d/model_state/mapper/abc.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@abc.abstractmethod
def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    """
    Executes the transformation logic on a specific dictionary of tensors.

    The orchestration system guarantees that the `group` dictionary passed here contains
    all keys listed in the `inputs` of the active `StateGroup`.

    Implementation of this method should guarantee that the result will contain all keys listed in the `outputs`.

    Args:
       group: A dictionary containing the source data.
           Keys match `StateGroup.inputs`.

    Returns:
       A dictionary containing the transformed data. Keys must strictly match `StateGroup.outputs`.
    """
    ...

state_dependency_groups() abstractmethod

Calculates and returns the set of independent dependency groups this mapper handles.

Returns:

Type Description
frozenset[StateGroup]

A frozenset of StateGroup objects. Each group

frozenset[StateGroup]

represents a disjoint operation. For example, a mapper that renames ten

frozenset[StateGroup]

independent tensors would return ten distinct StateGroup objects,

frozenset[StateGroup]

allowing them to be sharded or processed individually.

Source code in d9d/model_state/mapper/abc.py
40
41
42
43
44
45
46
47
48
49
50
51
@abc.abstractmethod
def state_dependency_groups(self) -> frozenset[StateGroup]:
    """
    Calculates and returns the set of independent dependency groups this mapper handles.

    Returns:
        A frozenset of `StateGroup` objects. Each group
        represents a disjoint operation. For example, a mapper that renames ten
        independent tensors would return ten distinct `StateGroup` objects,
        allowing them to be sharded or processed individually.
    """
    ...

StateGroup dataclass

Represents an atomic unit of dependency in the model state transformation graph.

A StateGroup defines a strict contract between a set of input keys (source) and a set of output keys (destination).

Attributes:

Name Type Description
inputs frozenset[str]

The complete set of keys required from the source state dictionary to satisfy this dependency.

outputs frozenset[str]

The complete set of keys that will be produced as a result of this transformation.

Source code in d9d/model_state/mapper/abc.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
@dataclasses.dataclass(frozen=True)
class StateGroup:
    """
    Represents an atomic unit of dependency in the model state transformation graph.

    A `StateGroup` defines a strict contract between a set of input keys (source)
    and a set of output keys (destination).

    Attributes:
        inputs: The complete set of keys required from the source state dictionary to satisfy this dependency.
        outputs: The complete set of keys that will be produced as a result of this transformation.
    """

    inputs: frozenset[str]
    outputs: frozenset[str]

d9d.model_state.mapper.adapters

This package provides utility functions that are used to create simple ModelStateMapper instances from objects such as PyTorch modules or other StateMappers

identity_mapper_from_mapper_outputs(mapper)

Creates an identity mapper covering all outputs produced by the provided mapper.

This function inspects the state_dependency_groups() of the input mapper, extracts every key listed in the outputs set of each group, and creates a corresponding ModelStateMapperIdentity for it.

Parameters:

Name Type Description Default
mapper ModelStateMapper

The mapper whose output signature will be inspected to generate the new identity mapper.

required

Returns:

Type Description
ModelStateMapper

A composite mapper that acts as a pass-through for every key produced by the source mapper.

Source code in d9d/model_state/mapper/adapters/mapper.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def identity_mapper_from_mapper_outputs(mapper: ModelStateMapper) -> ModelStateMapper:
    """
    Creates an identity mapper covering all outputs produced by the provided mapper.

    This function inspects the `state_dependency_groups()` of the input `mapper`,
    extracts every key listed in the `outputs` set of each group, and creates a
    corresponding `ModelStateMapperIdentity` for it.

    Args:
        mapper: The mapper whose output signature will be inspected to generate the new identity mapper.

    Returns:
        A composite mapper that acts as a pass-through for every key produced by the source `mapper`.
    """

    mappers: list[ModelStateMapper] = []

    for state_group in mapper.state_dependency_groups():
        for output_name in state_group.outputs:
            mappers.append(ModelStateMapperIdentity(output_name))

    return ModelStateMapperParallel(mappers)

identity_mapper_from_module(module)

Creates an identity mapper for every parameter in a single PyTorch module.

It is useful when you want to define a "pass-through" pipeline where the source checkpoint keys are expected to exactly match the model's current parameter names (standard load_state_dict behavior).

Parameters:

Name Type Description Default
module Module

The instantiated PyTorch model to inspect.

required
Source code in d9d/model_state/mapper/adapters/module.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def identity_mapper_from_module(module: nn.Module) -> ModelStateMapper:
    """
    Creates an identity mapper for every parameter in a single PyTorch module.

    It is useful when you want to define a "pass-through" pipeline where the
    source checkpoint keys are expected to exactly match the model's current
    parameter names (standard `load_state_dict` behavior).

    Args:
        module: The instantiated PyTorch model to inspect.
    """

    return ModelStateMapperParallel(
        [ModelStateMapperIdentity(key) for key in module.state_dict()]
    )

d9d.model_state.mapper.compose

Complex state mappers are built using composition. This package provides ModelStateMapper implementations that are composed of other mappers.

ModelStateMapperParallel

Bases: ModelStateMapper

Executes a list of states mappers independently alongside each other.

This class aggregates multiple mappers into a single logical unit. It enforces strict isolation between the mappers: no two mappers can consume the same input key (input collision) or produce the same output key (output collision).

During execution (apply), it routes the specific subset of the input dictionary to the sub-mapper responsible for those keys.

Source code in d9d/model_state/mapper/compose/parallel.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ModelStateMapperParallel(ModelStateMapper):
    """
    Executes a list of states mappers independently alongside each other.

    This class aggregates multiple mappers into a single logical unit.
    It enforces strict isolation between the mappers: no two mappers can
    consume the same input key (input collision) or produce the same output
    key (output collision).

    During execution (`apply`), it routes the specific subset of the input dictionary
    to the sub-mapper responsible for those keys.
    """

    def __init__(self, mappers: Sequence[ModelStateMapper]):
        mappers_lst = filter_empty_mappers(mappers)

        all_groups = set()
        inputs_to_mapper = {}

        seen_inputs: set[str] = set()
        seen_outputs: set[str] = set()
        for mapper in mappers_lst:
            sub_groups = mapper.state_dependency_groups()

            for sub_group in sub_groups:
                if not seen_inputs.isdisjoint(sub_group.inputs):
                    raise ValueError(f"Found a colliding input group: {sub_group.inputs}")
                seen_inputs.update(sub_group.inputs)

                if not seen_outputs.isdisjoint(sub_group.outputs):
                    raise ValueError(f"Found colliding output keys: {sub_group.outputs}")
                seen_outputs.update(sub_group.outputs)

                all_groups.add(sub_group)
                inputs_to_mapper[sub_group.inputs] = mapper

        self._all_groups = frozenset(all_groups)
        self._inputs_to_mapper = inputs_to_mapper

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return self._all_groups

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        group_keys = frozenset(group.keys())

        if group_keys not in self._inputs_to_mapper:
            raise ValueError("Tried to run a parallel mapper with undefined group. Perhaps you sent groups that are "
                             "not isolated?")

        return self._inputs_to_mapper[group_keys].apply(group)

ModelStateMapperSequential

Bases: ModelStateMapper

Executes a list of mappers in a specific sequence (pipeline).

This class manages the data flow from one mapper to the next. It abstracts away intermediate states, exposing only the inputs required by the first relevant stage and the outputs produced by the final relevant stage.

Key Features:

  1. Gap Filling: Automatically injects Identity mappers if a tensor needs to pass through a stage without modification to reach a later stage or the final output.

  2. Group Merging: Computes the net dependency graph. If Stage A requires 'x' and produces 'y', and Stage B requires 'y' and produces 'z', the Sequential mapper reports a single group {x} -> {z}.

Source code in d9d/model_state/mapper/compose/sequential.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class ModelStateMapperSequential(ModelStateMapper):
    """
    Executes a list of mappers in a specific sequence (pipeline).

    This class manages the data flow from one mapper to the next. It abstracts
    away intermediate states, exposing only the inputs required by the first
    relevant stage and the outputs produced by the final relevant stage.

    Key Features:

    1. **Gap Filling**: Automatically injects `Identity` mappers if a tensor needs
       to pass through a stage without modification to reach a later stage or
       the final output.

    2. **Group Merging**: Computes the net dependency graph. If Stage A requires 'x'
       and produces 'y', and Stage B requires 'y' and produces 'z', the
       Sequential mapper reports a single group `{x} -> {z}`.
    """

    def __init__(self, mappers: list[ModelStateMapper]):
        mappers = filter_empty_mappers(mappers)
        if not mappers:
            raise ValueError("Mappers list cannot be empty.")

        mappers = self._fill_gaps(mappers)

        self._groups = self._compute_pipeline_groups(mappers)
        self._mappers = mappers

    @staticmethod
    def _fill_gaps(mappers: list[ModelStateMapper]) -> list[ModelStateMapper]:
        mappers = mappers.copy()

        # propagate inputs from bottom to top
        for stage_i in range(1, len(mappers))[::-1]:
            groups_current = mappers[stage_i].state_dependency_groups()
            groups_prev = mappers[stage_i - 1].state_dependency_groups()
            current_stage_requires = frozenset.union(*(x.inputs for x in groups_current))
            prev_stage_produces = frozenset.union(*(x.outputs for x in groups_prev))

            needs_to_pass_through = current_stage_requires - prev_stage_produces

            mappers[stage_i - 1] = ModelStateMapperParallel(
                [mappers[stage_i - 1]] + [ModelStateMapperIdentity(x) for x in needs_to_pass_through]
            )

        # propagate outputs from top to bottom
        for stage_i in range(0, len(mappers) - 1):
            groups_current = mappers[stage_i].state_dependency_groups()
            groups_next = mappers[stage_i + 1].state_dependency_groups()
            current_stage_produces = frozenset.union(*(x.outputs for x in groups_current))
            next_stage_requires = frozenset.union(*(x.inputs for x in groups_next))

            needs_to_pass_through = current_stage_produces - next_stage_requires

            mappers[stage_i + 1] = ModelStateMapperParallel(
                [mappers[stage_i + 1]] + [ModelStateMapperIdentity(x) for x in needs_to_pass_through]
            )

        return mappers

    @staticmethod
    def _compute_pipeline_groups(mappers: list[ModelStateMapper]) -> frozenset[StateGroup]:
        outputs_depend_on_inputs = {}

        # given a fully connected graph, we can just go upwards
        for last_group_traced in mappers[-1].state_dependency_groups():
            required_inputs = last_group_traced.inputs

            for mapper_i in range(0, len(mappers) - 1)[::-1]:
                next_visit_groups = [x for x in mappers[mapper_i].state_dependency_groups()
                                     if not x.outputs.isdisjoint(required_inputs)]

                required_inputs = frozenset.union(*(x.inputs for x in next_visit_groups))

            outputs_depend_on_inputs[last_group_traced.outputs] = required_inputs

        return ModelStateMapperSequential._merge_groups(list(outputs_depend_on_inputs.items()))

    @staticmethod
    def _merge_groups(groups: Sequence[tuple[AbstractSet[str], AbstractSet[str]]]) -> frozenset[StateGroup]:
        saved_groups: list[tuple[set[str], set[str]]] = []

        saved_groups_modified = True
        while saved_groups_modified:
            saved_groups_modified = False
            for output_names, input_names in groups:
                was_new_group_created = False
                for group in saved_groups:
                    if group[0].intersection(input_names) or group[1].intersection(output_names):
                        group[0].update(input_names)
                        group[1].update(output_names)
                        was_new_group_created = True
                        saved_groups_modified = True

                if not was_new_group_created:
                    saved_groups.append((set(input_names), set(output_names)))

            groups = saved_groups
            saved_groups = []

        return frozenset(StateGroup(inputs=frozenset(x[0]), outputs=frozenset(x[1])) for x in groups)

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return self._groups

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        current_state = group
        next_state = {}
        for mapper in self._mappers:
            for deps in mapper.state_dependency_groups():
                if not deps.inputs <= current_state.keys():
                    continue

                next_state.update(mapper.apply({k: v for k, v in current_state.items() if k in deps.inputs}))

            current_state = next_state
            next_state = {}

        return current_state

ModelStateMapperShard

Bases: ModelStateMapper

Wraps another state mapper and restricts its execution to a specific subset (shard) of dependency groups.

This is primarily used for parallelizing model loading across multiple processes or nodes. By assigning a different current_shard index to each process, the total set of tensors required by the sub_mapper is split evenly, preventing every process from loading the entire checkpoint.

Source code in d9d/model_state/mapper/compose/shard.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class ModelStateMapperShard(ModelStateMapper):
    """
    Wraps another state mapper and restricts its execution to a specific subset (shard)
    of dependency groups.

    This is primarily used for parallelizing model loading across multiple processes
    or nodes. By assigning a different `current_shard` index to each process,
    the total set of tensors required by the `sub_mapper` is split evenly,
    preventing every process from loading the entire checkpoint.
    """

    def __init__(self, sub_mapper: ModelStateMapper, total_shards: int, current_shard: int):
        self._groups = self._shard_groups(
            sub_mapper.state_dependency_groups(),
            n_shards=total_shards, shard=current_shard
        )
        self._sub_mapper = sub_mapper
        self._total_shards = total_shards
        self._current_shard = current_shard

    @staticmethod
    def _shard_groups(groups: frozenset[StateGroup], n_shards: int, shard: int) -> frozenset[StateGroup]:
        groups_sorted = sorted(groups, key=lambda x: sorted(x.inputs))
        groups_shard = [x for i, x in enumerate(groups_sorted) if i % n_shards == shard]
        return frozenset(groups_shard)

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return self._groups

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        return self._sub_mapper.apply(group)

filter_empty_mappers(mappers)

Filters out mappers that have no effect (no inputs and no outputs).

Parameters:

Name Type Description Default
mappers Sequence[ModelStateMapper]

The list of mappers to filter.

required

Returns:

Type Description
list[ModelStateMapper]

A new list containing only active mappers.

Source code in d9d/model_state/mapper/compose/helper.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def filter_empty_mappers(mappers: Sequence[ModelStateMapper]) -> list[ModelStateMapper]:
    """
    Filters out mappers that have no effect (no inputs and no outputs).

    Args:
        mappers: The list of mappers to filter.

    Returns:
        A new list containing only active mappers.
    """
    result = []
    for mapper in mappers:
        for group in mapper.state_dependency_groups():
            if len(group.inputs) > 0 or len(group.outputs) > 0:
                result.append(mapper)
                break
    return result

d9d.model_state.mapper.leaf

This package provides leaf mapper implementations.

ModelStateMapperDistribute

Bases: ModelStateMapper

Converts a single local Tensor object into a DTensor object with specified device_mesh and placements.

Source code in d9d/model_state/mapper/leaf/dtensor.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ModelStateMapperDistribute(ModelStateMapper):
    """
    Converts a single local Tensor object into a DTensor object with specified
    `device_mesh` and `placements`.
    """

    def __init__(self, name: str, device_mesh: DeviceMesh | None, placements: Sequence[Placement] | None):
        self._name = name

        self._device_mesh = device_mesh
        self._placements = placements

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return frozenset([StateGroup(inputs=frozenset([self._name]), outputs=frozenset([self._name]))])

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        return {
            self._name: distribute_tensor(
                group[self._name],
                device_mesh=self._device_mesh,
                placements=self._placements,
                src_data_rank=None  # do not communicate here
            )
        }

ModelStateMapperGatherFullTensor

Bases: ModelStateMapper

Gathers a single DTensor object into a full Tensor object.

Source code in d9d/model_state/mapper/leaf/dtensor.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class ModelStateMapperGatherFullTensor(ModelStateMapper):
    """
    Gathers a single DTensor object into a full Tensor object.
    """

    def __init__(self, name: str):
        self._name = name

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return frozenset([StateGroup(inputs=frozenset([self._name]), outputs=frozenset([self._name]))])

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        tensor = group[self._name]

        if not isinstance(tensor, DTensor):
            raise ValueError("Cannot gather anything but DTensor")

        return {
            self._name: tensor.full_tensor()
        }

ModelStateMapperIdentity

Bases: ModelStateMapper

Passes a single state tensor through unchanged.

Source code in d9d/model_state/mapper/leaf/identity.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class ModelStateMapperIdentity(ModelStateMapper):
    """
    Passes a single state tensor through unchanged.
    """

    def __init__(self, name: str):
        self._name = name

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return frozenset([
            StateGroup(
                inputs=frozenset([self._name]),
                outputs=frozenset([self._name])
            )
        ])

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        return group

ModelStateMapperRename

Bases: ModelStateMapper

Renames a single state tensor from name_from to name_to.

Source code in d9d/model_state/mapper/leaf/rename.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class ModelStateMapperRename(ModelStateMapper):
    """
    Renames a single state tensor from `name_from` to `name_to`.
    """

    def __init__(self, name_from: str, name_to: str):
        self._name_from = name_from
        self._name_to = name_to

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return frozenset([
            StateGroup(
                inputs=frozenset([self._name_from]),
                outputs=frozenset([self._name_to])
            )
        ])

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        return {
            self._name_to: group[self._name_from]
        }

ModelStateMapperSelectChildModules

Bases: ModelStateMapper

Selects a set of keys belonging to a specific parent module (prefix) and renames them by removing that prefix.

This is effectively a batch rename operation that "hoists" parameters from a submodule scope to the current scope.

Source code in d9d/model_state/mapper/leaf/select_child.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class ModelStateMapperSelectChildModules(ModelStateMapper):
    """
    Selects a set of keys belonging to a specific parent module (prefix) and
    renames them by removing that prefix.

    This is effectively a batch rename operation that "hoists" parameters
    from a submodule scope to the current scope.
    """

    def __init__(self, base_names: list[str], parent_name: str):
        self._base_names = base_names
        self._parent_prefix = f"{parent_name}."

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return frozenset([
            StateGroup(
                inputs=frozenset([self._parent_prefix + name]),
                outputs=frozenset([name])
            )
            for name in self._base_names
        ])

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        name, value = next(iter(group.items()))
        if name.startswith(self._parent_prefix):
            return {
                name[len(self._parent_prefix):]: value
            }
        else:
            return {

            }

ModelStateMapperStackTensors

Bases: ModelStateMapper

Stacks multiple input tensors with names source_names into a single output tensor with name target_name producing new stack_dim dimension.

Source code in d9d/model_state/mapper/leaf/stack.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class ModelStateMapperStackTensors(ModelStateMapper):
    """
    Stacks multiple input tensors with names `source_names` into a single output tensor with name `target_name`
    producing new `stack_dim` dimension.
    """

    def __init__(self, source_names: list[str], target_name: str, stack_dim: int):
        self._source_names = source_names
        self._target_name = target_name
        self._stack_dim = stack_dim

    def state_dependency_groups(self) -> frozenset[StateGroup]:
        return frozenset([
            StateGroup(
                inputs=frozenset(self._source_names),
                outputs=frozenset([self._target_name])
            )
        ])

    def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        source_tensors = [group[name] for name in self._source_names]
        return {
            self._target_name: torch.stack(source_tensors, dim=self._stack_dim)
        }