User Tasks
A Task defines custom logic for a single train or inference step.
Each Task may implement Stateful protocol, so you may store some mutable state here.
TrainTask
It is responsible for logging metrics, mapping batch inputs before they are fed into the model, and for computing the task loss function value.
Init: create_metrics(...), dump_hparams(...).
Lifecycle:
build_forward_inputs(...)(will be called once) ->compute_loss(...)(will be called multiple times if pipelining is enabled - once for each pipeline microbatch) ->update_metrics(...)(will be called once).
Exit: finalize(...).
State Management: state_dict(...), load_state_dict(...).
Events Registration: register_events(...) allows you to link specific custom methods to framework-wide Event Hooks.
InferenceTask
The InferenceTask defines the logic for a single inference step.
It is designed to handle the forward-only flow, processing the raw tensors synthesized by the model (e.g., logits, hidden states).
Lifecycle:
build_forward_inputs(...)(called once) ->process_outputs(...)(called once per pipeline microbatch).
Exit: finalize(...).
State Management: state_dict(...), load_state_dict(...).
Events Registration: register_events(...) allows hooking into the Event Bus alongside regular execution.
Pipeline State
You may note that batch is only accessible in build_forward_inputs(...) method, but not in others. Don't worry!
There is an object for transferring any state between the Task Lifecycle stages, - it is called PipelineState.
The pipeline state will automatically shard and unshard data if needed.
You may read an additional documentation for its internal behaviour.
Example Implementation
d9d.loop.control.task
BaseTask
Bases: ABC, Stateful, Generic[TBatch]
Abstract base class representing a unit of work (Task) in the training/inference loop.
build_forward_inputs(ctx)
abstractmethod
Transforms raw data loaded from the DataLoader into arguments for the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
BuildForwardInputsContext[TBatch]
|
Context object. |
required |
Returns:
| Type | Description |
|---|---|
BuildForwardInputsResult
|
Result object. |
finalize(ctx)
Performs cleanup or final actions when the task execution finishes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
FinalizeContext
|
Context object. |
required |
load_state_dict(state_dict)
register_events(context)
Register task-specific event subscriptions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context
|
RegisterTaskEventsContext
|
Context providing access to the distributed environment and the event bus. |
required |
BuildForwardInputsContext
dataclass
Bases: Generic[TBatch]
Context data to prepare inputs for the model forward pass.
Attributes:
| Name | Type | Description |
|---|---|---|
batch |
TBatch
|
The raw batch data loaded from the DataLoader object. |
state |
PipelineState
|
The current state of the pipeline. You can assign any data to this state object, and it will be accessible during this pipeline step (e.g. when computing loss) |
BuildForwardInputsResult
dataclass
The result of processing the raw batch into model inputs.
Attributes:
| Name | Type | Description |
|---|---|---|
inputs |
dict[str, Tensor]
|
A dictionary of inputs that are passed to model pipeline as input data (first stage only if using pipeline parallelism). |
kwargs |
dict[str, Any]
|
A dictionary of keyword arguments passed to each pipeline stage. |
pipeline_sharding_spec |
PipelineShardingSpec | None
|
A specification defining how inputs and kwargs should be split into micro-batches for pipeline parallelism. If None, the framework assumes standard behavior where all the non-scalar Tensors and lists are split by 0 dimension. |
ComputeLossContext
dataclass
Context data provided to calculate the loss during training.
Attributes:
| Name | Type | Description |
|---|---|---|
pipeline_results |
Mapping[str, Tensor]
|
The outputs returned by the model's forward pass. |
state |
PipelineState
|
The current state of the pipeline. You can assign any data to this state object, and it will be accessible during this pipeline step (e.g. when calculating metrics) |
stepper |
Stepper
|
Component tracking the current step. |
ComputeLossResult
dataclass
CreateMetricsContext
dataclass
Context data provided to initialize metrics.
CreateMetricsResult
dataclass
FinalizeContext
dataclass
Context data provided when the task is being finalized.
InferenceTask
Bases: BaseTask, ABC, Generic[TBatch]
Abstract base class for defining inference-specific logic.
build_forward_inputs(ctx)
abstractmethod
Transforms raw data loaded from the DataLoader into arguments for the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
BuildForwardInputsContext[TBatch]
|
Context object. |
required |
Returns:
| Type | Description |
|---|---|
BuildForwardInputsResult
|
Result object. |
finalize(ctx)
Performs cleanup or final actions when the task execution finishes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
FinalizeContext
|
Context object. |
required |
load_state_dict(state_dict)
process_outputs(ctx)
abstractmethod
Processes the model outputs (e.g. saving to disk, decoding tokens).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
ProcessOutputsContext
|
Context containing the model outputs and pipeline state. |
required |
register_events(context)
Register task-specific event subscriptions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context
|
RegisterTaskEventsContext
|
Context providing access to the distributed environment and the event bus. |
required |
InferenceTaskProvider
Bases: Protocol
Protocol for a callable that creates an InferenceTask instance.
__call__(ctx)
Creates and returns a new InferenceTask.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
InferenceTaskProviderContext
|
Context providing distributed environment information. |
required |
Returns:
| Type | Description |
|---|---|
InferenceTask
|
An instantiated InferenceTask. |
InferenceTaskProviderContext
dataclass
Context data provided to the factory creating an InferenceTask.
Attributes:
| Name | Type | Description |
|---|---|---|
dist_context |
DistributedContext
|
Information about the distributed environment. |
ProcessOutputsContext
dataclass
Context data provided to process outputs during inference.
Attributes:
| Name | Type | Description |
|---|---|---|
pipeline_results |
dict[str, Tensor]
|
The outputs returned by the model's forward pass. |
state |
PipelineState
|
The current state of the pipeline. |
RegisterTaskEventsContext
dataclass
Context for registering task-specific events.
Attributes:
| Name | Type | Description |
|---|---|---|
dist_context |
DistributedContext
|
The distributed execution context. |
event_bus |
EventBus
|
The event bus for subscribing to events. |
TrainTask
Bases: BaseTask, ABC, Generic[TBatch]
Abstract base class for defining training-specific logic.
build_forward_inputs(ctx)
abstractmethod
Transforms raw data loaded from the DataLoader into arguments for the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
BuildForwardInputsContext[TBatch]
|
Context object. |
required |
Returns:
| Type | Description |
|---|---|
BuildForwardInputsResult
|
Result object. |
compute_loss(ctx)
abstractmethod
Calculates the loss based on model outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
ComputeLossContext
|
Context object. |
required |
Returns:
| Type | Description |
|---|---|
ComputeLossResult
|
Result object. |
create_metrics(ctx)
Initializes metrics to be tracked during training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
CreateMetricsContext
|
Context object. |
required |
Returns:
| Type | Description |
|---|---|
CreateMetricsResult
|
Result object. |
dump_hparams()
Exports hyperparameters associated with this task for logging.
Returns:
| Type | Description |
|---|---|
ScalarTree
|
A dictionary of hyperparameter names and values. |
finalize(ctx)
Performs cleanup or final actions when the task execution finishes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
FinalizeContext
|
Context object. |
required |
load_state_dict(state_dict)
register_events(context)
Register task-specific event subscriptions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context
|
RegisterTaskEventsContext
|
Context providing access to the distributed environment and the event bus. |
required |
state_dict()
update_metrics(ctx)
Updates the state of the metrics at the end of training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
UpdateMetricsContext
|
Context object. |
required |
TrainTaskProvider
Bases: Protocol
Protocol that creates a TrainTask instance.
__call__(ctx)
Creates and returns a new TrainTask.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ctx
|
TrainTaskProviderContext
|
Context object. |
required |
Returns:
| Type | Description |
|---|---|
TrainTask
|
An instantiated TrainTask. |
TrainTaskProviderContext
dataclass
Context data provided to the factory creating a TrainTask.
Attributes:
| Name | Type | Description |
|---|---|---|
dist_context |
DistributedContext
|
Information about the distributed environment. |
UpdateMetricsContext
dataclass
Context data provided to update metrics after a step.
Attributes:
| Name | Type | Description |
|---|---|---|
state |
PipelineState
|
The current state of the pipeline. |
metrics |
Mapping[str, Metric]
|
The dictionary of metrics to be updated. |