About

The d9d.model_state.io package handles the reading and writing of model checkpoints.

We use checkpoint format that is compatible with HuggingFace format. This format is characterized by using sharded model-00001-of-XXXXX.safetensors .safetensors files for storing parameter tensors along with model.safetensors.index.json file containing the metadata.

It is tightly integrated with the d9d.model_state.mapper framework to allow for Streamed Transformation - converting model architectures on-the-fly during IO without loading the entire model into memory.

Core Concepts

Why Support Transformations

In d9d all the model state input/output logic is natively integrated with mapping and transforming model states. Such a combined system acts as a powerful abstraction layer that decouples the checkpoint architecture (how weights are stored on disk) from the model architecture (how weights are used in PyTorch code).

This integration is critical for:

  • Native HuggingFace Compatibility: You can use highly optimized, custom model implementations (e.g., using a single packed qkv_proj tensor) while reading directly from standard community checkpoints (which typically store q_proj, k_proj, and v_proj separately). The mapper handles the reshaping and stacking on-the-fly during the read stream. This eliminates the need for maintaining separate "conversion scripts" or storing duplicate, converted copies of large models.
  • Runtime Structure Adapting (e.g., LoRA): When injecting adapters like LoRA, the runtime model structure changes - often wrapping original layers. For example, a standard some_linear.weight on disk might need to be loaded into some_linear.orig.weight in memory. Instead of loading the full state dict and manually patching keys (which spikes memory), the mapper reroutes these keys without the need of materializing the model weights fully.

How Loading Works

Standard model loading involves loading a huge dictionary into CPU RAM, filtering and processing it, and moving the results to GPU. This approach is ineffective since it requires a lot of CPU-GPU transfers, consumes high amount of memory and involves duplicate work across different pipeline parallel workers.

d9d proposes a different approach:

  • Streaming & Eviction: Tensors are loaded in streamed manner and therefore kept in memory only when needed. Once a mapper group (e.g., "stack Q, K, V") is executed, the source tensors are immediately evicted from memory.
  • Topology-Aware Loading: Instead of blindly loading all the files, the reader inspects the ModelStateMapper. It calculates exactly which files contain the required inputs.

How Saving Works

Standard model saving often requires gathering all parameters to a single rank (causing OOM) or manual orchestration of file names and indices across hundreds of GPUs.

d9d's approach automates the checkpoint exporting lifecycle for large-scale distributed setups:

  • Streaming & Eviction: Tensors are saved in streamed manner and therefore kept in memory only when needed. Once a mapper group (e.g., "stack Q, K, V") is executed, the source tensors are immediately evicted from memory. Target tensors are kept in memory only before they are flushed to respective .safetensors files.
  • Distributed Awareness: In addition to providing local model exporting, we provide distributed-aware export functions. The writer natively understands distributed topologies (via ProcessGroup or DeviceMesh). In Pipeline Parallel scenarios, it identifies which rank holds the specific stage master copy, ensuring that parameters are written exactly once without race conditions or duplication.

Usage Examples

These examples provide information primarily how to load and write model states in a pass-through way. If you want to see examples of complex model state mapping, please refer to ModelStateMapper documentation.

Raw I/O - Streamed Loading

This example shows how to load a model without spiking memory usage.

from pathlib import Path
import torch
from d9d.model_state.io.reader import read_model_state
from d9d.model_state.mapper.leaf import ModelStateMapperStackTensors
from d9d.model_state.mapper.adapters import identity_mapper_from_module

# Define the mapper (Topology)
# in this example we will load all the parameters the model contains
mapper = identity_mapper_from_module(model)

# Start the stream
# 'src_dir' must contain safetensors files and model.safetensors.index.json
loader_stream = read_model_state(
    src_dir=Path("./checkpoint"),
    mapper=mapper,
    device="cpu"  # or "cuda:0"
)

# Iterate through the transformed results
state_dict = {}
for name, tensor in loader_stream:
    print(f"Loaded and transformed: {name} -> {tensor.shape}")
    state_dict[name] = tensor

Raw I/O - Streamed Saving

Saving a model locally, automatically splitting into 1 GB shards.

from pathlib import Path
from d9d.model_state.io.writer import write_model_state_local
from d9d.model_state.mapper.adapters import identity_mapper_from_module

# 1. Create a generator for your model states
state_generator = model.named_parameters()

