What is Stochastic Rounding and Why is It Useful For Optimization

Standard floating-point casting (e.g., tensor.to(torch.bfloat16)) typically utilizes Round-to-Nearest-Even. This method is statistically biased.

When training models in reduced precision (like BF16), standard "Round to Nearest" operations can lead to stalling. If a weight update is smaller than the smallest representable difference for a given float value, the update disappears completely.

Stochastic Rounding replaces rigid rounding with a probabilistic approach: for instance, if a value \(x\) is \(30\%\) of the way between representable numbers \(A\) and \(B\), it has a \(30\%\) chance of rounding to \(B\) and \(70\%\) chance of rounding to \(A\). Over multiple updates, the statistical expectation matches the true high-precision value \(E[Round(x)] = x\), allowing training to converge even when individual updates are technically "too small" for the format.

For more information, please refer to:

About

This module provides optimizers for low precision training with stochastic rounding using highly optimized Triton kernels.

Benchmarks

All the benchmarks were performed on a single NVDIA H100 80GB GPU.

copy_fp32_to_bf16_stochastic_

adamw_stochastic_bf16_

d9d.optim.stochastic

StochasticAdamW

Bases: Optimizer

Implements the AdamW algorithm with Stochastic Rounding.

This optimizer is designed to handle stochastic rounding primarily for BF16 training, leveraging a custom kernel.

Parameters must be in BF16. Gradients could be both in BF16 and FP32.

It natively supports PyTorch distributed DTensor parameters.

It maintains its own random number generator state to ensure reproducibility.

Source code in d9d/optim/stochastic/adamw.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class StochasticAdamW(Optimizer):
    """Implements the AdamW algorithm with Stochastic Rounding.

    This optimizer is designed to handle stochastic rounding primarily for BF16 training,
    leveraging a custom kernel.

    Parameters must be in BF16. Gradients could be both in BF16 and FP32.

    It natively supports PyTorch distributed ``DTensor`` parameters.

    It maintains its own random number generator state to ensure reproducibility.
    """

    def __init__(
            self,
            params: ParamsT,
            lr: float,
            betas: tuple[float, float] = (0.9, 0.999),
            eps: float = 1e-8,
            weight_decay: float = 1e-2,
            generator: torch.Generator | None = None,
            state_dtype: torch.dtype = torch.float32,
    ):
        """Constructs a new StochasticAdamW optimizer.

         Args:
             params: Iterable of parameters to optimize or dicts defining parameter groups.
             lr: Learning rate.
             betas: Coefficients used for computing running averages of gradient and its square.
             eps: Term added to the denominator to improve numerical stability.
             weight_decay: Weight decay coefficient.
             generator: Pseudorandom number generator for stochastic rounding. If None,
                 a new generator is created and seeded from the main PyTorch generator.
             state_dtype: Data Type to use for the optimizer states.
         """

        if lr <= 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if eps <= 0:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if weight_decay <= 0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        if generator is None:
            generator = torch.Generator(device="cpu")
            # make the generator fork from pytorch's main generator
            seed = cast(int, torch.randint(0, 2**32, (1,)).item())
            generator.manual_seed(seed)

        self._generator = generator

        defaults = {
            "lr": lr,
            "betas": betas,
            "eps": eps,
            "weight_decay": weight_decay,
            "state_dtype": state_dtype
        }
        super().__init__(params, defaults)

    def state_dict(self) -> StateDict:
        state_dict = super().state_dict()
        state_dict[_GENERATOR_STATE_KEY] = self._generator.get_state()
        return state_dict

    def load_state_dict(self, state_dict: StateDict) -> None:
        if _GENERATOR_STATE_KEY in state_dict:
            self._generator.set_state(state_dict.pop(_GENERATOR_STATE_KEY))
        super().load_state_dict(state_dict)

    @torch.no_grad()
    def step(self, closure: None = None) -> None:  # type: ignore[override]
        if closure is not None:
            raise ValueError("Closure is not supported")

        for group in self.param_groups:
            lr = group["lr"]
            beta1, beta2 = group["betas"]
            eps = group["eps"]
            weight_decay = group["weight_decay"]
            state_dtype = group["state_dtype"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("StochasticAdamW does not support sparse gradients")

                state = self.state[p]

                # State Initialization
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = _new_buffer(p, dtype_override=state_dtype)
                    state["exp_avg_sq"] = _new_buffer(p, dtype_override=state_dtype)

                state["step"] += 1
                exp_avg = state["exp_avg"]
                exp_avg_sq = state["exp_avg_sq"]

                adamw_stochastic_bf16_(
                    params=_tensor_to_local(p),
                    grads=_tensor_to_local(grad),
                    exp_avg=_tensor_to_local(exp_avg),
                    exp_avg_sq=_tensor_to_local(exp_avg_sq),
                    lr=lr,
                    beta1=beta1,
                    beta2=beta2,
                    eps=eps,
                    weight_decay=weight_decay,
                    step=state["step"],
                    generator=self._generator
                )

__init__(params, lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, generator=None, state_dtype=torch.float32)

Constructs a new StochasticAdamW optimizer.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

required
betas tuple[float, float]

Coefficients used for computing running averages of gradient and its square.

(0.9, 0.999)
eps float

Term added to the denominator to improve numerical stability.

1e-08
weight_decay float

Weight decay coefficient.

0.01
generator Generator | None

Pseudorandom number generator for stochastic rounding. If None, a new generator is created and seeded from the main PyTorch generator.

None
state_dtype dtype

Data Type to use for the optimizer states.

float32
Source code in d9d/optim/stochastic/adamw.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
        self,
        params: ParamsT,
        lr: float,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 1e-2,
        generator: torch.Generator | None = None,
        state_dtype: torch.dtype = torch.float32,
):
    """Constructs a new StochasticAdamW optimizer.

     Args:
         params: Iterable of parameters to optimize or dicts defining parameter groups.
         lr: Learning rate.
         betas: Coefficients used for computing running averages of gradient and its square.
         eps: Term added to the denominator to improve numerical stability.
         weight_decay: Weight decay coefficient.
         generator: Pseudorandom number generator for stochastic rounding. If None,
             a new generator is created and seeded from the main PyTorch generator.
         state_dtype: Data Type to use for the optimizer states.
     """

    if lr <= 0:
        raise ValueError(f"Invalid learning rate: {lr}")
    if eps <= 0:
        raise ValueError(f"Invalid epsilon value: {eps}")
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
    if weight_decay <= 0:
        raise ValueError(f"Invalid weight_decay value: {weight_decay}")

    if generator is None:
        generator = torch.Generator(device="cpu")
        # make the generator fork from pytorch's main generator
        seed = cast(int, torch.randint(0, 2**32, (1,)).item())
        generator.manual_seed(seed)

    self._generator = generator

    defaults = {
        "lr": lr,
        "betas": betas,
        "eps": eps,
        "weight_decay": weight_decay,
        "state_dtype": state_dtype
    }
    super().__init__(params, defaults)

