Skip to content

Async Metric Collection

Internal API Warning

If you are using the standard d9d Trainer, you do not need to interact with this package directly. It is handled automatically. This documentation is intended for users implementing custom training loops or logging infrastructure.

About

The d9d.internals.metric_collector package provides the infrastructure for non-blocking metric processing.

While the Metric interface is synchronous by design, the AsyncMetricCollector wraps a metric instance and schedules its synchronization and computation on a secondary CUDA stream. This allows the main training loop to proceed immediately without waiting for metric reductions (all-reduce) to complete.

d9d.internals.metric_collector

AsyncMetricCollector

Helper class to synchronize and compute metrics asynchronously on a separate CUDA stream.

This class decouples metric synchronization and computation from the main training loop. It schedules the heavy lifting (distributed reduction and tensor operations) on a secondary stream.

__init__(metric)

Constructs AsyncMetricCollector object.

Parameters:

Name Type Description Default
metric Metric

The metric instance to collect and compute asynchronously.

required

bind()

Moves the underlying metric to CUDA and initializes the side stream.

collect_results()

Waits for the async computation to finish and retrieves results.

This method synchronizes the current stream with the side stream, moves results to CPU, converts them to Python scalars, and resets the underlying metric.

Returns:

Type Description
PyTree[float | int | bool]

A PyTree structure matching the metric's output containing python scalars

PyTree[float | int | bool]

(float, int, or bool) located on the CPU.

Raises:

Type Description
RuntimeError

If the collector is not bound or if schedule_collection was not called prior to this method.

schedule_collection(dist_context)

Schedules metric synchronization and computation on the side stream.

This method records a dependency on the current stream to ensure all data required for the metric is available, then launches the synchronization (if distributed) and computation tasks on the dedicated side stream.

Parameters:

Name Type Description Default
dist_context DistributedContext

Distributed context used for metric synchronization across ranks.

required

Raises:

Type Description
RuntimeError

If the collector has not been bound via .bind().

unbind()

Releases the reference to the side stream.