About

The d9d.metric package provides a unified interface for tracking, accumulating, and synchronizing statistics (such as Accuracy) across a distributed environment.

Why and How

The Single-GPU Trap

Some practitioners coming from single-GPU training or standard data science backgrounds are used to workflows relying on good-old CPU-based libraries such as scikit-learn:

# Typical single-node pattern
loss_val = loss_fn(pred, target).item() # <--- CPU Sync Point 1
history.append(loss_val)
# ... later ...
avg = np.mean(history)     # <--- CPU Sync Point 2
sklearn.metrics.f1_score(all_preds, all_targets)

In a large-scale distributed environment, this approach causes critical failures:

  • Pipeline Stalls: Calling .item() or .cpu() forces a synchronization that waits for the GPU to finish. This destroys the pipelining efficiency required for training large models.
  • Out-of-Memory Errors: Accumulating prediction history for many steps in a Python list will rapidly exhaust RAM.
  • No Synchronization - Partial View: Rank 0 only sees its own data shard. Logging loss from Rank 0 is misleading.

So, we have to do something with metric implementations to be performant and accurate.

The d9d Solution

This package addresses issues described above by providing a Metric interface that is:

  • Distributed Aware: Each metric is supposed to be able to synchronize its state across ND-parallel environment.
  • Async Compatible: We separate the triggering of communication from the waiting for results. This allows the communication to happen in the background while the GPU continues computing the next micro-batch.
  • Stateful: Metrics implement the torch.distributed.checkpoint.stateful.Stateful interface, allowing their state to be checkpointed seamlessly.
  • Clear: Unlike some other distributed metric libraries such as torchmetrics, d9d's Metric interface is really just an interface. It has no state you have to account, no special contract you have to follow, nothing. Just implement the interface and do whatever you want the way you want, only make sure that you won't break the metric lifecycle.

The Metric Lifecycle

A Metric in d9d follows a specific lifecycle:

  1. Update: Happens every train step. Data is aggregated locally on the GPU. No communication occurs here.
  2. Trigger Sync: Happens at the logging interval. The metric schedules asynchronous collective operations (like all_reduce) to aggregate data across the world.
  3. Wait Sync: Acts as a barrier. Ensures the collective ions from the previous step are finished.
  4. Compute: Calculates the final scalar (e.g., dividing total loss by total samples) using the synchronized data.
  5. Reset: Clears the internal state for the next logging window.

Usage Examples

Basic Usage

Typically, you want to just instantiate and update metrics within your Trainable object.

See related examples in Trainable documentation.

Implementing a Custom Metric

Metric implementations included in d9d usually follow this design:

  • GPU Residency: Metrics accumulate data directly on the GPU tensors, so no GPU-CPU synchronization involved.
  • Linearly Additive States: For instance, instead of storing the "Current Accuracy" (which is hard to average), we store "Total Correct" and "Total Samples". These values are mathematically safe to sum via all_reduce.

Below is an example of a MaxMetric that tracks the maximum value seen across all ranks (e.g., max GPU memory usage or max gradient norm).

import torch
import torch.distributed as dist
from typing import Any

from d9d.metric import Metric
from d9d.core.dist_context import DistributedContext

class MaxMetric(Metric[torch.Tensor]):
    def __init__(self):
        self._max_val = torch.tensor(float('-inf'))
        self._handle: dist.Work | None = None

    def update(self, value: torch.Tensor):
        # Keep local max
        self._max_val = torch.max(self._max_val, value)

    def trigger_sync(self, dist_context: DistributedContext):
        # Schedule async reduction across the world
        self._handle = dist.all_reduce(
            self._max_val, 
            op=dist.ReduceOp.MAX, 
            async_op=True
        )

    def wait_sync(self, dist_context: DistributedContext):
        self._handle.wait()
        self._handle = None

    def compute(self) -> torch.Tensor:
        return self._max_val

    def reset(self):
        self._max_val.fill_(float('-inf'))
        self._handle = None

    def to(self, device: str | torch.device | int):
        self._max_val = self._max_val.to(device)

    # Stateful Protocol for Checkpointing
    def state_dict(self) -> dict[str, Any]:
        return {'max_val': self._max_val}

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        self._max_val = state_dict['max_val']

Manual Usage

You may want to use d9d metrics manually, without using the Trainable object.

Using the built-in WeightedMeanMetric, which is commonly used for tracking Loss (weighted by batch size/token count).

import torch
from d9d.metric.impl import WeightedMeanMetric
from d9d.core.dist_context import DistributedContext

# 1. Initialize
metric = WeightedMeanMetric()