# 2. Define a mapper (Identity if no transformation is needed during save)
mapper = identity_mapper_from_module(model)

# 3. Write
# This handles sharding and metadata file creation automatically
write_model_state_local(
    dest_dir=Path("./output_checkpoint"),
    mapper=mapper,
    state_generator=state_generator,
    shard_size_gb=1.0  # Split files if they exceed 1GB
)

Raw I/O - Distributed Load-Transform-Save

One of the most powerful features of d9d is the ability to perform Offline Checkpoint Conversion using a distributed cluster.

If you have a massive checkpoint in Format A (e.g., HuggingFace) and need to convert it to Format B (e.g., a custom Training format with packed QKV), you don't need a single machine with 1TB RAM. instead, you can spin up 8 GPUs, have each GPU process 1/8th of the keys in parallel, and write a new sharded checkpoint.

import torch.distributed as dist
from pathlib import Path
from d9d.model_state.io import read_model_state, write_model_state_distributed
from d9d.model_state.mapper.compose import ModelStateMapperShard
from d9d.model_state.mapper.adapters import identity_mapper_from_mapper_outputs

# 1. Initialize distributed environment
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()

# 2. Define the global transformation logic
# This describes how the ENTIRE model should be converted.
# e.g., "Stack Q,K,V", "Rename MLP", "Load everying else as-is"
mapper = build_my_fancy_custom_mapper()

# 3. Shard the workload
# We wrap the mapper to restrict execution.
# Rank 0 will only process the first 1/N dependency groups, Rank 1 the next, etc.
# This ensures that no two ranks load/process/save the same tensors.
local_work_mapper = ModelStateMapperShard(
    sub_mapper=mapper,
    total_shards=world_size,
    current_shard=rank
)

# 4. Define saving topology
# The 'read_model_state' generator below will yield tensors that have 
# ALREADY been transformed to their target names/shapes.
# The writer just needs to accept these new keys and save them.
# We generate this identity mapper automatically from the output signature.
writer_mapper = identity_mapper_from_mapper_outputs(local_work_mapper)

# 5. Execute the pipeline
# - Reader: Loads specific source files, transforms, yields new tensors.
# - Writer: Receives new tensors, saves to '*.safetensors' files with temporary names
# - Finalizer: Rank 0 creates 'model.safetensors.index.json' covering all ranks and renames .safetensors files to their final names.
write_model_state_distributed(
    dest_dir=Path("./converted_checkpoint"),
    mapper=writer_mapper,
    state_generator=read_model_state(
        src_dir=Path("./original_checkpoint"),
        mapper=local_work_mapper, # Defines what to load and transformations
        device="cuda",
        show_progress=False # Disable read bars to avoid stderr spam
    ),
    process_group=dist.distributed_c10d._get_default_group(),
    shard_size_gb=4.0,
    show_progress=True # Master rank will show global save progress
)

PyTorch Module I/O - Streamed Loading

Loading a checkpoint where disk keys exactly match model keys. identity_mapper_from_module ensures only existing model parameters are loaded.

from pathlib import Path
from d9d.model_state.io import load_model_state
from d9d.model_state.mapper.adapters import identity_mapper_from_module

# 1. Setup Model (e.g., empty or on meta device)
model = ...

# 2. Create Identity Topology
# This tells d9d: "Load every key that exists in 'model' as is."
mapper = identity_mapper_from_module(model)

# 3. Stream & Inject
load_model_state(
    src_dir=Path("./checkpoints/v1"),
    mapper=mapper,
    device="cuda",
    model=model
)

PyTorch Module I/O - Streaming Saving (DeviceMesh)

Saves a model in a complex ND Parallel environment using PyTorch DeviceMesh.

This features:

  • DTensor Gathering: Automatically gathers DTensor shards from the mesh into full tensors before writing.
  • Concurrency Within PP Rank: In a Data/Tensor/... parallel setup, multiple GPUs hold replicated or sharded copies of the same parameters. This function uses the DeviceMesh to ensure that only the "canonical" PP replica (DP Rank 0, TP Rank 0, ...) writes to disk, preventing write conflicts.
  • Concurrency Across PP Ranks: Each PP rank writes the data into its own files. After all the PP ranks finish writing, PP Rank 0 merges the metadata from different PP ranks into a single global checkpoint index file.
