Overview
About
The d9d.metric package provides a unified interface for tracking, accumulating, and synchronizing statistics (such as Accuracy) across a distributed environment.
Why and How
The Single-GPU Trap
Some practitioners coming from single-GPU training or standard data science backgrounds are used to workflows relying on good-old CPU-based libraries such as scikit-learn:
In a large-scale distributed environment, this approach causes critical failures:
- Pipeline Stalls: Calling
.item()or.cpu()forces a synchronization that waits for the GPU to finish. This destroys the pipelining efficiency required for training large models. - Out-of-Memory Errors: Accumulating prediction history for many steps in a Python list will rapidly exhaust RAM.
- No Synchronization - Partial View: Rank 0 only sees its own data shard. Logging loss from Rank 0 is misleading.
So, we have to do something with metric implementations to be performant and accurate.
The d9d Solution
This package addresses issues described above by providing a Metric interface that is:
- Distributed Aware: Each metric knows how to synchronize its state across an ND-parallel environment via the
syncmethod. - Async Compatible: While
Metricimplementations themselves can remain simple and synchronous, they are designed to be driven by theAsyncMetricCollector. This wrapper offloads the synchronization and computation to a side-stream, allowing the main training loop to continue while metrics are being reduced. - Stateful: Metrics implement the
torch.distributed.checkpoint.stateful.Statefulinterface, allowing their state to be checkpointed seamlessly. - Clear: Unlike some other libraries, d9d's
Metricis a lightweight interface. It has no hidden state accounting or complex contracts. Just implement the interface and ensure you don't break the lifecycle.
The Metric Lifecycle
A Metric in d9d follows a specific lifecycle:
- Update: Happens every train step. Data is aggregated locally on the GPU using methods like
.add_(). No communication occurs here. - Sync: Happens at the logging interval. The metric aggregates data across the world (e.g.
all_reduce). - Compute: Calculates the final scalar (e.g., dividing total loss by total samples) using the synchronized data.
- Reset: Clears the internal state for the next logging window.
Usage Examples
Basic Usage
Typically, you want to just instantiate and update metrics within your TrainTask object.
See related examples in Trainer documentation.
Manual Usage
You may want to use d9d metrics manually, without using the Trainer object.
When used directly, the sync() method is blocking by default. You may call it within torch.cuda.stream(...)
to overlap with computations.
d9d.metric
Distributed metric abstractions and implementations.
Metric
Bases: ABC, Stateful, Generic[TComputeResult]
Abstract base class for all metrics.
Metrics track statistics over time (e.g., during training) and can be synchronized across distributed processes. They also support state persistence via the Stateful interface.
compute()
abstractmethod
Computes the current value of the metric.
Returns:
| Type | Description |
|---|---|
TComputeResult
|
The computed metric result (of type |
reset()
abstractmethod
Resets the internal state of the metric to the initial values.
sync(dist_context)
abstractmethod
Synchronizes the metric state across distributed processes.
This method aggregates statistics from all ranks (e.g., via all-reduce) to ensure the metric state is consistent globally.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context. |
required |