d9d.kernel.stochastic

Utilities for stochastic type casting (e.g., FP32 to BF16).

adamw_stochastic_bf16_(params, grads, exp_avg, exp_avg_sq, lr, beta1, beta2, eps, weight_decay, step, generator=None)

Performs a single in-place AdamW optimization step.

It is specifically designed for scenarios where parameters are stored in BFloat16.

To mitigate precision loss during the parameter update, it utilizes stochastic rounding when casting FP32 calculation results back to BFloat16.

This function supports mixed precision for gradients and optimizer states (they can be either FP32 or BFloat16).

Parameters:

Name Type Description Default
params Tensor

The tensor of model parameters to update. Must be BFloat16 and contiguous.

required
grads Tensor

The gradient tensor.

required
exp_avg Tensor

The exponential moving average of gradient values (first moment).

required
exp_avg_sq Tensor

The exponential moving average of squared gradient values (second moment).

required
lr float

The learning rate.

required
beta1 float

Decay rate for the first moment estimate.

required
beta2 float

Decay rate for the second moment estimate.

required
eps float

Term added to the denominator to improve numerical stability.

required
weight_decay float

Weight decay coefficient.

required
step int

The current optimization step count, used for bias correction.

required
generator Generator | None

PyTorch random number generator used to create the seed for stochastic rounding.

None

Raises:

Type Description
ValueError

If main parameters are not BFloat16, if input tensor shapes do not match, if input tensors are not contiguous (for those that require in-place modification), if the optimizer states (exp_avg, exp_avg_sq) have different dtypes.