from pathlib import Path

from torch.distributed.device_mesh import init_device_mesh
from d9d.model_state.io import save_model_state_pipeline_parallel
from d9d.model_state.mapper.compose import ModelStateMapperParallel
from d9d.model_state.mapper.adapters import identity_mapper_from_module

# 1. Setup 3D Mesh
# pp=2 (Pipeline), dp=2 (Data), tp=2 (Tensor)
mesh = init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "dp", "tp"))

# 2. Define Model Stages
# In this example, each PP rank manages two distinct parts of the model.
my_stages = [TransformerStage(...), TransformerStage(...)]

# 3. Create Topology
# Since this rank manages multiple modules, we create a Parallel mapper
# to combine the requirements of all stages.
mapper = ModelStateMapperParallel([
    identity_mapper_from_module(stage) for stage in my_stages
])

# 4. Save
# The system inspects the mesh. It identifies if the current rank is 
# the "Master" for the provided stages (i.e., dp_rank=0, tp_rank=0).
# If so, it gathers DTensors and writes. If not, it skips writing 
# but participates in the collective gather.
save_model_state_pipeline_parallel(
    dest_dir=Path("./checkpoint"),
    mapper=mapper,
    device_mesh=mesh,
    pipeline_dim_name="pp",
    models=my_stages,
    shard_size_gb=4.0
)

d9d.model_state.io

load_model_state(src_dir, mapper, device, model, show_progress=True)

High-level utility to stream a checkpoint directly into a PyTorch module.

This function orchestrates the full loading lifecycle:

  1. Topology Mapping: Uses mapper to rename/stack/reshape on-disk states to model states.

  2. Automatic Distribution: If the model contains DTensors, the loaded local tensors are automatically sharded/replicated to match the model's placement schema.

  3. Streaming Read & Inject: After loading and transforming a model state, it will be injected into model using load_state_dict(...).

NOTICE: Only states specified in mapper will be loaded! You can use d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will load every model state without changing it.

Parameters:

Name Type Description Default
src_dir Path

Directory containing .safetensors and index files.

required
mapper ModelStateMapper

The topology defining how mapping from disk keys to model keys works.

required
device str

The device to load tensors onto (usually "cpu" or "cuda").

required
model Module

The model instance to load weights into.

required
show_progress bool

Whether to display the loading progress bar.

True
Source code in d9d/model_state/io/module_reader.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def load_model_state(
        src_dir: Path,
        mapper: ModelStateMapper,
        device: str,
        model: nn.Module,
        show_progress: bool = True,
):
    """
    High-level utility to stream a checkpoint directly into a PyTorch module.

    This function orchestrates the full loading lifecycle:

    1.  Topology Mapping: Uses `mapper` to rename/stack/reshape on-disk states to model states.

    2.  Automatic Distribution: If the `model` contains `DTensor`s, the loaded local tensors are automatically
        sharded/replicated to match the model's placement schema.

    3.  Streaming Read & Inject: After loading and transforming a model state, it will be injected into `model`
        using `load_state_dict(...)`.

    NOTICE: Only states specified in `mapper` will be loaded! You can use
    `d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will load every
    model state without changing it.

    Args:
        src_dir: Directory containing .safetensors and index files.
        mapper: The topology defining how mapping from disk keys to model keys works.
        device: The device to load tensors onto (usually "cpu" or "cuda").
        model: The model instance to load weights into.
        show_progress: Whether to display the loading progress bar.
    """

    for state_name, state_value in read_model_state(
            src_dir=src_dir,
            mapper=_augment_mapper_for_injection(model, mapper),
            device=device,
            show_progress=show_progress
    ):
        model.load_state_dict({state_name: state_value}, strict=False)

read_model_state(src_dir, mapper, device, show_progress=True)

Reads a model checkpoint from disk, transforming it on-the-fly according to the state mapper.

This function uses a streaming approach. It analyzes the mapper to determine which files need to be loaded. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.

Parameters:

Name Type Description Default
src_dir Path

The directory containing .safetensors files and model.safetensors.index.json file.

required
mapper ModelStateMapper

The transformation graph defining how to map on-disk keys to output keys.

required
device str

The device to load tensors onto (e.g., "cpu", "cuda:0").

required
show_progress bool

Whether to display a progress bar.

True

Yields:

