What is Stochastic Rounding and Why is It Useful For Optimization
Standard floating-point casting (e.g., tensor.to(torch.bfloat16)) typically utilizes Round-to-Nearest-Even. This method is statistically biased.
When training models in reduced precision (like BF16), standard "Round to Nearest" operations can lead to stalling. If a weight update is smaller than the smallest representable difference for a given float value, the update disappears completely.
Stochastic Rounding replaces rigid rounding with a probabilistic approach: for instance, if a value \(x\) is \(30\%\) of the way between representable numbers \(A\) and \(B\), it has a \(30\%\) chance of rounding to \(B\) and \(70\%\) chance of rounding to \(A\). Over multiple updates, the statistical expectation matches the true high-precision value \(E[Round(x)] = x\), allowing training to converge even when individual updates are technically "too small" for the format.
For more information, please refer to:
- Zamirai, Pedram, et al. “Revisiting BFloat16 Training.” Version 2
- Ozkara, Kaan, et al. “Stochastic Rounding for LLM Training: Theory and Practice.”
About
This module provides optimizers for low precision training with stochastic rounding using highly optimized Triton kernels.
Benchmarks
All the benchmarks were performed on a single NVDIA H100 80GB GPU.
copy_fp32_to_bf16_stochastic_

adamw_stochastic_bf16_

d9d.optim.stochastic
StochasticAdamW
Bases: Optimizer
Implements the AdamW algorithm with Stochastic Rounding.
This optimizer is designed to handle stochastic rounding primarily for BF16 training, leveraging a custom kernel.
Parameters must be in BF16. Gradients could be both in BF16 and FP32.
It natively supports PyTorch distributed DTensor parameters.
It maintains its own random number generator state to ensure reproducibility.
Source code in d9d/optim/stochastic/adamw.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | |
__init__(params, lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, generator=None, state_dtype=torch.float32)
Constructs a new StochasticAdamW optimizer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParamsT
|
Iterable of parameters to optimize or dicts defining parameter groups. |
required |
lr
|
float
|
Learning rate. |
required |
betas
|
tuple[float, float]
|
Coefficients used for computing running averages of gradient and its square. |
(0.9, 0.999)
|
eps
|
float
|
Term added to the denominator to improve numerical stability. |
1e-08
|
weight_decay
|
float
|
Weight decay coefficient. |
0.01
|
generator
|
Generator | None
|
Pseudorandom number generator for stochastic rounding. If None, a new generator is created and seeded from the main PyTorch generator. |
None
|
state_dtype
|
dtype
|
Data Type to use for the optimizer states. |
float32
|
Source code in d9d/optim/stochastic/adamw.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | |
d9d.kernel.stochastic
Utilities for stochastic type casting (e.g., FP32 to BF16).
adamw_stochastic_bf16_(params, grads, exp_avg, exp_avg_sq, lr, beta1, beta2, eps, weight_decay, step, generator=None)
Performs a single in-place AdamW optimization step.
It is specifically designed for scenarios where parameters are stored in BFloat16.
To mitigate precision loss during the parameter update, it utilizes stochastic rounding when casting FP32 calculation results back to BFloat16.
This function supports mixed precision for gradients and optimizer states (they can be either FP32 or BFloat16).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Tensor
|
The tensor of model parameters to update. Must be BFloat16 and contiguous. |
required |
grads
|
Tensor
|
The gradient tensor. |
required |
exp_avg
|
Tensor
|
The exponential moving average of gradient values (first moment). |
required |
exp_avg_sq
|
Tensor
|
The exponential moving average of squared gradient values (second moment). |
required |
lr
|
float
|
The learning rate. |
required |
beta1
|
float
|
Decay rate for the first moment estimate. |
required |
beta2
|
float
|
Decay rate for the second moment estimate. |
required |
eps
|
float
|
Term added to the denominator to improve numerical stability. |
required |
weight_decay
|
float
|
Weight decay coefficient. |
required |
step
|
int
|
The current optimization step count, used for bias correction. |
required |
generator
|
Generator | None
|
PyTorch random number generator used to create the seed for stochastic rounding. |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If main parameters are not BFloat16, if input tensor shapes do not match, if input tensors are not contiguous (for those that require in-place modification), if the optimizer states (exp_avg, exp_avg_sq) have different dtypes. |
Source code in d9d/kernel/stochastic/adamw_step.py
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | |
copy_fp32_to_bf16_stochastic_(target, source, generator=None)
Copies elements from a Float32 tensor to a BFloat16 tensor using stochastic rounding.
Unlike standard round-to-nearest casting, stochastic rounding probabilistically rounds numbers up or down based on the value of the bits being truncated. This preserves the expected value of the tensor (E[round(x)] = x), which is crucial for accumulating gradients or parameters in low precision without stagnation.
This operation is performed in-place on the target tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target
|
Tensor
|
The output tensor where results are written. Must be of type BFloat16 and contiguous. |
required |
source
|
Tensor
|
The input tensor containing values to copy. Must be of type Float32. |
required |
generator
|
Generator | None
|
An optional PyTorch RNG generator to strictly control the random noise used for rounding. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The target tensor, modified in-place. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If target is not contiguous, if source/target shapes do not match, or if dtypes are not FP32 and BF16 respectively. |
Source code in d9d/kernel/stochastic/copy.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | |