Skip to content

Creating Custom Metrics

The d9d framework allows you to implement custom metrics by adhering to the Metric interface.

Design Guidelines

Metric implementations usually follow this design:

  • GPU Residency: Metrics accumulate data directly on the GPU tensors to avoid CPU-GPU synchronization.
  • Linearly Additive States: Instead of storing unstable averages (e.g., "Current Accuracy"), store raw accumulation counts like "Total Correct" and "Total Samples". These are mathematically safe to sum via all_reduce.

Helper Components

We provide the d9d.metric.component package to simplify implementation:

  • MetricAccumulator: A helper object that handles the boilerplate of maintaining Local vs Synchronized versions of a metric state tensor. It supports standard reduction operations like Sum, Max, and Min.

Examples

MaxMetric

Below is an example of a MaxMetric that tracks the maximum value seen across all ranks using the MetricAccumulator helper.

import torch
import torch.distributed as dist
from typing import Any

from d9d.metric import Metric
from d9d.metric.component import MetricAccumulator, MetricReduceOp
from d9d.core.dist_context import DistributedContext

class MaxMetric(Metric[torch.Tensor]):
    def __init__(self):
        # Initialize accumulator with -inf
        self._max_val = MetricAccumulator(
            torch.tensor(float('-inf')), 
            reduce_op=MetricReduceOp.max
        )

    def update(self, value: torch.Tensor):
        # Update local max (No communication)
        self._max_val.update(value)

    def sync(self, dist_context: DistributedContext):
        # Perform all_reduce across the world
        self._max_val.sync()

    def compute(self) -> torch.Tensor:
        # Return the synchronized value
        return self._max_val.value

    def reset(self):
        self._max_val.reset()

    def to(self, device: str | torch.device | int):
        self._max_val.to(device)

    # Stateful Protocol for Checkpointing
    def state_dict(self) -> dict[str, Any]:
        return {'max_val': self._max_val.state_dict()}

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        self._max_val.load_state_dict(state_dict['max_val'])

d9d.metric.component

Reusable components for building distributed metrics.

MetricAccumulator

Bases: Stateful

Helper class to track a distributed metric state.

This class manages two copies of the state: a 'local' copy that is updated locally on every step, and a 'synchronized' copy that is populated during the sync phase via distributed reduction (all-reduce).

value property

Returns the current accumulated value.

Returns:

Type Description
Tensor

The global synchronized value if sync() was called recently,

Tensor

otherwise the local accumulated value.

__init__(initial_value, reduce_op=MetricReduceOp.sum)

Constructs MetricAccumulator object.

Parameters:

Name Type Description Default
initial_value Tensor

Tensor representing the starting value (e.g., 0 for sum, -inf for max). This tensor determines the device and dtype of the accumulator.

required
reduce_op MetricReduceOp

The reduction operation to use during updates and synchronization.

sum

load_state_dict(state_dict)

Restores the accumulator state from a checkpoint.

Parameters:

Name Type Description Default
state_dict dict[str, Any]

Dictionary containing state to load.

required

reset()

Resets the accumulator to its initial state.

state_dict()

Returns the serialized state of the accumulator.

Returns:

Type Description
dict[str, Any]

Dictionary containing local and synchronized tensors and status flags.

sync()

Synchronizes the accumulator across the default distributed process group.

This method acts as a blocking barrier. It copies the local state to a buffer and performs an all_reduce collective operation.

to(device)

Moves internal tensors to the specified device.

Parameters:

Name Type Description Default
device str | device | int

Target device.

required

update(value)

Updates the local accumulator with a new value.

This operation is performed in-place on the local tensor using the configured reduction operation (e.g., add for Sum, max for Max). It marks the accumulator as not synchronized.

Parameters:

Name Type Description Default
value Tensor | float | bool

The value to accumulate.

required