Type Description
Iterable[tuple[str, Tensor]]

A tuple containing the transformed parameter name and its tensor value.

Source code in d9d/model_state/io/reader.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def read_model_state(
        src_dir: Path,
        mapper: ModelStateMapper,
        device: str,
        show_progress: bool = True
) -> Iterable[tuple[str, torch.Tensor]]:
    """
    Reads a model checkpoint from disk, transforming it on-the-fly according to the state mapper.

    This function uses a streaming approach. It analyzes the mapper to determine which files
    need to be loaded. Tensors are loaded into memory only when needed and evicted immediately
    after the mapper processes them.

    Args:
        src_dir: The directory containing .safetensors files and `model.safetensors.index.json` file.
        mapper: The transformation graph defining how to map on-disk keys to output keys.
        device: The device to load tensors onto (e.g., "cpu", "cuda:0").
        show_progress: Whether to display a progress bar.

    Yields:
        A tuple containing the transformed parameter name and its tensor value.
    """

    yield from _StateLoadingFlow(
        src_dir=src_dir,
        device=device,
        mapper=mapper,
        show_progress=show_progress
    ).load()

save_model_state(dest_dir, mapper, model, shard_size_gb=4.0, show_progress=True)

High-level utility to save a PyTorch model to disk on a single process.

NOTICE: Only states specified in mapper will be saved! You can use d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will save every model state without changing it.

Parameters:

Name Type Description Default
dest_dir Path

The directory to save .safetensors shards and index.

required
mapper ModelStateMapper

Topology defining how model keys map to disk keys.

required
model Module

The PyTorch module to save.

required
shard_size_gb float

Max size per shard file in Gigabytes.

4.0
show_progress bool

Whether to display a progress bar.

True
Source code in d9d/model_state/io/module_writer.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def save_model_state(
        dest_dir: Path,
        mapper: ModelStateMapper,
        model: nn.Module,
        shard_size_gb: float = 4.0,
        show_progress: bool = True
):
    """
    High-level utility to save a PyTorch model to disk on a **single** process.

    NOTICE: Only states specified in `mapper` will be saved! You can use
    `d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will save every
    model state without changing it.

    Args:
        dest_dir: The directory to save .safetensors shards and index.
        mapper: Topology defining how model keys map to disk keys.
        model: The PyTorch module to save.
        shard_size_gb: Max size per shard file in Gigabytes.
        show_progress: Whether to display a progress bar.
    """

    write_model_state_local(
        dest_dir=dest_dir,
        mapper=_augment_mapper_for_extraction([model], mapper),
        state_generator=_state_generator([model]),
        shard_size_gb=shard_size_gb,
        show_progress=show_progress
    )

save_model_state_pipeline_parallel(dest_dir, mapper, device_mesh, pipeline_dim_name, models, shard_size_gb=4.0, show_progress=True)

High-level utility to save a model in a Distributed Pipeline Parallel environment to disk.

Features:

  1. Auto-Gather: Converts DTensor parameters to full tensors before saving.

  2. Distribution Awareness: Uses the device_mesh to ensure that for a given pipeline stage, only the master rank writes the checkpoint, preventing Write-After-Write conflicts.

  3. Index Merging: Aggregates metadata from all independent pipeline stages into one global index file.

NOTICE: Only states specified in mapper will be saved! You can use d9d.model_state.mapper.adapters.identity_mapper_from_module(module) to create a mapper that will save every model state without changing it.

Parameters:

Name Type Description Default
dest_dir Path

directory to save .safetensors shards and index file.

required
mapper ModelStateMapper

Topology defining how model keys map to disk keys.

required
device_mesh DeviceMesh

The cluster topology mesh.

required
pipeline_dim_name str

The specific dimension name in the mesh used for pipelining.

required
models list[Module]

A list of modules (pipeline stages) processed by this PP rank.

required
shard_size_gb float

Max size per shard file in Gigabytes.

4.0
show_progress bool

Whether to display a progress bar.

