Model Definition
ModelProvider
The ModelProvider controls the lifecycle of the nn.Module. In distributed training, models are rarely just "instantiated".
They must be initialized, parallelized, and mapped for loading from checkpoint.
How to Write a ModelProvider
Choose a Model
Choose a model from d9d's catalogue or create it by your own.
Implement initialize_model_stage(...)
Implement the initialize_model_stage(...) method - it should prepare a nn.Module for specified pipeline parallel stage containing model architecture in a target torch.dtype.
Note that models are initialized on meta device, so you must not load model weights here.
Instead, this function should return a State Mapper that will map model weights on disk to model weights in-memory.
You also may apply PEFT methods here and other architectural patches, but make sure you respect the changes they made in returned State Mapper.
Implement parallelize_model_stage(...)
Implement the parallelize_model_stage(...) method - it should apply Horizontal Parallelism strategy for selected model in-place.
If you use one of d9d's models, you may use default strategies for them such as parallelize_qwen3_moe_for_causal_lm (reference).
For a custom model, please see Horizontal Parallelism docs and reference implementations.
Implement prepare_export_model_stage(...)
Implement the prepare_export_model_stage(...) method - it should return a State Mapper
that converts in-memory model state to that one that will be saved on disk during final export.
Basically, it should reverse all the operations of State Mapper produced in initialize_model_stage(...).
Example Implementation
d9d.loop.control.model_provider
InitializeModelStageContext
dataclass
Context data required for initializing a specific model pipeline stage.
Attributes:
| Name | Type | Description |
|---|---|---|
dist_context |
DistributedContext
|
The distributed execution context. |
stage |
PipelineStageInfo
|
Metadata describing the current pipeline stage being initialized. |
InitializeModelStageResult
dataclass
Bases: Generic[TModel]
The result of initializing a model stage.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
TModel
|
The PyTorch module. |
state_mapper |
ModelStateMapper
|
The mapper defining how to load weights into this module. |
ModelProvider
Abstract interface for defining the lifecycle of a distributed model.
This provider handles initialization, parallelization (sharding/replication/etc), and export preparation for models within the d9d framework.
dump_hparams()
Exports hyperparameters associated with this model for logging.
Returns:
| Type | Description |
|---|---|
ScalarTree
|
A dictionary of hyperparameter names and values. |
initialize_model_stage(context)
abstractmethod
Initializes the model architecture for a specific pipeline stage.
This method is responsible for constructing the nn.Module for the requested stage.
Construction occurs within a meta-device context; therefore, weights
should not be loaded directly here. Instead, a ModelStateMapper must be returned
to define how weights from a checkpoint map to the newly created module parameters.
This allows for architecture modifications, such as injecting LoRA adapters, provided that the returned mapper reflects the new structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context
|
InitializeModelStageContext
|
Context for this operation. |
required |
Returns:
| Type | Description |
|---|---|
InitializeModelStageResult[TModel]
|
Result of this operation. |
parallelize_model_stage(context)
abstractmethod
Converts the model parameters into distributed tensors (DTensors).
Implementations should modify the model in-place. This involves converting standard parameters into DTensors by replicating or sharding them according to the desired parallelism strategies.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context
|
ParallelizeModelStageContext[TModel]
|
Context for this operation. |
required |
prepare_export_model_stage(context)
abstractmethod
Prepares the state mapper required for saving the model to disk.
This methods defines how the current in-memory model structure maps back to the serialized checkpoint format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context
|
PrepareExportModelStageContext[TModel]
|
Context for this operation. |
required |
Returns:
| Type | Description |
|---|---|
PrepareExportModelStageResult
|
Result of this operation. |
register_events(context)
Register model-specific event subscriptions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context
|
RegisterModelEventsContext
|
Context providing access to the distributed environment, the built model modules, and the event bus. |
required |
ParallelizeModelStageContext
dataclass
Bases: Generic[TModel]
Context data required for horizontally parallelizing a model stage.
Attributes:
| Name | Type | Description |
|---|---|---|
dist_context |
DistributedContext
|
The distributed execution context. |
stage |
PipelineStageInfo
|
Metadata describing the current pipeline stage. |
model |
TModel
|
The PyTorch module to be parallelized. |
PrepareExportModelStageContext
dataclass
Bases: Generic[TModel]
Context data required for preparing a model stage for export.
Attributes:
| Name | Type | Description |
|---|---|---|
dist_context |
DistributedContext
|
The distributed execution context. |
model |
TModel
|
The PyTorch module to be exported. |
PrepareExportModelStageResult
dataclass
The result of preparing a model stage for export.
Attributes:
| Name | Type | Description |
|---|---|---|
state_mapper |
ModelStateMapper
|
The mapper defining how model parameters map to disk storage. |
RegisterModelEventsContext
dataclass
Context for registering model-specific events.
Attributes:
| Name | Type | Description |
|---|---|---|
dist_context |
DistributedContext
|
The distributed execution context. |
event_bus |
EventBus
|
The event bus for subscribing to events. |