Async Metric Collection
Internal API Warning
If you are using the standard d9d Trainer, you do not need to interact with this package directly. It is handled automatically. This documentation is intended for users implementing custom training loops or logging infrastructure.
About
The d9d.internals.metric_collector package provides the infrastructure for non-blocking metric processing.
While the Metric interface is synchronous by design, the AsyncMetricCollector wraps a metric instance and schedules its synchronization and computation on a secondary CUDA stream. This allows the main training loop to proceed immediately without waiting for metric reductions (all-reduce) to complete.
d9d.internals.metric_collector
AsyncMetricCollector
Helper class to synchronize and compute metrics asynchronously on a separate CUDA stream.
This class decouples metric synchronization and computation from the main training loop. It schedules the heavy lifting (distributed reduction and tensor operations) on a secondary stream.
__init__(metric)
Constructs AsyncMetricCollector object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metric
|
Metric
|
The metric instance to collect and compute asynchronously. |
required |
bind()
Moves the underlying metric to CUDA and initializes the side stream.
collect_results()
Waits for the async computation to finish and retrieves results.
This method synchronizes the current stream with the side stream, moves results to CPU, converts them to Python scalars, and resets the underlying metric.
Returns:
| Type | Description |
|---|---|
PyTree[float | int | bool]
|
A PyTree structure matching the metric's output containing python scalars |
PyTree[float | int | bool]
|
(float, int, or bool) located on the CPU. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the collector is not bound or if schedule_collection was not called prior to this method. |
schedule_collection(dist_context)
Schedules metric synchronization and computation on the side stream.
This method records a dependency on the current stream to ensure all data required for the metric is available, then launches the synchronization (if distributed) and computation tasks on the dedicated side stream.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
Distributed context used for metric synchronization across ranks. |
required |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the collector has not been bound via .bind(). |
unbind()
Releases the reference to the side stream.