True
Source code in d9d/model_state/io/module_writer.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def save_model_state_pipeline_parallel(
        dest_dir: Path,
        mapper: ModelStateMapper,
        device_mesh: DeviceMesh,
        pipeline_dim_name: str,
        models: list[nn.Module],
        shard_size_gb: float = 4.0,
        show_progress: bool = True
):
    """
    High-level utility to save a model in a Distributed Pipeline Parallel environment to disk.

    Features:

    1. **Auto-Gather**: Converts `DTensor` parameters to full tensors before saving.

    2. **Distribution Awareness**: Uses the `device_mesh` to ensure that for a given pipeline stage,
       only the master rank writes the checkpoint, preventing Write-After-Write conflicts.

    3. **Index Merging**: Aggregates metadata from all independent pipeline stages into one global index file.

    NOTICE: Only states specified in `mapper` will be saved! You can use
    `d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will save every
    model state without changing it.

    Args:
        dest_dir: directory to save .safetensors shards and index file.
        mapper: Topology defining how model keys map to disk keys.
        device_mesh: The cluster topology mesh.
        pipeline_dim_name: The specific dimension name in the mesh used for pipelining.
        models: A list of modules (pipeline stages) processed by this PP rank.
        shard_size_gb: Max size per shard file in Gigabytes.
        show_progress: Whether to display a progress bar.
    """
    write_model_state_pipeline_parallel(
        dest_dir=dest_dir,
        mapper=_augment_mapper_for_extraction(models, mapper),
        state_generator=_state_generator(models),
        device_mesh=device_mesh,
        pipeline_dim_name=pipeline_dim_name,
        shard_size_gb=shard_size_gb,
        show_progress=show_progress
    )

write_model_state_distributed(dest_dir, mapper, state_generator, process_group, shard_size_gb=4.0, show_progress=True)

Saves model states in a distributed setup (multiple processes).

This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.

Each rank writes its own shard. Rank 0 gathers indices and finalizes the checkpoint.

Parameters:

Name Type Description Default
dest_dir Path

Destination directory.

required
mapper ModelStateMapper

Mapping to apply to states before saving.

required
state_generator Iterable[tuple[str, Tensor]]

Stream of (name, tensor) pairs from the model.

required
process_group ProcessGroup

The distributed process group.

required
shard_size_gb float

Maximum shard size in GB.

4.0
show_progress bool

Whether to show the progress bar.

True
Source code in d9d/model_state/io/writer.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def write_model_state_distributed(
        dest_dir: Path,
        mapper: ModelStateMapper,
        state_generator: Iterable[tuple[str, torch.Tensor]],
        process_group: ProcessGroup,
        shard_size_gb: float = 4.0,
        show_progress: bool = True
):
    """
    Saves model states in a distributed setup (multiple processes).

    This function uses a streaming approach. It analyzes the mapper to determine which files
    need to be saved. Tensors are loaded into memory only when needed and evicted immediately
    after the mapper processes them.

    Each rank writes its own shard. Rank 0 gathers indices and finalizes the checkpoint.

    Args:
        dest_dir: Destination directory.
        mapper: Mapping to apply to states before saving.
        state_generator: Stream of (name, tensor) pairs from the model.
        process_group: The distributed process group.
        shard_size_gb: Maximum shard size in GB.
        show_progress: Whether to show the progress bar.
    """

    current_idx = _StateWritingFlowLocal(
        dest_dir=dest_dir,
        mapper=mapper,
        shard_size_gb=shard_size_gb,
        show_progress=show_progress,
        sharding_rank=process_group.rank(),
        is_current_process_rank_master=True
    ).write(state_generator=state_generator)
    gather_idx = all_gather_object(current_idx, process_group)
    gather_idx_filter = [x for x in gather_idx if x is not None]
    if process_group.rank() == 0:
        _finalize_master(dest_dir, gather_idx_filter)

write_model_state_local(dest_dir, mapper, state_generator, shard_size_gb=4.0, show_progress=True)

Saves model states to disk in a single local process.

This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.

Parameters:

Name Type Description Default
dest_dir Path

Destination directory.

required
mapper ModelStateMapper

Mapping to apply to states before saving.

required
state_generator Iterable[tuple[str, Tensor]]

Stream of (name, tensor) pairs to save.

required
shard_size_gb float

Maximum size of a single .safetensors file in GB.

4.0
show_progress bool

Whether to show the progress bar.