dataloader = ...
dist_ctx = ...

# 2. Training Loop
for step, batch in enumerate(dataloader):
    # ... forward, backward ...
    loss = ... # scalar tensor
    num_tokens_in_loss = ... # scalar tensor

    # Update local state (No communication)
    metric.update(values=loss, weights=num_tokens_in_loss)

# 3. Synchronize
# Initiate communication across all GPUs
metric.trigger_sync(dist_ctx)

# Do other work here
# to hide communication latency.

# 4. Finalize and Print
# Block until communication is done
metric.wait_sync(dist_ctx)
print(f"Global Average Loss: {metric.compute()}")

# 5. Reset for next epoch
metric.reset()

d9d.metric

Distributed metric abstractions and implementations.

Metric

Bases: ABC, Stateful, Generic[TComputeResult]

Abstract base class for all metrics.

Metrics track statistics over time (e.g., during training) and can be synchronized across distributed processes. They also support state persistence via the Stateful interface.

Source code in d9d/metric/abc.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class Metric(abc.ABC, Stateful, Generic[TComputeResult]):
    """
    Abstract base class for all metrics.

    Metrics track statistics over time (e.g., during training) and can be synchronized
    across distributed processes. They also support state persistence via the Stateful
    interface.
    """

    @abc.abstractmethod
    def update(self, *args: Any, **kwargs: Any):
        """
        Updates the metric state with a new batch of data.

        Args:
            *args: Positional arguments required by the specific metric implementation.
            **kwargs: Keyword arguments required by the specific metric implementation.
        """

    @abc.abstractmethod
    def trigger_sync(self, dist_context: DistributedContext):
        """
        Initiates the synchronization of the metric state across distributed processes.

        This method should start the collective operations (e.g., all-reduce) required
        to aggregate statistics, but should not block waiting for completion if possible.

        Args:
            dist_context: The distributed context.
        """

    @abc.abstractmethod
    def wait_sync(self, dist_context: DistributedContext):
        """
        Waits for the synchronization initiated by `trigger_sync` to complete.

        After this method returns, the metric state must be fully aggregated and
        consistent across ranks.

        Args:
            dist_context: The distributed context.
        """

    @abc.abstractmethod
    def compute(self) -> TComputeResult:
        """
        Computes the current value of the metric.

        Returns:
            The computed metric result (of type `TComputeResult`).
                This can be a single `torch.Tensor` or `PyTree` structure (dict, list, etc.)
                containing tensors, depending on how the subclass was typed.
        """

    @abc.abstractmethod
    def reset(self):
        """
        Resets the internal state of the metric to the initial values.
        """

    def to(self, device: str | torch.device | int):
        """
        Moves a metric state to a specified device.

        Args:
            device: The device to move the metric state to.
        """

compute() abstractmethod

Computes the current value of the metric.

Returns:

Type Description
TComputeResult

The computed metric result (of type TComputeResult). This can be a single torch.Tensor or PyTree structure (dict, list, etc.) containing tensors, depending on how the subclass was typed.

Source code in d9d/metric/abc.py
56
57
58
59
60
61
62
63
64
65
@abc.abstractmethod
def compute(self) -> TComputeResult:
    """
    Computes the current value of the metric.

    Returns:
        The computed metric result (of type `TComputeResult`).
            This can be a single `torch.Tensor` or `PyTree` structure (dict, list, etc.)
            containing tensors, depending on how the subclass was typed.
    """

reset() abstractmethod

Resets the internal state of the metric to the initial values.

Source code in d9d/metric/abc.py
67
68
69
70
71
@abc.abstractmethod
def reset(self):
    """
    Resets the internal state of the metric to the initial values.
    """

to(device)

Moves a metric state to a specified device.

Parameters:

Name Type Description Default
device str | device | int

The device to move the metric state to.

required
Source code in d9d/metric/abc.py
73
74
75
76
77
78
79
def to(self, device: str | torch.device | int):
    """
    Moves a metric state to a specified device.

    Args:
        device: The device to move the metric state to.
    """

trigger_sync(dist_context) abstractmethod

Initiates the synchronization of the metric state across distributed processes.

This method should start the collective operations (e.g., all-reduce) required to aggregate statistics, but should not block waiting for completion if possible.

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context.

required
Source code in d9d/metric/abc.py
32
33
34
35
36
37
38
39
40
41
42
@abc.abstractmethod
def trigger_sync(self, dist_context: DistributedContext):
    """
    Initiates the synchronization of the metric state across distributed processes.

    This method should start the collective operations (e.g., all-reduce) required
    to aggregate statistics, but should not block waiting for completion if possible.

    Args:
        dist_context: The distributed context.
    """

