About
The d9d.model_state.io package handles the reading and writing of model checkpoints.
We use checkpoint format that is compatible with HuggingFace format. This format is characterized by using sharded model-00001-of-XXXXX.safetensors .safetensors files for storing parameter tensors along with model.safetensors.index.json file containing the metadata.
It is tightly integrated with the d9d.model_state.mapper framework to allow for Streamed Transformation - converting model architectures on-the-fly during IO without loading the entire model into memory.
Core Concepts
Why Support Transformations
In d9d all the model state input/output logic is natively integrated with mapping and transforming model states. Such a combined system acts as a powerful abstraction layer that decouples the checkpoint architecture (how weights are stored on disk) from the model architecture (how weights are used in PyTorch code).
This integration is critical for:
- Native HuggingFace Compatibility: You can use highly optimized, custom model implementations (e.g., using a single packed
qkv_projtensor) while reading directly from standard community checkpoints (which typically storeq_proj,k_proj, andv_projseparately). The mapper handles the reshaping and stacking on-the-fly during the read stream. This eliminates the need for maintaining separate "conversion scripts" or storing duplicate, converted copies of large models. - Runtime Structure Adapting (e.g., LoRA): When injecting adapters like LoRA, the runtime model structure changes - often wrapping original layers. For example, a standard
some_linear.weighton disk might need to be loaded intosome_linear.orig.weightin memory. Instead of loading the full state dict and manually patching keys (which spikes memory), the mapper reroutes these keys without the need of materializing the model weights fully.
How Loading Works
Standard model loading involves loading a huge dictionary into CPU RAM, filtering and processing it, and moving the results to GPU. This approach is ineffective since it requires a lot of CPU-GPU transfers, consumes high amount of memory and involves duplicate work across different pipeline parallel workers.
d9d proposes a different approach:
- Streaming & Eviction: Tensors are loaded in streamed manner and therefore kept in memory only when needed. Once a mapper group (e.g., "stack Q, K, V") is executed, the source tensors are immediately evicted from memory.
- Topology-Aware Loading: Instead of blindly loading all the files, the reader inspects the
ModelStateMapper. It calculates exactly which files contain the required inputs.
How Saving Works
Standard model saving often requires gathering all parameters to a single rank (causing OOM) or manual orchestration of file names and indices across hundreds of GPUs.
d9d's approach automates the checkpoint exporting lifecycle for large-scale distributed setups:
- Streaming & Eviction: Tensors are saved in streamed manner and therefore kept in memory only when needed. Once a mapper group (e.g., "stack Q, K, V") is executed, the source tensors are immediately evicted from memory. Target tensors are kept in memory only before they are flushed to respective
.safetensorsfiles. - Distributed Awareness: In addition to providing local model exporting, we provide distributed-aware export functions. The writer natively understands distributed topologies (via
ProcessGrouporDeviceMesh). In Pipeline Parallel scenarios, it identifies which rank holds the specific stage master copy, ensuring that parameters are written exactly once without race conditions or duplication.
Usage Examples
These examples provide information primarily how to load and write model states in a pass-through way. If you want to see examples of complex model state mapping, please refer to ModelStateMapper documentation.
Raw I/O - Streamed Loading
This example shows how to load a model without spiking memory usage.
from pathlib import Path
import torch
from d9d.model_state.io.reader import read_model_state
from d9d.model_state.mapper.leaf import ModelStateMapperStackTensors
from d9d.model_state.mapper.adapters import identity_mapper_from_module
# Define the mapper (Topology)
# in this example we will load all the parameters the model contains
mapper = identity_mapper_from_module(model)
# Start the stream
# 'src_dir' must contain safetensors files and model.safetensors.index.json
loader_stream = read_model_state(
src_dir=Path("./checkpoint"),
mapper=mapper,
device="cpu" # or "cuda:0"
)
# Iterate through the transformed results
state_dict = {}
for name, tensor in loader_stream:
print(f"Loaded and transformed: {name} -> {tensor.shape}")
state_dict[name] = tensor
Raw I/O - Streamed Saving
Saving a model locally, automatically splitting into 1 GB shards.
from pathlib import Path
from d9d.model_state.io.writer import write_model_state_local
from d9d.model_state.mapper.adapters import identity_mapper_from_module
# 1. Create a generator for your model states
state_generator = model.named_parameters()
# 2. Define a mapper (Identity if no transformation is needed during save)
mapper = identity_mapper_from_module(model)
# 3. Write
# This handles sharding and metadata file creation automatically
write_model_state_local(
dest_dir=Path("./output_checkpoint"),
mapper=mapper,
state_generator=state_generator,
shard_size_gb=1.0 # Split files if they exceed 1GB
)
Raw I/O - Distributed Load-Transform-Save
One of the most powerful features of d9d is the ability to perform Offline Checkpoint Conversion using a distributed cluster.
If you have a massive checkpoint in Format A (e.g., HuggingFace) and need to convert it to Format B (e.g., a custom Training format with packed QKV), you don't need a single machine with 1TB RAM. instead, you can spin up 8 GPUs, have each GPU process 1/8th of the keys in parallel, and write a new sharded checkpoint.
import torch.distributed as dist
from pathlib import Path
from d9d.model_state.io import read_model_state, write_model_state_distributed
from d9d.model_state.mapper.compose import ModelStateMapperShard
from d9d.model_state.mapper.adapters import identity_mapper_from_mapper_outputs
# 1. Initialize distributed environment
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
# 2. Define the global transformation logic
# This describes how the ENTIRE model should be converted.
# e.g., "Stack Q,K,V", "Rename MLP", "Load everying else as-is"
mapper = build_my_fancy_custom_mapper()
# 3. Shard the workload
# We wrap the mapper to restrict execution.
# Rank 0 will only process the first 1/N dependency groups, Rank 1 the next, etc.
# This ensures that no two ranks load/process/save the same tensors.
local_work_mapper = ModelStateMapperShard(
sub_mapper=mapper,
total_shards=world_size,
current_shard=rank
)
# 4. Define saving topology
# The 'read_model_state' generator below will yield tensors that have
# ALREADY been transformed to their target names/shapes.
# The writer just needs to accept these new keys and save them.
# We generate this identity mapper automatically from the output signature.
writer_mapper = identity_mapper_from_mapper_outputs(local_work_mapper)
# 5. Execute the pipeline
# - Reader: Loads specific source files, transforms, yields new tensors.
# - Writer: Receives new tensors, saves to '*.safetensors' files with temporary names
# - Finalizer: Rank 0 creates 'model.safetensors.index.json' covering all ranks and renames .safetensors files to their final names.
write_model_state_distributed(
dest_dir=Path("./converted_checkpoint"),
mapper=writer_mapper,
state_generator=read_model_state(
src_dir=Path("./original_checkpoint"),
mapper=local_work_mapper, # Defines what to load and transformations
device="cuda",
show_progress=False # Disable read bars to avoid stderr spam
),
process_group=dist.distributed_c10d._get_default_group(),
shard_size_gb=4.0,
show_progress=True # Master rank will show global save progress
)
PyTorch Module I/O - Streamed Loading
Loading a checkpoint where disk keys exactly match model keys. identity_mapper_from_module ensures only existing model parameters are loaded.
from pathlib import Path
from d9d.model_state.io import load_model_state
from d9d.model_state.mapper.adapters import identity_mapper_from_module
# 1. Setup Model (e.g., empty or on meta device)
model = ...
# 2. Create Identity Topology
# This tells d9d: "Load every key that exists in 'model' as is."
mapper = identity_mapper_from_module(model)
# 3. Stream & Inject
load_model_state(
src_dir=Path("./checkpoints/v1"),
mapper=mapper,
device="cuda",
model=model
)
PyTorch Module I/O - Streaming Saving (DeviceMesh)
Saves a model in a complex ND Parallel environment using PyTorch DeviceMesh.
This features:
- DTensor Gathering: Automatically gathers
DTensorshards from the mesh into full tensors before writing. - Concurrency Within PP Rank: In a Data/Tensor/... parallel setup, multiple GPUs hold replicated or sharded copies of the same parameters. This function uses the
DeviceMeshto ensure that only the "canonical" PP replica (DP Rank 0, TP Rank 0, ...) writes to disk, preventing write conflicts. - Concurrency Across PP Ranks: Each PP rank writes the data into its own files. After all the PP ranks finish writing, PP Rank 0 merges the metadata from different PP ranks into a single global checkpoint index file.
from pathlib import Path
from torch.distributed.device_mesh import init_device_mesh
from d9d.model_state.io import save_model_state_pipeline_parallel
from d9d.model_state.mapper.compose import ModelStateMapperParallel
from d9d.model_state.mapper.adapters import identity_mapper_from_module
# 1. Setup 3D Mesh
# pp=2 (Pipeline), dp=2 (Data), tp=2 (Tensor)
mesh = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "dp", "tp"))
# 2. Define Model Stages
# In this example, each PP rank manages two distinct parts of the model.
my_stages = [TransformerStage(...), TransformerStage(...)]
# 3. Create Topology
# Since this rank manages multiple modules, we create a Parallel mapper
# to combine the requirements of all stages.
mapper = ModelStateMapperParallel([
identity_mapper_from_module(stage) for stage in my_stages
])
# 4. Save
# The system inspects the mesh. It identifies if the current rank is
# the "Master" for the provided stages (i.e., dp_rank=0, tp_rank=0).
# If so, it gathers DTensors and writes. If not, it skips writing
# but participates in the collective gather.
save_model_state_pipeline_parallel(
dest_dir=Path("./checkpoint"),
mapper=mapper,
device_mesh=mesh,
pipeline_dim_name="pp",
models=my_stages,
shard_size_gb=4.0
)
d9d.model_state.io
load_model_state(src_dir, mapper, device, model, show_progress=True)
High-level utility to stream a checkpoint directly into a PyTorch module.
This function orchestrates the full loading lifecycle:
-
Topology Mapping: Uses
mapperto rename/stack/reshape on-disk states to model states. -
Automatic Distribution: If the
modelcontainsDTensors, the loaded local tensors are automatically sharded/replicated to match the model's placement schema. -
Streaming Read & Inject: After loading and transforming a model state, it will be injected into
modelusingload_state_dict(...).
NOTICE: Only states specified in mapper will be loaded! You can use
d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will load every
model state without changing it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
src_dir
|
Path
|
Directory containing .safetensors and index files. |
required |
mapper
|
ModelStateMapper
|
The topology defining how mapping from disk keys to model keys works. |
required |
device
|
str
|
The device to load tensors onto (usually "cpu" or "cuda"). |
required |
model
|
Module
|
The model instance to load weights into. |
required |
show_progress
|
bool
|
Whether to display the loading progress bar. |
True
|
Source code in d9d/model_state/io/module_reader.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | |
read_model_state(src_dir, mapper, device, show_progress=True)
Reads a model checkpoint from disk, transforming it on-the-fly according to the state mapper.
This function uses a streaming approach. It analyzes the mapper to determine which files need to be loaded. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
src_dir
|
Path
|
The directory containing .safetensors files and |
required |
mapper
|
ModelStateMapper
|
The transformation graph defining how to map on-disk keys to output keys. |
required |
device
|
str
|
The device to load tensors onto (e.g., "cpu", "cuda:0"). |
required |
show_progress
|
bool
|
Whether to display a progress bar. |
True
|
Yields:
| Type | Description |
|---|---|
Iterable[tuple[str, Tensor]]
|
A tuple containing the transformed parameter name and its tensor value. |
Source code in d9d/model_state/io/reader.py
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | |
save_model_state(dest_dir, mapper, model, shard_size_gb=4.0, show_progress=True)
High-level utility to save a PyTorch model to disk on a single process.
NOTICE: Only states specified in mapper will be saved! You can use
d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will save every
model state without changing it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
The directory to save .safetensors shards and index. |
required |
mapper
|
ModelStateMapper
|
Topology defining how model keys map to disk keys. |
required |
model
|
Module
|
The PyTorch module to save. |
required |
shard_size_gb
|
float
|
Max size per shard file in Gigabytes. |
4.0
|
show_progress
|
bool
|
Whether to display a progress bar. |
True
|
Source code in d9d/model_state/io/module_writer.py
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | |
save_model_state_pipeline_parallel(dest_dir, mapper, device_mesh, pipeline_dim_name, models, shard_size_gb=4.0, show_progress=True)
High-level utility to save a model in a Distributed Pipeline Parallel environment to disk.
Features:
-
Auto-Gather: Converts
DTensorparameters to full tensors before saving. -
Distribution Awareness: Uses the
device_meshto ensure that for a given pipeline stage, only the master rank writes the checkpoint, preventing Write-After-Write conflicts. -
Index Merging: Aggregates metadata from all independent pipeline stages into one global index file.
NOTICE: Only states specified in mapper will be saved! You can use
d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will save every
model state without changing it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
directory to save .safetensors shards and index file. |
required |
mapper
|
ModelStateMapper
|
Topology defining how model keys map to disk keys. |
required |
device_mesh
|
DeviceMesh
|
The cluster topology mesh. |
required |
pipeline_dim_name
|
str
|
The specific dimension name in the mesh used for pipelining. |
required |
models
|
list[Module]
|
A list of modules (pipeline stages) processed by this PP rank. |
required |
shard_size_gb
|
float
|
Max size per shard file in Gigabytes. |
4.0
|
show_progress
|
bool
|
Whether to display a progress bar. |
True
|
Source code in d9d/model_state/io/module_writer.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | |
write_model_state_distributed(dest_dir, mapper, state_generator, process_group, shard_size_gb=4.0, show_progress=True)
Saves model states in a distributed setup (multiple processes).
This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.
Each rank writes its own shard. Rank 0 gathers indices and finalizes the checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
Destination directory. |
required |
mapper
|
ModelStateMapper
|
Mapping to apply to states before saving. |
required |
state_generator
|
Iterable[tuple[str, Tensor]]
|
Stream of (name, tensor) pairs from the model. |
required |
process_group
|
ProcessGroup
|
The distributed process group. |
required |
shard_size_gb
|
float
|
Maximum shard size in GB. |
4.0
|
show_progress
|
bool
|
Whether to show the progress bar. |
True
|
Source code in d9d/model_state/io/writer.py
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | |
write_model_state_local(dest_dir, mapper, state_generator, shard_size_gb=4.0, show_progress=True)
Saves model states to disk in a single local process.
This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
Destination directory. |
required |
mapper
|
ModelStateMapper
|
Mapping to apply to states before saving. |
required |
state_generator
|
Iterable[tuple[str, Tensor]]
|
Stream of (name, tensor) pairs to save. |
required |
shard_size_gb
|
float
|
Maximum size of a single .safetensors file in GB. |
4.0
|
show_progress
|
bool
|
Whether to show the progress bar. |
True
|
Source code in d9d/model_state/io/writer.py
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | |
write_model_state_pipeline_parallel(dest_dir, mapper, state_generator, device_mesh, pipeline_dim_name, shard_size_gb=4.0, show_progress=True)
Saves model states in a complex ND distributed training setting.
This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.
This handles Pipeline Parallelism by ensuring that only one rank per pipeline stage actually writes data to disk to avoid duplication.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
Destination directory. |
required |
mapper
|
ModelStateMapper
|
Mapping to apply to states before saving. |
required |
state_generator
|
Iterable[tuple[str, Tensor]]
|
Stream of (name, tensor) pairs from the model. |
required |
device_mesh
|
DeviceMesh
|
The PyTorch DeviceMesh representing the cluster layout. |
required |
pipeline_dim_name
|
str
|
The name of the mesh dimension responsible for pipeline parallelism. |
required |
shard_size_gb
|
float
|
Maximum shard size in GB. |
4.0
|
show_progress
|
bool
|
Whether to show the progress bar. |
True
|
Source code in d9d/model_state/io/writer.py
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 | |