Skip to content

Inference Loop

Overview

The d9d.loop package provides the execution engine not only for training but also for high-scale distributed inference.

The inference engine shares the same core philosophy as the Trainer: separating the definition of the job from the execution.

Example

from d9d.loop.run import InferenceConfigurator

# Configure
inference = InferenceConfigurator(
    mesh=mesh_params,                  # Physical cluster layout
    parameters=config,                 # Logic configuration (batch size, etc)

    model_provider=...,                # Same provider used in training
    task_provider=...,                 # Inference-specific logic (e.g., generation)
    data_provider=...,                 # Validation/Test dataset
).configure()

# Execute
inference.infer()

Configuration & Construction

The inference environment is assembled using the InferenceConfigurator.

This class binds the infrastructure and user logic into a ready-to-execute Inference object.

d9d.loop.run.InferenceConfigurator

Orchestrates the assembly of the distributed inference environment.

This class binds the infrastructure configuration (DeviceMesh), the inference parameters, and the user-defined logic (Providers) to create a fully initialized state object capable of running the inference loop.

__init__(mesh, parameters, task_provider, model_provider, data_provider)

Constructs a configurator capable of building the full inference state.

Parameters:

Name Type Description Default
mesh DeviceMeshParameters

Definition of the distributed device mesh topology.

required
parameters InferenceConfig

The global configuration object for inference.

required
task_provider InferenceTaskProvider

Factory for creating the inference task logic.

required
model_provider ModelProvider

Factory for defining and creating model stages.

required
data_provider DatasetProvider

Factory for providing inference datasets.

required

configure()

Instantiates all inference components and returns a configured Inference engine.

This method triggers the creation of the distributed context, sets seeds, builds the model, data loaders, and attaches all auxiliary components.

Returns:

Name Type Description
Inference Inference

A ready-to-use inference engine instance encapsulating the job state.

The Configuration Lifecycle

The InferenceConfigurator.configure() method performs a setup sequence similar to training, but optimized for forward-only execution:

  1. Distributed Context Initialization:

  2. Seeding:

    • Sets distributed seeds. Determinism is crucial in inference for reproducible sampling or validation splits.
  3. Event Bus Initialization:

    • Creates the global EventBus for lifecycle extensions. Tasks and Providers can register custom hooks.
    • Triggers EVENT_INFERENCE_CONFIG_STARTED event.
  4. Task Instantiation:

    • Instantiates the InferenceTask. This defines how inputs are processed and what to do with the outputs (e.g., writing to a JSONL file).
  5. Data Loader Construction:

    • Creates a distributed DataLoader that handles sharding the inference dataset across ranks.
  6. Model Materialization:

    • The ModelStageFactory runs to build the model.
    • Note: This reuses the exact same ModelProvider as training.
    • Triggers EVENT_INFERENCE_MODEL_STAGES_READY event.
  7. State Assembly:

    • Components are packed into InferenceJobState.
    • The Inference engine is instantiated.

Execution

To run the job, call the .infer() method on the configured object.

d9d.loop.run.Inference

The main execution engine for running a distributed inference job.

This class manages the inference loop, lifecycle events, distributed synchronization, and periodic side-effects (profiling, checkpointing). It ensures the model is in evaluation mode and runs within a torch.inference_mode context.

__init__(state)

Constructs an Inference engine from a pre-built job state.

Parameters:

Name Type Description Default
state InferenceJobState

The encapsulated state object containing all initialized components.

required

infer()

Executes the full inference workflow.

This method:

  1. Waits for world synchronization.
  2. Loads the latest checkpoint if available.
  3. Iterates through the data loader.
  4. Executes the pipeline forward pass for every batch.
  5. Handles periodic garbage collection and profiling.
  6. Finalizes the task upon completion.

The Inference Lifecycle

The Inference.infer() method orchestrates the execution flow. It is designed to be lean and memory-efficient.

1. Initialization & Recovery

Before the loop starts:

  1. Mode Switching:
    • Enables torch.inference_mode(). This disables gradient calculation globally, saving significant memory.
    • Sets all model modules to .eval() mode (affecting Dropout, BatchNorm, etc.).
  2. State Loading:
    • The StateCheckpointer loads the model weights from the specified checkpoint.
    • If the job was interrupted previously, it also restores the Stepper and DataLoader state to resume exactly where it left off.
  3. Context Entry:
    • Enters UI, Garbage Collector, and Profiler contexts.
  4. Ready Hook Trigger: EVENT_INFERENCE_READY is fired to mark initialization completion.

2. The Step Loop

For every step:

  1. Triggers EVENT_INFERENCE_STEP_PRE event.
  2. Microbatch Execution:

    • Triggers EVENT_INFERENCE_FORWARD_PRE event.
    • The DataLoader yields a batch group.
    • The InferenceTaskOperator manages the execution.
    • Data is fed through the model.
    • Unlike training, no backward pass is performed.
    • Triggers EVENT_INFERENCE_FORWARD_POST event.
  3. Maintenance:

    • GC: ManualGarbageCollector runs periodically to ensure peak memory usage is controlled.
    • Event-Based Logic: Triggers EVENT_INFERENCE_STEP_POST event.
    • Advance: The Stepper increments.
  4. Checkpointing:

    • If configured, the system saves the progress of the inference job. This allows restarting a long-running generation job on a massive dataset without re-processing the first half.

3. Finalization

  1. Event-specific: The system triggers EVENT_INFERENCE_FINISHED event.
  2. Task-specific:
    • Calls InferenceTask.finalize().
    • This is typically used to close file handles (e.g., flushing the final lines of a generated dataset to disk).