Inference Loop
Overview
The d9d.loop package provides the execution engine not only for training but also for high-scale distributed inference.
The inference engine shares the same core philosophy as the Trainer: separating the definition of the job from the execution.
Example
Configuration & Construction
The inference environment is assembled using the InferenceConfigurator.
This class binds the infrastructure and user logic into a ready-to-execute Inference object.
d9d.loop.run.InferenceConfigurator
Orchestrates the assembly of the distributed inference environment.
This class binds the infrastructure configuration (DeviceMesh), the inference parameters, and the user-defined logic (Providers) to create a fully initialized state object capable of running the inference loop.
__init__(mesh, parameters, task_provider, model_provider, data_provider)
Constructs a configurator capable of building the full inference state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh
|
DeviceMeshParameters
|
Definition of the distributed device mesh topology. |
required |
parameters
|
InferenceConfig
|
The global configuration object for inference. |
required |
task_provider
|
InferenceTaskProvider
|
Factory for creating the inference task logic. |
required |
model_provider
|
ModelProvider
|
Factory for defining and creating model stages. |
required |
data_provider
|
DatasetProvider
|
Factory for providing inference datasets. |
required |
configure()
Instantiates all inference components and returns a configured Inference engine.
This method triggers the creation of the distributed context, sets seeds, builds the model, data loaders, and attaches all auxiliary components.
Returns:
| Name | Type | Description |
|---|---|---|
Inference |
Inference
|
A ready-to-use inference engine instance encapsulating the job state. |
The Configuration Lifecycle
The InferenceConfigurator.configure() method performs a setup sequence similar to training, but optimized for forward-only execution:
-
Distributed Context Initialization:
- Constructs the global DistributedContext.
-
Seeding:
- Sets distributed seeds. Determinism is crucial in inference for reproducible sampling or validation splits.
-
Event Bus Initialization:
- Creates the global
EventBusfor lifecycle extensions. Tasks and Providers can register custom hooks. - Triggers
EVENT_INFERENCE_CONFIG_STARTEDevent.
- Creates the global
-
Task Instantiation:
- Instantiates the
InferenceTask. This defines how inputs are processed and what to do with the outputs (e.g., writing to a JSONL file).
- Instantiates the
-
Data Loader Construction:
- Creates a distributed
DataLoaderthat handles sharding the inference dataset across ranks.
- Creates a distributed
-
Model Materialization:
- The
ModelStageFactoryruns to build the model. - Note: This reuses the exact same
ModelProvideras training. - Triggers
EVENT_INFERENCE_MODEL_STAGES_READYevent.
- The
-
State Assembly:
- Components are packed into
InferenceJobState. - The
Inferenceengine is instantiated.
- Components are packed into
Execution
To run the job, call the .infer() method on the configured object.
d9d.loop.run.Inference
The main execution engine for running a distributed inference job.
This class manages the inference loop, lifecycle events, distributed synchronization,
and periodic side-effects (profiling, checkpointing). It ensures the model is in
evaluation mode and runs within a torch.inference_mode context.
__init__(state)
Constructs an Inference engine from a pre-built job state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
InferenceJobState
|
The encapsulated state object containing all initialized components. |
required |
infer()
Executes the full inference workflow.
This method:
- Waits for world synchronization.
- Loads the latest checkpoint if available.
- Iterates through the data loader.
- Executes the pipeline forward pass for every batch.
- Handles periodic garbage collection and profiling.
- Finalizes the task upon completion.
The Inference Lifecycle
The Inference.infer() method orchestrates the execution flow. It is designed to be lean and memory-efficient.
1. Initialization & Recovery
Before the loop starts:
- Mode Switching:
- Enables
torch.inference_mode(). This disables gradient calculation globally, saving significant memory. - Sets all model modules to
.eval()mode (affecting Dropout, BatchNorm, etc.).
- Enables
- State Loading:
- The
StateCheckpointerloads the model weights from the specified checkpoint. - If the job was interrupted previously, it also restores the
StepperandDataLoaderstate to resume exactly where it left off.
- The
- Context Entry:
- Enters UI, Garbage Collector, and Profiler contexts.
- Ready Hook Trigger:
EVENT_INFERENCE_READYis fired to mark initialization completion.
2. The Step Loop
For every step:
- Triggers
EVENT_INFERENCE_STEP_PREevent. -
Microbatch Execution:
- Triggers
EVENT_INFERENCE_FORWARD_PREevent. - The
DataLoaderyields a batch group. - The
InferenceTaskOperatormanages the execution. - Data is fed through the model.
- Unlike training, no backward pass is performed.
- Triggers
EVENT_INFERENCE_FORWARD_POSTevent.
- Triggers
-
Maintenance:
- GC:
ManualGarbageCollectorruns periodically to ensure peak memory usage is controlled. - Event-Based Logic: Triggers
EVENT_INFERENCE_STEP_POSTevent. - Advance: The
Stepperincrements.
- GC:
-
Checkpointing:
- If configured, the system saves the progress of the inference job. This allows restarting a long-running generation job on a massive dataset without re-processing the first half.
3. Finalization
- Event-specific: The system triggers
EVENT_INFERENCE_FINISHEDevent. - Task-specific:
- Calls
InferenceTask.finalize(). - This is typically used to close file handles (e.g., flushing the final lines of a generated dataset to disk).
- Calls