Bring Your own Model

d9d does not enforce you to use its model implementations. You are eager to use own custom implementations of any model you want, optionally using high-performant d9d's building blocks.

Just make sure to follow the main design principles described below.

Main Principles

d9d opts for a "white-box" approach to modelling. We avoid heavy abstraction layers in favor of readable, standard PyTorch code.

No LayerSpecs

Some distributed frameworks force users to define models via metadata specification objects to inject wrapping logic (like FSDP or Checkpointing) automatically. This makes debugging difficult.

In d9d, you write standard nn.Module classes. Use nn.linear, nn.RMSNorm, or d9d's optimized blocks directly. Distributed wrapping logic is handled transparently, maintaining the standard PyTorch look and feel.

Explicit Composition

We avoid creating "Uber-Modules" - single, massive classes (e.g., GenericTransformerBlock) that handle every possible architectural variation (MoE, Dense, Post-Norm, Pre-Norm, Parallel Dense-Attention) via dozens of flags and parameters.

Instead, d9d promotes explicit composition like HuggingFace Transformers does. This composition makes the call stack distinct and the logic for a specific architecture easy to trace.

Pipelining-Aware Models

Please see Pipelining API.

Late Initialization

Constructing a large model on a single GPU (or even CPU RAM) often leads to immediate Out-Of-Memory (OOM) errors. d9d solves this via the ModuleLateInit protocol.

It is safe to use modules implementing this protocol with d9d's native Trainer framework.

The Trainer will instantiate modules on the meta device (consuming no memory), lay out the distributed topology and sharding strategy.

Only then reset_parameters() is called to materialize model weights without allocating unnecessary things.

Reference Implementations

For reference implementations, please see Qwen3-MoE.

d9d.module.base

Defines structural protocols and base classes for PyTorch modules used within the d9d framework.

ModuleLateInit

Bases: Protocol

Protocol for modules that support late parameter initialization.

Source code in d9d/module/base/late_init.py
 5
 6
 7
 8
 9
10
@typing.runtime_checkable
class ModuleLateInit(Protocol):
    """Protocol for modules that support late parameter initialization."""

    def reset_parameters(self):
        """Resets the module parameters (i.e. performs random initialization)."""

reset_parameters()

Resets the module parameters (i.e. performs random initialization).

Source code in d9d/module/base/late_init.py
 9
10
def reset_parameters(self):
    """Resets the module parameters (i.e. performs random initialization)."""