Model State IO
About
The d9d.model_state.io package handles the reading and writing of model checkpoints.
We use checkpoint format that is compatible with HuggingFace format. This format is characterized by using sharded model-00001-of-XXXXX.safetensors .safetensors files for storing parameter tensors along with model.safetensors.index.json file containing the metadata.
It is tightly integrated with the d9d.model_state.mapper framework to allow for Streamed Transformation - converting model architectures on-the-fly during IO without loading the entire model into memory.
Core Concepts
Why Support Transformations
In d9d all the model state input/output logic is natively integrated with mapping and transforming model states. Such a combined system acts as a powerful abstraction layer that decouples the checkpoint architecture (how weights are stored on disk) from the model architecture (how weights are used in PyTorch code).
This integration is critical for:
- Native HuggingFace Compatibility: You can use highly optimized, custom model implementations (e.g., using a single packed
qkv_projtensor) while reading directly from standard community checkpoints (which typically storeq_proj,k_proj, andv_projseparately). The mapper handles the reshaping and stacking on-the-fly during the read stream. This eliminates the need for maintaining separate "conversion scripts" or storing duplicate, converted copies of large models. - Runtime Structure Adapting (e.g., LoRA): When injecting adapters like LoRA, the runtime model structure changes - often wrapping original layers. For example, a standard
some_linear.weighton disk might need to be loaded intosome_linear.orig.weightin memory. Instead of loading the full state dict and manually patching keys (which spikes memory), the mapper reroutes these keys without the need of materializing the model weights fully.
How Loading Works
Standard model loading involves loading a huge dictionary into CPU RAM, filtering and processing it, and moving the results to GPU. This approach is ineffective since it requires a lot of CPU-GPU transfers, consumes high amount of memory and involves duplicate work across different pipeline parallel workers.
d9d proposes a different approach:
- Streaming & Eviction: Tensors are loaded in streamed manner and therefore kept in memory only when needed. Once a mapper group (e.g., "stack Q, K, V") is executed, the source tensors are immediately evicted from memory.
- Topology-Aware Loading: Instead of blindly loading all the files, the reader inspects the
ModelStateMapper. It calculates exactly which files contain the required inputs.
How Saving Works
Standard model saving often requires gathering all parameters to a single rank (causing OOM) or manual orchestration of file names and indices across hundreds of GPUs.
d9d's approach automates the checkpoint exporting lifecycle for large-scale distributed setups:
- Streaming & Eviction: Tensors are saved in streamed manner and therefore kept in memory only when needed. Once a mapper group (e.g., "stack Q, K, V") is executed, the source tensors are immediately evicted from memory. Target tensors are kept in memory only before they are flushed to respective
.safetensorsfiles. - Distributed Awareness: In addition to providing local model exporting, we provide distributed-aware export functions. The writer natively understands distributed topologies (via
ProcessGrouporDeviceMesh). In Pipeline Parallel scenarios, it identifies which rank holds the specific stage master copy, ensuring that parameters are written exactly once without race conditions or duplication.
Usage Examples
These examples provide information primarily how to load and write model states in a pass-through way. If you want to see examples of complex model state mapping, please refer to ModelStateMapper documentation.
Raw I/O - Streamed Loading
This example shows how to load a model without spiking memory usage.
Raw I/O - Streamed Saving
Saving a model locally, automatically splitting into 1 GB shards.
Raw I/O - Distributed Load-Transform-Save
One of the most powerful features of d9d is the ability to perform Offline Checkpoint Conversion using a distributed cluster.
If you have a massive checkpoint in Format A (e.g., HuggingFace) and need to convert it to Format B (e.g., a custom Training format with packed QKV), you don't need a single machine with 1TB RAM. instead, you can spin up 8 GPUs, have each GPU process 1/8th of the keys in parallel, and write a new sharded checkpoint.
PyTorch Module I/O - Streamed Loading
Loading a checkpoint where disk keys exactly match model keys. identity_mapper_from_module ensures only existing model parameters are loaded.
PyTorch Module I/O - Streaming Saving (DeviceMesh)
Saves a model in a complex ND Parallel environment using PyTorch DeviceMesh.
This features:
- DTensor Gathering: Automatically gathers
DTensorshards from the mesh into full tensors before writing. - Concurrency Within PP Rank: In a Data/Tensor/... parallel setup, multiple GPUs hold replicated or sharded copies of the same parameters. This function uses the
DeviceMeshto ensure that only the "canonical" PP replica (DP Rank 0, TP Rank 0, ...) writes to disk, preventing write conflicts. - Concurrency Across PP Ranks: Each PP rank writes the data into its own files. After all the PP ranks finish writing, PP Rank 0 merges the metadata from different PP ranks into a single global checkpoint index file.
d9d.model_state.io
load_model_state(src_dir, mapper, device, model, show_progress=True, position=None)
High-level utility to stream a checkpoint directly into a PyTorch module.
This function orchestrates the full loading lifecycle:
-
Topology Mapping: Uses
mapperto rename/stack/reshape on-disk states to model states. -
Automatic Distribution: If the
modelcontainsDTensors, the loaded local tensors are automatically sharded/replicated to match the model's placement schema. -
Streaming Read & Inject: After loading and transforming a model state, it will be injected into
modelusingload_state_dict(...).
NOTICE: Only states specified in mapper will be loaded! You can use
d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will load every
model state without changing it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
src_dir
|
Path
|
Directory containing .safetensors and index files. |
required |
mapper
|
ModelStateMapper
|
The topology defining how mapping from disk keys to model keys works. |
required |
device
|
str
|
The device to load tensors onto (usually "cpu" or "cuda"). |
required |
model
|
Module
|
The model instance to load weights into. |
required |
show_progress
|
bool
|
Whether to display the loading progress bar. |
True
|
position
|
int | None
|
Row index for the tqdm bar. Pass the process local rank to stack one bar
per rank without interleaving. |
None
|
read_model_state(src_dir, mapper, device, show_progress=True, position=None)
Reads a model checkpoint from disk, transforming it on-the-fly according to the state mapper.
This function uses a streaming approach. It analyzes the mapper to determine which files need to be loaded. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
src_dir
|
Path
|
The directory containing .safetensors files and |
required |
mapper
|
ModelStateMapper
|
The transformation graph defining how to map on-disk keys to output keys. |
required |
device
|
str
|
The device to load tensors onto (e.g., "cpu", "cuda:0"). |
required |
show_progress
|
bool
|
Whether to display a progress bar. |
True
|
position
|
int | None
|
Row index for the tqdm bar. Pass the process local rank to stack one bar
per rank without interleaving. |
None
|
Yields:
| Type | Description |
|---|---|
Iterable[tuple[str, Tensor]]
|
A tuple containing the transformed parameter name and its tensor value. |
save_model_state(dest_dir, mapper, model, shard_size_gb=4.0, show_progress=True)
High-level utility to save a PyTorch model to disk on a single process.
NOTICE: Only states specified in mapper will be saved! You can use
d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will save every
model state without changing it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
The directory to save .safetensors shards and index. |
required |
mapper
|
ModelStateMapper
|
Topology defining how model keys map to disk keys. |
required |
model
|
Module
|
The PyTorch module to save. |
required |
shard_size_gb
|
float
|
Max size per shard file in Gigabytes. |
4.0
|
show_progress
|
bool
|
Whether to display a progress bar. |
True
|
save_model_state_pipeline_parallel(dest_dir, mapper, device_mesh, pipeline_dim_name, models, shard_size_gb=4.0, show_progress=True, position=None)
High-level utility to save a model in a Distributed Pipeline Parallel environment to disk.
Features:
-
Auto-Gather: Converts
DTensorparameters to full tensors before saving. -
Distribution Awareness: Uses the
device_meshto ensure that for a given pipeline stage, only the master rank writes the checkpoint, preventing Write-After-Write conflicts. -
Index Merging: Aggregates metadata from all independent pipeline stages into one global index file.
NOTICE: Only states specified in mapper will be saved! You can use
d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will save every
model state without changing it.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
directory to save .safetensors shards and index file. |
required |
mapper
|
ModelStateMapper
|
Topology defining how model keys map to disk keys. |
required |
device_mesh
|
DeviceMesh
|
The cluster topology mesh. |
required |
pipeline_dim_name
|
str
|
The specific dimension name in the mesh used for pipelining. |
required |
models
|
list[Module]
|
A list of modules (pipeline stages) processed by this PP rank. |
required |
shard_size_gb
|
float
|
Max size per shard file in Gigabytes. |
4.0
|
show_progress
|
bool
|
Whether to display a progress bar. |
True
|
position
|
int | None
|
Row index for the tqdm bar. Pass the process local rank to stack one bar
per rank without interleaving. |
None
|
write_model_state_distributed(dest_dir, mapper, state_generator, process_group, shard_size_gb=4.0, show_progress=True, position=None)
Saves model states in a distributed setup (multiple processes).
This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.
Each rank writes its own shard. Rank 0 gathers indices and finalizes the checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
Destination directory. |
required |
mapper
|
ModelStateMapper
|
Mapping to apply to states before saving. |
required |
state_generator
|
Iterable[tuple[str, Tensor]]
|
Stream of (name, tensor) pairs from the model. |
required |
process_group
|
ProcessGroup
|
The distributed process group. |
required |
shard_size_gb
|
float
|
Maximum shard size in GB. |
4.0
|
show_progress
|
bool
|
Whether to show the progress bar. |
True
|
position
|
int | None
|
Row index for the tqdm bar. Pass the process local rank to stack one bar
per rank without interleaving. |
None
|
write_model_state_local(dest_dir, mapper, state_generator, shard_size_gb=4.0, show_progress=True)
Saves model states to disk in a single local process.
This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
Destination directory. |
required |
mapper
|
ModelStateMapper
|
Mapping to apply to states before saving. |
required |
state_generator
|
Iterable[tuple[str, Tensor]]
|
Stream of (name, tensor) pairs to save. |
required |
shard_size_gb
|
float
|
Maximum size of a single .safetensors file in GB. |
4.0
|
show_progress
|
bool
|
Whether to show the progress bar. |
True
|
write_model_state_pipeline_parallel(dest_dir, mapper, state_generator, device_mesh, pipeline_dim_name, shard_size_gb=4.0, show_progress=True, position=None)
Saves model states in a complex ND distributed training setting.
This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.
This handles Pipeline Parallelism by ensuring that only one rank per pipeline stage actually writes data to disk to avoid duplication.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dest_dir
|
Path
|
Destination directory. |
required |
mapper
|
ModelStateMapper
|
Mapping to apply to states before saving. |
required |
state_generator
|
Iterable[tuple[str, Tensor]]
|
Stream of (name, tensor) pairs from the model. |
required |
device_mesh
|
DeviceMesh
|
The PyTorch DeviceMesh representing the cluster layout. |
required |
pipeline_dim_name
|
str
|
The name of the mesh dimension responsible for pipeline parallelism. |
required |
shard_size_gb
|
float
|
Maximum shard size in GB. |
4.0
|
show_progress
|
bool
|
Whether to show the progress bar. |
True
|
position
|
int | None
|
Row index for the tqdm bar. Pass the process local rank to stack one bar
per rank without interleaving. |
None
|