Training Loop
Overview
The d9d.loop package provides the execution engine for distributed training.
The d9d Trainer separates the definition of the job (Models, Tasks, Data) from the execution of the job (Synchronization, Checkpointing, Profiling).
This allows the same code to run on a single GPU or a 1000-GPU Pipeline Parallel cluster without modifications.
Example
Configuration & Construction
To ensure reproducibility, the Trainer is not instantiated directly with loose objects. It is built using the TrainingConfigurator and the dependency injection pattern.
This class binds the
- Infrastructure Configuration,
- Job Configuration,
- and User Logic (Providers)
into a Trainer object with prepared TrainJobState.
d9d.loop.run.TrainingConfigurator
Orchestrates the assembly of the distributed training environment.
This class binds the infrastructure configuration (DeviceMesh), the training parameters (TrainerConfig), and the user-defined logic (Providers) to create a fully initialized state object capable of running the training loop.
__init__(mesh, parameters, task_provider, model_provider, data_provider, optimizer_provider, lr_scheduler_provider)
Constructs a configurator capable of building the full training state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh
|
DeviceMeshParameters
|
Definition of the distributed device mesh topology. |
required |
parameters
|
TrainerConfig
|
The global configuration object for the trainer. |
required |
task_provider
|
TrainTaskProvider
|
Factory for creating the training task logic. |
required |
model_provider
|
ModelProvider
|
Factory for defining and creating model stages. |
required |
data_provider
|
DatasetProvider
|
Factory for providing training datasets. |
required |
optimizer_provider
|
OptimizerProvider
|
Factory for creating the optimizer. |
required |
lr_scheduler_provider
|
LRSchedulerProvider
|
Factory for creating the learning rate scheduler. |
required |
configure()
Instantiates all training components and returns a configured Trainer.
This method triggers the creation of the distributed context, sets seeds, builds the model, optimizer, data loaders, and attaches all auxiliary components (logging, profiling, checkpointing).
Returns:
| Name | Type | Description |
|---|---|---|
Trainer |
Trainer
|
A ready-to-use trainer instance encapsulating the job state. |
The Configuration Lifecycle
The TrainingConfigurator.configure() method does:
-
Distributed Context Initialization:
- Constructs the global DistributedContext, therefore initializing all the required NCCL process groups and
DeviceMeshes.
- Constructs the global DistributedContext, therefore initializing all the required NCCL process groups and
-
Seeding:
- Sets distributed seeds using the configured
base_seed. This ensures model initialization and other initial states are deterministic. More info.
- Sets distributed seeds using the configured
-
Event Bus Initialization:
- Creates the global
EventBus. Tasks and providers use this to register custom hooks. - Triggers
EVENT_TRAIN_CONFIG_STARTEDevent.
- Creates the global
-
Task Instantiation:
- Instantiates the
TrainTaskobject using specifiedTrainTaskProvider.
- Instantiates the
-
Data Loader Construction:
- Calls the
DatasetProviderto get the dataset and wraps it into aDataLoader. - The DataLoader will move all the Tensor data to this worker's device automatically.
- Triggers
EVENT_TRAIN_DATA_LOADER_READYevent.
- Calls the
-
Model Materialization:
- The
ModelStageFactoryruns. This is the heavy lifting of initialization:- Meta Init:
ModelProvidercreates the model on themetadevice (no memory usage). - Parallelization:
ModelProviderappliesDTensorsharding/replication to parameters. - Materialization: Empty tensors are allocated on the actual GPU.
- Wait: Hard barrier to ensure all ranks allocated memory successfully.
- Parameter Reset:
model.reset_parameters()is called to generate random weights on GPU. - Source Loading (Optional): If configured, a pretrained checkpoint (e.g., from HF) is streamed into the model using
ModelStateMapper.
- Meta Init:
- The
-
Triggers
EVENT_TRAIN_MODEL_STAGES_READYevent. -
Optimizer and LR Scheduler Setup:
OptimizerFactoryiterates over the model parameters.- Calls
OptimizerProviderandLRSchedulerProvider. - Triggers both
EVENT_TRAIN_OPTIMIZER_READYandEVENT_TRAIN_LR_SCHEDULER_READYevents.
-
State Assembly:
- All components (including internal ones) are packed into the
TrainJobState. - The
Traineris instantiated with this state and returned.
- All components (including internal ones) are packed into the
Execution
To run a train job, just call the .train() method on a Trainer object that is returned by configuration process.
d9d.loop.run.Trainer
The main execution engine for running a distributed training job.
This class manages the training loop, lifecycle events, distributed synchronization, and periodic side-effects (logging, checkpointing).
__init__(state)
Constructs a Trainer from a pre-built job state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
TrainJobState
|
The encapsulated state object containing all initialized components (model, optimizer, dist_context, etc.). |
required |
export(export_to, load_checkpoint)
Exports the current model state to the specified directory.
This handles the distributed saving logic, allowing the model to be reconstituted later or used for inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
export_to
|
Path
|
The directory path where the model artifacts will be saved. |
required |
load_checkpoint
|
bool
|
If True, attempts to load the latest checkpoint into the model before exporting. |
required |
train()
Executes the full training workflow.
The Training Lifecycle
The Trainer.train() method orchestrates the following lifecycle. It is critical to understand this flow when debugging distributed issues or checking for side effects.
1. Initialization & Recovery
Before the loop starts:
- Global Synchronization: The trainer waits for all ranks to come online (
barrier). - State Loading: The
StateCheckpointerchecks the filesystem.- If a checkpoint exists, it loads it into all the
Statefulobjects inside its_state. - If no checkpoint exists, it starts from the first step.
- If a checkpoint exists, it loads it into all the
- Context Entry: The trainer enters several context managers:
- UI: Renders a progress bar.
- Logging: Initiates a new run in selected experiment tracker and dumps run hyperparameters there. More info.
- Garbage Collector: Disables automatic Python garbage collection.
- Profiler: Starts
torch.profilerhooks. More info. - Gradient Manager: Sets up backward hooks for synchronizing gradient states by all-reduce.
- Gradient Clipper: Looks for model parameters which gradients will be registered for clipping.
- Ready Hook Trigger:
EVENT_TRAIN_READYis fired to mark the start of the primary train sequence.
2. The Step Loop
For every global step (step), the trainer performs the following actions in strict order:
- Triggers
EVENT_TRAIN_STEP_PREevent. -
Microbatch Execution
- Triggers
EVENT_TRAIN_FORWARD_BACKWARD_PREevent. - The
DataLoaderyields a "Batch Group" containing \(N\) microbatches (calculated automatically based onBatchingConfig). - We delegate to the
TrainTaskfor mapping data before feeding it into the model. - The gradients will be accumulated locally using either regular multiple forward-backward calls if pipeline parallelism is disabled, either using our internal pipelining API. We delegate to
TrainTaskto compute loss values between forward and backward passes. - Last gradient accumulation triggers all-reduce synchronization. Communications may start overlapping here.
- We delegate to
TrainTaskto accumulate local metrics (e.g., token counts, accuracy) into theMetricstate. - Triggers
EVENT_TRAIN_FORWARD_BACKWARD_POSTevent.
- Triggers
-
Metric Synchronization
- Metric Sync Trigger:
JobLoggertriggers an async reduction of all metrics across the world. More info.
- Metric Sync Trigger:
-
Gradient Synchronization
- Wait & Scale: The
GradientManagerwaits for all backward hooks to finish. It synchronizes the total weighted loss across the world to determine the scaling factor, then divides all gradients by this factor (essential for correct averaging when batch sizes vary due to masking/packing). More info.
- Wait & Scale: The
-
Gradient Clipping
- The
GradientClippercalculates the global L2 norm of all parameters. - If
max_normis set, gradients are modified in-place. - The total norm is logged.
- More info.
- The
-
Optimization
- Triggers
EVENT_TRAIN_OPTIMIZER_STEP_PREevent. - Step: The
Optimizerupdates model parameters. - Schedule: The
LRSchedulerupdates the learning rate for the next step. - Zero Grad: The
GradientManagerclears gradients for the next iteration. - Triggers
EVENT_TRAIN_OPTIMIZER_STEP_POSTevent.
- Triggers
-
Logging & Maintenance
- Log: Metrics are finalized and written to the tracker.
- GC:
ManualGarbageCollectorruns if the current step matches the GC period. - Event-Based Logic: Triggers
EVENT_TRAIN_STEP_POSTevent. - Advance: The
Stepperincrements the step count.
-
Checkpointing
- If the current step matches
checkpointing.period_steps, checkpointing is triggered. This acts as a global barrier.
- If the current step matches
3. Finalization
- Event-specific: The system triggers
EVENT_TRAIN_FINISHEDevent. - Task-specific: We delegate to the
TrainTaskto do its specific finalization work.