Skip to content

Training Loop

Overview

The d9d.loop package provides the execution engine for distributed training.

The d9d Trainer separates the definition of the job (Models, Tasks, Data) from the execution of the job (Synchronization, Checkpointing, Profiling).

This allows the same code to run on a single GPU or a 1000-GPU Pipeline Parallel cluster without modifications.

Example

from d9d.loop.run import TrainingConfigurator

# Configure
trainer = TrainingConfigurator(
    mesh=mesh_params,                  # Physical cluster layout
    parameters=config,                 # Logic configuration (batch size, etc)

    # --- User Logic ---
    model_provider=...,                # How to build the model
    task_provider=...,                 # How to compute loss
    data_provider=...,                 # How to load data
    optimizer_provider=...,            # How to optimize
    lr_scheduler_provider=...          # LR scheduler
).configure()

# Execute
trainer.train()

Configuration & Construction

To ensure reproducibility, the Trainer is not instantiated directly with loose objects. It is built using the TrainingConfigurator and the dependency injection pattern.

This class binds the

into a Trainer object with prepared TrainJobState.

d9d.loop.run.TrainingConfigurator

Orchestrates the assembly of the distributed training environment.

This class binds the infrastructure configuration (DeviceMesh), the training parameters (TrainerConfig), and the user-defined logic (Providers) to create a fully initialized state object capable of running the training loop.

__init__(mesh, parameters, task_provider, model_provider, data_provider, optimizer_provider, lr_scheduler_provider)

Constructs a configurator capable of building the full training state.

Parameters:

Name Type Description Default
mesh DeviceMeshParameters

Definition of the distributed device mesh topology.

required
parameters TrainerConfig

The global configuration object for the trainer.

required
task_provider TrainTaskProvider

Factory for creating the training task logic.

required
model_provider ModelProvider

Factory for defining and creating model stages.

required
data_provider DatasetProvider

Factory for providing training datasets.

required
optimizer_provider OptimizerProvider

Factory for creating the optimizer.

required
lr_scheduler_provider LRSchedulerProvider

Factory for creating the learning rate scheduler.

required

configure()

Instantiates all training components and returns a configured Trainer.

This method triggers the creation of the distributed context, sets seeds, builds the model, optimizer, data loaders, and attaches all auxiliary components (logging, profiling, checkpointing).

Returns:

Name Type Description
Trainer Trainer

A ready-to-use trainer instance encapsulating the job state.

The Configuration Lifecycle

The TrainingConfigurator.configure() method does:

  1. Distributed Context Initialization:

    • Constructs the global DistributedContext, therefore initializing all the required NCCL process groups and DeviceMeshes.
  2. Seeding:

    • Sets distributed seeds using the configured base_seed. This ensures model initialization and other initial states are deterministic. More info.
  3. Event Bus Initialization:

    • Creates the global EventBus. Tasks and providers use this to register custom hooks.
    • Triggers EVENT_TRAIN_CONFIG_STARTED event.
  4. Task Instantiation:

    • Instantiates the TrainTask object using specified TrainTaskProvider.
  5. Data Loader Construction:

    • Calls the DatasetProvider to get the dataset and wraps it into a DataLoader.
    • The DataLoader will move all the Tensor data to this worker's device automatically.
    • Triggers EVENT_TRAIN_DATA_LOADER_READY event.
  6. Model Materialization:

    • The ModelStageFactory runs. This is the heavy lifting of initialization:
      1. Meta Init: ModelProvider creates the model on the meta device (no memory usage).
      2. Parallelization: ModelProvider applies DTensor sharding/replication to parameters.
      3. Materialization: Empty tensors are allocated on the actual GPU.
      4. Wait: Hard barrier to ensure all ranks allocated memory successfully.
      5. Parameter Reset: model.reset_parameters() is called to generate random weights on GPU.
      6. Source Loading (Optional): If configured, a pretrained checkpoint (e.g., from HF) is streamed into the model using ModelStateMapper.
  7. Triggers EVENT_TRAIN_MODEL_STAGES_READY event.

  8. Optimizer and LR Scheduler Setup:

    • OptimizerFactory iterates over the model parameters.
    • Calls OptimizerProvider and LRSchedulerProvider.
    • Triggers both EVENT_TRAIN_OPTIMIZER_READY and EVENT_TRAIN_LR_SCHEDULER_READY events.
  9. State Assembly:

    • All components (including internal ones) are packed into the TrainJobState.
    • The Trainer is instantiated with this state and returned.