Source code in d9d/kernel/stochastic/adamw_step.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def adamw_stochastic_bf16_(  # noqa: C901
        params: torch.Tensor,
        grads: torch.Tensor,
        exp_avg: torch.Tensor,
        exp_avg_sq: torch.Tensor,
        lr: float,
        beta1: float,
        beta2: float,
        eps: float,
        weight_decay: float,
        step: int,
        generator: torch.Generator | None = None
) -> None:
    """
    Performs a single in-place AdamW optimization step.

    It is specifically designed for scenarios where parameters are stored in BFloat16.

    To mitigate precision loss during the parameter update, it utilizes stochastic rounding when casting
    FP32 calculation results back to BFloat16.

    This function supports mixed precision for gradients and optimizer states (they can be
    either FP32 or BFloat16).

    Args:
        params: The tensor of model parameters to update. Must be BFloat16 and contiguous.
        grads: The gradient tensor.
        exp_avg: The exponential moving average of gradient values (first moment).
        exp_avg_sq: The exponential moving average of squared gradient values (second moment).
        lr: The learning rate.
        beta1: Decay rate for the first moment estimate.
        beta2: Decay rate for the second moment estimate.
        eps: Term added to the denominator to improve numerical stability.
        weight_decay: Weight decay coefficient.
        step: The current optimization step count, used for bias correction.
        generator: PyTorch random number generator used to create the seed for stochastic rounding.

    Raises:
        ValueError: If main parameters are not BFloat16, if input tensor shapes do not match,
            if input tensors are not contiguous (for those that require in-place modification),
            if the optimizer states (exp_avg, exp_avg_sq) have different dtypes.
    """

    # check shape equality
    if grads.shape != params.shape:
        raise ValueError("Shape mismatch between grads and params.")

    if exp_avg.shape != params.shape:
        raise ValueError("Shape mismatch between exp_avg state and params.")

    if exp_avg_sq.shape != params.shape:
        raise ValueError("Shape mismatch between exp_avg_sq state and params.")

    # check params
    if params.dtype != torch.bfloat16:
        raise ValueError("Params must be BFloat16 for this kernel.")

    if not params.is_contiguous():
        raise ValueError("Params must be contiguous since it is an in-place kernel.")

    # check grads
    if not grads.is_contiguous():
        grads = grads.contiguous()

    # check states
    if not exp_avg.is_contiguous():
        raise ValueError("Exp_avg state must be contiguous since it is an in-place kernel.")

    if not exp_avg_sq.is_contiguous():
        raise ValueError("Exp_avg_sq state must be contiguous since it is an in-place kernel.")

    if exp_avg.dtype != exp_avg_sq.dtype:
        raise ValueError("States have different dtypes.")

    n_elements = params.numel()

    grad_is_bf16 = (grads.dtype == torch.bfloat16)
    state_is_bf16 = (exp_avg.dtype == torch.bfloat16)

    # Generate random seed
    seed = torch.randint(
        0, 2 ** 31 - 1, (1,),
        device="cpu",
        generator=generator
    ).item()

    def _grid(meta: dict[str, int]) -> tuple[int, ...]:
        return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

    _adamw_stochastic_bf16_kernel[_grid](
        params,
        grads,
        exp_avg,
        exp_avg_sq,

        n_elements,

        lr,
        beta1,
        beta2,
        eps,
        weight_decay,
        step,
        seed,

        GRAD_IS_BF16=grad_is_bf16,
        STATE_IS_BF16=state_is_bf16
    )

copy_fp32_to_bf16_stochastic_(target, source, generator=None)

Copies elements from a Float32 tensor to a BFloat16 tensor using stochastic rounding.

Unlike standard round-to-nearest casting, stochastic rounding probabilistically rounds numbers up or down based on the value of the bits being truncated. This preserves the expected value of the tensor (E[round(x)] = x), which is crucial for accumulating gradients or parameters in low precision without stagnation.

This operation is performed in-place on the target tensor.

Parameters:

Name Type Description Default
target Tensor

The output tensor where results are written. Must be of type BFloat16 and contiguous.

required
source Tensor

The input tensor containing values to copy. Must be of type Float32.

required
generator Generator | None

An optional PyTorch RNG generator to strictly control the random noise used for rounding.

None

Returns:

Type Description
Tensor

The target tensor, modified in-place.

Raises:

Type Description
ValueError

If target is not contiguous, if source/target shapes do not match, or if dtypes are not FP32 and BF16 respectively.

Source code in d9d/kernel/stochastic/copy.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def copy_fp32_to_bf16_stochastic_(
        target: torch.Tensor,
        source: torch.Tensor,
        generator: torch.Generator | None = None
) -> torch.Tensor:
    """
    Copies elements from a Float32 tensor to a BFloat16 tensor using stochastic rounding.

    Unlike standard round-to-nearest casting, stochastic rounding probabilistically rounds
    numbers up or down based on the value of the bits being truncated. This preserves the
    expected value of the tensor (E[round(x)] = x), which is crucial for accumulating
    gradients or parameters in low precision without stagnation.

    This operation is performed in-place on the target tensor.

    Args:
        target: The output tensor where results are written. Must be of type BFloat16
            and contiguous.
        source: The input tensor containing values to copy. Must be of type Float32.
        generator: An optional PyTorch RNG generator to strictly control the random
            noise used for rounding.

    Returns:
        The target tensor, modified in-place.

    Raises:
        ValueError: If target is not contiguous, if source/target shapes do not match,
            or if dtypes are not FP32 and BF16 respectively.
    """

    if not source.is_contiguous():
        source = source.contiguous()

    if not target.is_contiguous():
        raise ValueError("Since this is an in-place operation, target should be a contiguous tensor!")

    if source.shape != target.shape:
        raise ValueError("Source and Target Tensors are of different shapes")

    if source.dtype != torch.float32:
        raise ValueError("Source must be Float32")
    if target.dtype != torch.bfloat16:
        raise ValueError("Target must be BFloat16")

    n_elements = source.numel()

    # Generate a random seed for this specific kernel launch
    seed = torch.randint(
        0, 2 ** 31 - 1, (1,),
        device="cpu",
        generator=generator
    ).item()

    def _grid(meta: dict[str, int]) -> tuple[int, ...]:
        return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

    _copy_fp32_to_bf16_kernel[_grid](
        source,
        target,
        n_elements,
        seed
    )
    return target