Creating Custom Metrics
The d9d framework allows you to implement custom metrics by adhering to the Metric interface.
Design Guidelines
Metric implementations usually follow this design:
- GPU Residency: Metrics accumulate data directly on the GPU tensors to avoid CPU-GPU synchronization.
- Linearly Additive States: Instead of storing unstable averages (e.g., "Current Accuracy"), store raw accumulation counts like "Total Correct" and "Total Samples". These are mathematically safe to sum via
all_reduce.
Helper Components
We provide the d9d.metric.component package to simplify implementation:
- MetricAccumulator: A helper object that handles the boilerplate of maintaining Local vs Synchronized versions of a metric state tensor. It supports standard reduction operations like Sum, Max, and Min.
Examples
MaxMetric
Below is an example of a MaxMetric that tracks the maximum value seen across all ranks using the MetricAccumulator helper.
d9d.metric.component
Reusable components for building distributed metrics.
MetricAccumulator
Bases: Stateful
Helper class to track a distributed metric state.
This class manages two copies of the state: a 'local' copy that is updated locally on every step, and a 'synchronized' copy that is populated during the sync phase via distributed reduction (all-reduce).
value
property
__init__(initial_value, reduce_op=MetricReduceOp.sum)
Constructs MetricAccumulator object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
initial_value
|
Tensor
|
Tensor representing the starting value (e.g., 0 for sum, -inf for max). This tensor determines the device and dtype of the accumulator. |
required |
reduce_op
|
MetricReduceOp
|
The reduction operation to use during updates and synchronization. |
sum
|
load_state_dict(state_dict)
reset()
Resets the accumulator to its initial state.
state_dict()
sync()
Synchronizes the accumulator across the default distributed process group.
This method acts as a blocking barrier. It copies the local state to a buffer
and performs an all_reduce collective operation.
to(device)
update(value)
Updates the local accumulator with a new value.
This operation is performed in-place on the local tensor using the configured reduction operation (e.g., add for Sum, max for Max). It marks the accumulator as not synchronized.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Tensor | float | bool
|
The value to accumulate. |
required |