Execution

To run a train job, just call the .train() method on a Trainer object that is returned by configuration process.

d9d.loop.run.Trainer

The main execution engine for running a distributed training job.

This class manages the training loop, lifecycle events, distributed synchronization, and periodic side-effects (logging, checkpointing).

__init__(state)

Constructs a Trainer from a pre-built job state.

Parameters:

Name Type Description Default
state TrainJobState

The encapsulated state object containing all initialized components (model, optimizer, dist_context, etc.).

required

export(export_to, load_checkpoint)

Exports the current model state to the specified directory.

This handles the distributed saving logic, allowing the model to be reconstituted later or used for inference.

Parameters:

Name Type Description Default
export_to Path

The directory path where the model artifacts will be saved.

required
load_checkpoint bool

If True, attempts to load the latest checkpoint into the model before exporting.

required

train()

Executes the full training workflow.

The Training Lifecycle

The Trainer.train() method orchestrates the following lifecycle. It is critical to understand this flow when debugging distributed issues or checking for side effects.

1. Initialization & Recovery

Before the loop starts:

  1. Global Synchronization: The trainer waits for all ranks to come online (barrier).
  2. State Loading: The StateCheckpointer checks the filesystem.
    • If a checkpoint exists, it loads it into all the Stateful objects inside its _state.
    • If no checkpoint exists, it starts from the first step.
  3. Context Entry: The trainer enters several context managers:
    • UI: Renders a progress bar.
    • Logging: Initiates a new run in selected experiment tracker and dumps run hyperparameters there. More info.
    • Garbage Collector: Disables automatic Python garbage collection.
    • Profiler: Starts torch.profiler hooks. More info.
    • Gradient Manager: Sets up backward hooks for synchronizing gradient states by all-reduce.
    • Gradient Clipper: Looks for model parameters which gradients will be registered for clipping.
  4. Ready Hook Trigger: EVENT_TRAIN_READY is fired to mark the start of the primary train sequence.

2. The Step Loop

For every global step (step), the trainer performs the following actions in strict order:

  1. Triggers EVENT_TRAIN_STEP_PRE event.
  2. Microbatch Execution

    • Triggers EVENT_TRAIN_FORWARD_BACKWARD_PRE event.
    • The DataLoader yields a "Batch Group" containing \(N\) microbatches (calculated automatically based on BatchingConfig).
    • We delegate to the TrainTask for mapping data before feeding it into the model.
    • The gradients will be accumulated locally using either regular multiple forward-backward calls if pipeline parallelism is disabled, either using our internal pipelining API. We delegate to TrainTask to compute loss values between forward and backward passes.
    • Last gradient accumulation triggers all-reduce synchronization. Communications may start overlapping here.
    • We delegate to TrainTask to accumulate local metrics (e.g., token counts, accuracy) into the Metric state.
    • Triggers EVENT_TRAIN_FORWARD_BACKWARD_POST event.
  3. Metric Synchronization

    • Metric Sync Trigger: JobLogger triggers an async reduction of all metrics across the world. More info.
  4. Gradient Synchronization

    • Wait & Scale: The GradientManager waits for all backward hooks to finish. It synchronizes the total weighted loss across the world to determine the scaling factor, then divides all gradients by this factor (essential for correct averaging when batch sizes vary due to masking/packing). More info.
  5. Gradient Clipping

    • The GradientClipper calculates the global L2 norm of all parameters.
    • If max_norm is set, gradients are modified in-place.
    • The total norm is logged.
    • More info.
  6. Optimization

    • Triggers EVENT_TRAIN_OPTIMIZER_STEP_PRE event.
    • Step: The Optimizer updates model parameters.
    • Schedule: The LRScheduler updates the learning rate for the next step.
    • Zero Grad: The GradientManager clears gradients for the next iteration.
    • Triggers EVENT_TRAIN_OPTIMIZER_STEP_POST event.
  7. Logging & Maintenance

    • Log: Metrics are finalized and written to the tracker.
    • GC: ManualGarbageCollector runs if the current step matches the GC period.
    • Event-Based Logic: Triggers EVENT_TRAIN_STEP_POST event.
    • Advance: The Stepper increments the step count.
  8. Checkpointing

    • If the current step matches checkpointing.period_steps, checkpointing is triggered. This acts as a global barrier.

3. Finalization

  1. Event-specific: The system triggers EVENT_TRAIN_FINISHED event.
  2. Task-specific: We delegate to the TrainTask to do its specific finalization work.