Skip to content

Model Definition

ModelProvider

The ModelProvider controls the lifecycle of the nn.Module. In distributed training, models are rarely just "instantiated".

They must be initialized, parallelized, and mapped for loading from checkpoint.

How to Write a ModelProvider

Choose a Model

Choose a model from d9d's catalogue or create it by your own.

Implement initialize_model_stage(...)

Implement the initialize_model_stage(...) method - it should prepare a nn.Module for specified pipeline parallel stage containing model architecture in a target torch.dtype.

Note that models are initialized on meta device, so you must not load model weights here.

Instead, this function should return a State Mapper that will map model weights on disk to model weights in-memory.

You also may apply PEFT methods here and other architectural patches, but make sure you respect the changes they made in returned State Mapper.

Implement parallelize_model_stage(...)

Implement the parallelize_model_stage(...) method - it should apply Horizontal Parallelism strategy for selected model in-place.

If you use one of d9d's models, you may use default strategies for them such as parallelize_qwen3_moe_for_causal_lm (reference).

For a custom model, please see Horizontal Parallelism docs and reference implementations.

Implement prepare_export_model_stage(...)

Implement the prepare_export_model_stage(...) method - it should return a State Mapper that converts in-memory model state to that one that will be saved on disk during final export.

Basically, it should reverse all the operations of State Mapper produced in initialize_model_stage(...).

Example Implementation

from pydantic import BaseModel
from d9d.loop.control.model_provider import *
from d9d.module.model.qwen3_moe import Qwen3MoEForCausalLM, Qwen3MoEForCausalLMParameters
from d9d.module.parallelism.model.qwen3_moe import parallelize_qwen3_moe_for_causal_lm
from d9d.module.block.hidden_states_aggregator import HiddenStatesAggregationMode
from d9d.model_state.mapper.adapters import identity_mapper_from_module


class ModelProviderConfig(BaseModel):
    model: Qwen3MoEForCausalLMParameters  # Hyperparameters for Qwen3 MoE
    checkpointing: bool  # Enable gradient checkpointing to save VRAM


class ProjectModelProvider(ModelProvider[Qwen3MoEForCausalLM]):
    def __init__(self, config: ModelProviderConfig):
        self._config = config

    def initialize_model_stage(self, context: InitializeModelStageContext) -> InitializeModelStageResult:
        # Initialize the raw model on Meta device in BF16 precision
        model = Qwen3MoEForCausalLM(
            params=self._config.model,
            stage=context.stage,
            hidden_states_snapshot_mode=HiddenStatesAggregationMode.no,
            enable_checkpointing=self._config.checkpointing
        ).bfloat16()

        return InitializeModelStageResult(
            model=model,
            state_mapper=identity_mapper_from_module(model)
        )

    def parallelize_model_stage(self, context: ParallelizeModelStageContext):
        # Applies specific distributed strategies
        # suited for Qwen3 MoE architecture.
        # You can apply your own horizontal parallelism strategy here.
        parallelize_qwen3_moe_for_causal_lm(
            dist_context=context.dist_context,
            stage=context.stage,
            model=context.model
        )

    def prepare_export_model_stage(self, context: PrepareExportModelStageContext) -> PrepareExportModelStageResult:
        # When exporting, save model weights as-is

        return PrepareExportModelStageResult(
            state_mapper=identity_mapper_from_module(context.model)
        )

    def dump_hparams(self) -> ScalarTree:
        return self._config.model_dump(mode="json")

d9d.loop.control.model_provider

InitializeModelStageContext dataclass

Context data required for initializing a specific model pipeline stage.

Attributes:

Name Type Description
dist_context DistributedContext

The distributed execution context.

stage PipelineStageInfo

Metadata describing the current pipeline stage being initialized.

InitializeModelStageResult dataclass

Bases: Generic[TModel]

The result of initializing a model stage.

Attributes:

Name Type Description
model TModel

The PyTorch module.

state_mapper ModelStateMapper

The mapper defining how to load weights into this module.

ModelProvider

Bases: ABC, Generic[TModel]

Abstract interface for defining the lifecycle of a distributed model.

This provider handles initialization, parallelization (sharding/replication/etc), and export preparation for models within the d9d framework.

dump_hparams()

Exports hyperparameters associated with this model for logging.

Returns:

Type Description
ScalarTree

A dictionary of hyperparameter names and values.

initialize_model_stage(context) abstractmethod

Initializes the model architecture for a specific pipeline stage.

This method is responsible for constructing the nn.Module for the requested stage.

Construction occurs within a meta-device context; therefore, weights should not be loaded directly here. Instead, a ModelStateMapper must be returned to define how weights from a checkpoint map to the newly created module parameters.

This allows for architecture modifications, such as injecting LoRA adapters, provided that the returned mapper reflects the new structure.

Parameters:

Name Type Description Default
context InitializeModelStageContext

Context for this operation.

required

Returns:

Type Description
InitializeModelStageResult[TModel]

Result of this operation.

parallelize_model_stage(context) abstractmethod

Converts the model parameters into distributed tensors (DTensors).

Implementations should modify the model in-place. This involves converting standard parameters into DTensors by replicating or sharding them according to the desired parallelism strategies.

Parameters:

Name Type Description Default
context ParallelizeModelStageContext[TModel]

Context for this operation.

required

prepare_export_model_stage(context) abstractmethod

Prepares the state mapper required for saving the model to disk.

This methods defines how the current in-memory model structure maps back to the serialized checkpoint format.

Parameters:

Name Type Description Default
context PrepareExportModelStageContext[TModel]

Context for this operation.

required

Returns:

Type Description
PrepareExportModelStageResult

Result of this operation.

register_events(context)

Register model-specific event subscriptions.

Parameters:

Name Type Description Default
context RegisterModelEventsContext

Context providing access to the distributed environment, the built model modules, and the event bus.

required

ParallelizeModelStageContext dataclass

Bases: Generic[TModel]

Context data required for horizontally parallelizing a model stage.

Attributes:

Name Type Description
dist_context DistributedContext

The distributed execution context.

stage PipelineStageInfo

Metadata describing the current pipeline stage.

model TModel

The PyTorch module to be parallelized.

PrepareExportModelStageContext dataclass

Bases: Generic[TModel]

Context data required for preparing a model stage for export.

Attributes:

Name Type Description
dist_context DistributedContext

The distributed execution context.

model TModel

The PyTorch module to be exported.

PrepareExportModelStageResult dataclass

The result of preparing a model stage for export.

Attributes:

Name Type Description
state_mapper ModelStateMapper

The mapper defining how model parameters map to disk storage.

RegisterModelEventsContext dataclass

Context for registering model-specific events.

Attributes:

Name Type Description
dist_context DistributedContext

The distributed execution context.

event_bus EventBus

The event bus for subscribing to events.