update(*args, **kwargs) abstractmethod

Updates the metric state with a new batch of data.

Parameters:

Name Type Description Default
*args Any

Positional arguments required by the specific metric implementation.

()
**kwargs Any

Keyword arguments required by the specific metric implementation.

{}
Source code in d9d/metric/abc.py
22
23
24
25
26
27
28
29
30
@abc.abstractmethod
def update(self, *args: Any, **kwargs: Any):
    """
    Updates the metric state with a new batch of data.

    Args:
        *args: Positional arguments required by the specific metric implementation.
        **kwargs: Keyword arguments required by the specific metric implementation.
    """

wait_sync(dist_context) abstractmethod

Waits for the synchronization initiated by trigger_sync to complete.

After this method returns, the metric state must be fully aggregated and consistent across ranks.

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context.

required
Source code in d9d/metric/abc.py
44
45
46
47
48
49
50
51
52
53
54
@abc.abstractmethod
def wait_sync(self, dist_context: DistributedContext):
    """
    Waits for the synchronization initiated by `trigger_sync` to complete.

    After this method returns, the metric state must be fully aggregated and
    consistent across ranks.

    Args:
        dist_context: The distributed context.
    """

d9d.metric.impl

WeightedMeanMetric

Bases: Metric[Tensor]

Computes the weighted mean of values.

Tracks the sum of weighted values and the sum of weights.

Source code in d9d/metric/impl/mean.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class WeightedMeanMetric(Metric[torch.Tensor]):
    """
    Computes the weighted mean of values.

    Tracks the sum of weighted values and the sum of weights.
    """

    def __init__(self):
        """Constructs a WeightedMeanMetric object."""

        super().__init__()
        self._value = torch.scalar_tensor(0, dtype=torch.float32)
        self._weight = torch.scalar_tensor(0, dtype=torch.float32)

        self._is_synced = False
        self._synced_value = torch.scalar_tensor(0, dtype=torch.float32)
        self._synced_weight = torch.scalar_tensor(0, dtype=torch.float32)

        self._handles: list[dist.Work] | None = None

    def update(self, values: torch.Tensor, weights: torch.Tensor):
        self._value += (values * weights).sum()
        self._weight += weights.sum()

        self._is_synced = False

    def trigger_sync(self, dist_context: DistributedContext):
        self._synced_value = self._value.clone()
        self._synced_weight = self._weight.clone()
        self._is_synced = True

        self._handles = [
            dist.all_reduce(self._synced_value, op=dist.ReduceOp.SUM, async_op=True),
            dist.all_reduce(self._synced_weight, op=dist.ReduceOp.SUM, async_op=True)
        ]

    def wait_sync(self, dist_context: DistributedContext):
        if self._handles is None:
            raise RuntimeError("Sync was not triggered before")

        for handle in self._handles:
            handle.wait()
        self._handles = None

    def compute(self) -> torch.Tensor:
        if self._is_synced:
            return self._synced_value / self._synced_weight
        else:
            return self._value / self._weight

    def reset(self):
        self._value.fill_(0)
        self._weight.fill_(0)
        self._is_synced = False
        self._handles = None

    def to(self, device: str | torch.device | int):
        self._weight = self._weight.to(device)
        self._value = self._value.to(device)
        self._synced_weight = self._synced_weight.to(device)
        self._synced_value = self._synced_value.to(device)

    @property
    def accumulated_weight(self) -> torch.Tensor:
        """
        Returns the total weight accumulated so far.

        Returns:
            Scalar tensor with total weight.
        """

        if self._is_synced:
            return self._synced_weight

        return self._weight

    def state_dict(self) -> dict[str, Any]:
        return {
            "value": self._value,
            "weight": self._weight
        }

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        self._value = state_dict["value"]
        self._weight = state_dict["weight"]

accumulated_weight property

Returns the total weight accumulated so far.

Returns:

Type Description
Tensor

Scalar tensor with total weight.

__init__()

Constructs a WeightedMeanMetric object.

Source code in d9d/metric/impl/mean.py
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(self):
    """Constructs a WeightedMeanMetric object."""

    super().__init__()
    self._value = torch.scalar_tensor(0, dtype=torch.float32)
    self._weight = torch.scalar_tensor(0, dtype=torch.float32)

    self._is_synced = False
    self._synced_value = torch.scalar_tensor(0, dtype=torch.float32)
    self._synced_weight = torch.scalar_tensor(0, dtype=torch.float32)

    self._handles: list[dist.Work] | None = None