True
Source code in d9d/model_state/io/writer.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def write_model_state_local(
        dest_dir: Path,
        mapper: ModelStateMapper,
        state_generator: Iterable[tuple[str, torch.Tensor]],
        shard_size_gb: float = 4.0,
        show_progress: bool = True
):
    """
    Saves model states to disk in a single local process.

    This function uses a streaming approach. It analyzes the mapper to determine which files
    need to be saved. Tensors are loaded into memory only when needed and evicted immediately
    after the mapper processes them.

    Args:
        dest_dir: Destination directory.
        mapper: Mapping to apply to states before saving.
        state_generator: Stream of (name, tensor) pairs to save.
        shard_size_gb: Maximum size of a single .safetensors file in GB.
        show_progress: Whether to show the progress bar.
    """
    idx = _StateWritingFlowLocal(
        dest_dir=dest_dir,
        mapper=mapper,
        shard_size_gb=shard_size_gb,
        show_progress=show_progress,
        sharding_rank=0,
        is_current_process_rank_master=True
    ).write(state_generator=state_generator)

    idx = cast(ModelStateIndex, idx)  # we are sure is_current_process_rank_master=True

    _finalize_master(dest_dir, [idx])

write_model_state_pipeline_parallel(dest_dir, mapper, state_generator, device_mesh, pipeline_dim_name, shard_size_gb=4.0, show_progress=True)

Saves model states in a complex ND distributed training setting.

This function uses a streaming approach. It analyzes the mapper to determine which files need to be saved. Tensors are loaded into memory only when needed and evicted immediately after the mapper processes them.

This handles Pipeline Parallelism by ensuring that only one rank per pipeline stage actually writes data to disk to avoid duplication.

Parameters:

Name Type Description Default
dest_dir Path

Destination directory.

required
mapper ModelStateMapper

Mapping to apply to states before saving.

required
state_generator Iterable[tuple[str, Tensor]]

Stream of (name, tensor) pairs from the model.

required
device_mesh DeviceMesh

The PyTorch DeviceMesh representing the cluster layout.

required
pipeline_dim_name str

The name of the mesh dimension responsible for pipeline parallelism.

required
shard_size_gb float

Maximum shard size in GB.

4.0
show_progress bool

Whether to show the progress bar.

True
Source code in d9d/model_state/io/writer.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def write_model_state_pipeline_parallel(
        dest_dir: Path,
        mapper: ModelStateMapper,
        state_generator: Iterable[tuple[str, torch.Tensor]],
        device_mesh: DeviceMesh,
        pipeline_dim_name: str,
        shard_size_gb: float = 4.0,
        show_progress: bool = True
):
    """
    Saves model states in a complex ND distributed training setting.

    This function uses a streaming approach. It analyzes the mapper to determine which files
    need to be saved. Tensors are loaded into memory only when needed and evicted immediately
    after the mapper processes them.

    This handles Pipeline Parallelism by ensuring that only one rank per pipeline stage
    actually writes data to disk to avoid duplication.

    Args:
        dest_dir: Destination directory.
        mapper: Mapping to apply to states before saving.
        state_generator: Stream of (name, tensor) pairs from the model.
        device_mesh: The PyTorch DeviceMesh representing the cluster layout.
        pipeline_dim_name: The name of the mesh dimension responsible for pipeline parallelism.
        shard_size_gb: Maximum shard size in GB.
        show_progress: Whether to show the progress bar.
    """

    pipeline_rank = device_mesh[pipeline_dim_name].get_rank()

    mesh_dim_names = device_mesh.mesh_dim_names
    coords = device_mesh.get_coordinate()
    if mesh_dim_names is None or coords is None:
        raise ValueError("Cannot save state using a DeviceMesh with no dim names or coords")

    non_pipeline_coord_sum = sum(
        coord
        for name, coord
        in zip(mesh_dim_names, coords, strict=True)
        if name != pipeline_dim_name
    )
    master_within_pipeline_rank = non_pipeline_coord_sum == 0

    current_idx = _StateWritingFlowLocal(
        dest_dir=dest_dir,
        mapper=mapper,
        shard_size_gb=shard_size_gb,
        show_progress=show_progress,
        sharding_rank=pipeline_rank,
        is_current_process_rank_master=master_within_pipeline_rank
    ).write(state_generator=state_generator)
    gather_idx = all_gather_object(current_idx, device_mesh.get_group(0))
    gather_idx_filter = [x for x in gather_idx if x is not None]
    if pipeline_rank == 0 and master_within_pipeline_rank:
        _finalize_master(dest_dir, gather_idx_filter)