Skip to content

User Tasks

A Task defines custom logic for a single train or inference step.

Each Task may implement Stateful protocol, so you may store some mutable state here.

TrainTask

It is responsible for logging metrics, mapping batch inputs before they are fed into the model, and for computing the task loss function value.

Init: create_metrics(...), dump_hparams(...).

Lifecycle:

  1. build_forward_inputs(...) (will be called once) ->
  2. compute_loss(...) (will be called multiple times if pipelining is enabled - once for each pipeline microbatch) ->
  3. update_metrics(...) (will be called once).

Exit: finalize(...).

State Management: state_dict(...), load_state_dict(...).

Events Registration: register_events(...) allows you to link specific custom methods to framework-wide Event Hooks.

InferenceTask

The InferenceTask defines the logic for a single inference step.

It is designed to handle the forward-only flow, processing the raw tensors synthesized by the model (e.g., logits, hidden states).

Lifecycle:

  1. build_forward_inputs(...) (called once) ->
  2. process_outputs(...) (called once per pipeline microbatch).

Exit: finalize(...).

State Management: state_dict(...), load_state_dict(...).

Events Registration: register_events(...) allows hooking into the Event Bus alongside regular execution.

Pipeline State

You may note that batch is only accessible in build_forward_inputs(...) method, but not in others. Don't worry!

There is an object for transferring any state between the Task Lifecycle stages, - it is called PipelineState.

1
2
3
4
5
ctx.state["target"] = torch.tensor([1, 0, 1, 0], device="cuda")

# ...

metrics["accuracy"].update(ctx.state["target"])

The pipeline state will automatically shard and unshard data if needed.

You may read an additional documentation for its internal behaviour.

Example Implementation

import torch

from d9d.core.dist_context import DistributedContext
from d9d.core.types import ScalarTree
from d9d.module.block.head import LM_IGNORE_INDEX
from d9d.loop.control import *

class SFTTask(TrainTask[dict[str, torch.Tensor]]):
    def __init__(self, dist_ctx: DistributedContext):
        self._dist_ctx = dist_ctx

    def build_forward_inputs(self, ctx: BuildForwardInputsContext) -> BuildForwardInputsResult:
        # ctx.batch contains the output of the Collator.

        # Save labels in state for access during loss computation later
        ctx.state["labels"] = ctx.batch["labels"]

        # Return inputs for model.forward()
        # inputs are only for the first pipeline stage
        # kwargs are the same for all the pipeline stages
        return BuildForwardInputsResult(
            inputs={
                "input_ids": ctx.batch["input_ids"]
            },
            kwargs={
                "labels": ctx.batch["labels"],
                "position_ids": ctx.batch["position_ids"]
            }
        )

    def dump_hparams(self) -> ScalarTree:
        return super().dump_hparams()

    def compute_loss(self, ctx: ComputeLossContext) -> ComputeLossResult:
        # Retrieve log_probs calculated by the model pipeline
        logps = ctx.pipeline_results["logps"]

        # Calculate number of valid tokens (ignoring the -100 padding)
        # This is crucial for variable length batches.
        num_loss_tokens = (ctx.state["labels"] != LM_IGNORE_INDEX).sum()

        # Calculate average loss per valid token
        total_loss = logps.sum() / num_loss_tokens

        return ComputeLossResult(
            loss=total_loss,
            # loss_weight is used for gradient accumulation across the distributed world.
            # If batches have different token counts, we weigh the gradient
            # by token count to get a mathematical true average over the accumulation steps.
            loss_weight=num_loss_tokens / 1000
        )

d9d.loop.control.task

BaseTask

Bases: ABC, Stateful, Generic[TBatch]

Abstract base class representing a unit of work (Task) in the training/inference loop.

build_forward_inputs(ctx) abstractmethod

Transforms raw data loaded from the DataLoader into arguments for the model.

Parameters:

Name Type Description Default
ctx BuildForwardInputsContext[TBatch]

Context object.

required

Returns:

Type Description
BuildForwardInputsResult

Result object.

finalize(ctx)

Performs cleanup or final actions when the task execution finishes.

Parameters:

Name Type Description Default
ctx FinalizeContext

Context object.

required

load_state_dict(state_dict)

Restores the task's state from the provided dictionary.

Parameters:

Name Type Description Default
state_dict dict[str, Any]

The state dictionary to load.

required

register_events(context)

Register task-specific event subscriptions.

Parameters:

Name Type Description Default
context RegisterTaskEventsContext

Context providing access to the distributed environment and the event bus.

required

state_dict()

Returns the state dictionary for checkpointing this task.

Returns:

Type Description
dict[str, Any]

A dictionary containing the task's state.

BuildForwardInputsContext dataclass

Bases: Generic[TBatch]

Context data to prepare inputs for the model forward pass.

Attributes:

Name Type Description
batch TBatch

The raw batch data loaded from the DataLoader object.

state PipelineState

The current state of the pipeline. You can assign any data to this state object, and it will be accessible during this pipeline step (e.g. when computing loss)

BuildForwardInputsResult dataclass

The result of processing the raw batch into model inputs.

Attributes:

Name Type Description
inputs dict[str, Tensor]

A dictionary of inputs that are passed to model pipeline as input data (first stage only if using pipeline parallelism).

kwargs dict[str, Any]

A dictionary of keyword arguments passed to each pipeline stage.

pipeline_sharding_spec PipelineShardingSpec | None

A specification defining how inputs and kwargs should be split into micro-batches for pipeline parallelism. If None, the framework assumes standard behavior where all the non-scalar Tensors and lists are split by 0 dimension.

ComputeLossContext dataclass

Context data provided to calculate the loss during training.

Attributes:

Name Type Description
pipeline_results Mapping[str, Tensor]

The outputs returned by the model's forward pass.

state PipelineState

The current state of the pipeline. You can assign any data to this state object, and it will be accessible during this pipeline step (e.g. when calculating metrics)

stepper Stepper

Component tracking the current step.

ComputeLossResult dataclass

The result of the loss computation.

Attributes:

Name Type Description
loss Tensor

The scalar tensor representing the loss to be backpropagated.

loss_weight Tensor | None

The weight to apply to the loss (for synchronizing gradients using weighted mean). None for 1.0.

CreateMetricsContext dataclass

Context data provided to initialize metrics.

CreateMetricsResult dataclass

Result of metric initialization.

Attributes:

Name Type Description
metrics dict[str, Metric]

A dictionary mapping metric names to Metric instances.

FinalizeContext dataclass

Context data provided when the task is being finalized.

InferenceTask

Bases: BaseTask, ABC, Generic[TBatch]

Abstract base class for defining inference-specific logic.

build_forward_inputs(ctx) abstractmethod

Transforms raw data loaded from the DataLoader into arguments for the model.

Parameters:

Name Type Description Default
ctx BuildForwardInputsContext[TBatch]

Context object.

required

Returns:

Type Description
BuildForwardInputsResult

Result object.

finalize(ctx)

Performs cleanup or final actions when the task execution finishes.

Parameters:

Name Type Description Default
ctx FinalizeContext

Context object.

required

load_state_dict(state_dict)

Restores the task's state from the provided dictionary.

Parameters:

Name Type Description Default
state_dict dict[str, Any]

The state dictionary to load.

required

process_outputs(ctx) abstractmethod

Processes the model outputs (e.g. saving to disk, decoding tokens).

Parameters:

Name Type Description Default
ctx ProcessOutputsContext

Context containing the model outputs and pipeline state.

required

register_events(context)

Register task-specific event subscriptions.

Parameters:

Name Type Description Default
context RegisterTaskEventsContext

Context providing access to the distributed environment and the event bus.

required

state_dict()

Returns the state dictionary for checkpointing this task.

Returns:

Type Description
dict[str, Any]

A dictionary containing the task's state.

InferenceTaskProvider

Bases: Protocol

Protocol for a callable that creates an InferenceTask instance.

__call__(ctx)

Creates and returns a new InferenceTask.

Parameters:

Name Type Description Default
ctx InferenceTaskProviderContext

Context providing distributed environment information.

required

Returns:

Type Description
InferenceTask

An instantiated InferenceTask.

InferenceTaskProviderContext dataclass

Context data provided to the factory creating an InferenceTask.

Attributes:

Name Type Description
dist_context DistributedContext

Information about the distributed environment.

ProcessOutputsContext dataclass

Context data provided to process outputs during inference.

Attributes:

Name Type Description
pipeline_results dict[str, Tensor]

The outputs returned by the model's forward pass.

state PipelineState

The current state of the pipeline.

RegisterTaskEventsContext dataclass

Context for registering task-specific events.

Attributes:

Name Type Description
dist_context DistributedContext

The distributed execution context.

event_bus EventBus

The event bus for subscribing to events.

TrainTask

Bases: BaseTask, ABC, Generic[TBatch]

Abstract base class for defining training-specific logic.

build_forward_inputs(ctx) abstractmethod

Transforms raw data loaded from the DataLoader into arguments for the model.

Parameters:

Name Type Description Default
ctx BuildForwardInputsContext[TBatch]

Context object.

required

Returns:

Type Description
BuildForwardInputsResult

Result object.

compute_loss(ctx) abstractmethod

Calculates the loss based on model outputs.

Parameters:

Name Type Description Default
ctx ComputeLossContext

Context object.

required

Returns:

Type Description
ComputeLossResult

Result object.

create_metrics(ctx)

Initializes metrics to be tracked during training.

Parameters:

Name Type Description Default
ctx CreateMetricsContext

Context object.

required

Returns:

Type Description
CreateMetricsResult

Result object.

dump_hparams()

Exports hyperparameters associated with this task for logging.

Returns:

Type Description
ScalarTree

A dictionary of hyperparameter names and values.

finalize(ctx)

Performs cleanup or final actions when the task execution finishes.

Parameters:

Name Type Description Default
ctx FinalizeContext

Context object.

required

load_state_dict(state_dict)

Restores the task's state from the provided dictionary.

Parameters:

Name Type Description Default
state_dict dict[str, Any]

The state dictionary to load.

required

register_events(context)

Register task-specific event subscriptions.

Parameters:

Name Type Description Default
context RegisterTaskEventsContext

Context providing access to the distributed environment and the event bus.

required

state_dict()

Returns the state dictionary for checkpointing this task.

Returns:

Type Description
dict[str, Any]

A dictionary containing the task's state.

update_metrics(ctx)

Updates the state of the metrics at the end of training step.

Parameters:

Name Type Description Default
ctx UpdateMetricsContext

Context object.

required

TrainTaskProvider

Bases: Protocol

Protocol that creates a TrainTask instance.

__call__(ctx)

Creates and returns a new TrainTask.

Parameters:

Name Type Description Default
ctx TrainTaskProviderContext

Context object.

required

Returns:

Type Description
TrainTask

An instantiated TrainTask.

TrainTaskProviderContext dataclass

Context data provided to the factory creating a TrainTask.

Attributes:

Name Type Description
dist_context DistributedContext

Information about the distributed environment.

UpdateMetricsContext dataclass

Context data provided to update metrics after a step.

Attributes:

Name Type Description
state PipelineState

The current state of the pipeline.

metrics Mapping[str, Metric]

The dictionary of metrics to be updated.