Skip to content

Stochastic Optimizers

What is Stochastic Rounding and Why is It Useful For Optimization

Standard floating-point casting (e.g., tensor.to(torch.bfloat16)) typically utilizes Round-to-Nearest-Even. This method is statistically biased.

When training models in reduced precision (like BF16), standard "Round to Nearest" operations can lead to stalling. If a weight update is smaller than the smallest representable difference for a given float value, the update disappears completely.

Stochastic Rounding replaces rigid rounding with a probabilistic approach: for instance, if a value \(x\) is \(30\%\) of the way between representable numbers \(A\) and \(B\), it has a \(30\%\) chance of rounding to \(B\) and \(70\%\) chance of rounding to \(A\). Over multiple updates, the statistical expectation matches the true high-precision value \(E[Round(x)] = x\), allowing training to converge even when individual updates are technically "too small" for the format.

For more information, please refer to:

About

This module provides optimizers for low precision training with stochastic rounding using highly optimized Triton kernels.

Benchmarks

All the benchmarks were performed on a single NVDIA H100 80GB GPU.

copy_fp32_to_bf16_stochastic_

adamw_stochastic_bf16_

d9d.optim.stochastic

StochasticAdamW

Bases: Optimizer

Implements the AdamW algorithm with Stochastic Rounding.

This optimizer is designed to handle stochastic rounding primarily for BF16 training, leveraging a custom kernel.

Parameters must be in BF16. Gradients could be both in BF16 and FP32.

It natively supports PyTorch distributed DTensor parameters.

It maintains its own random number generator state to ensure reproducibility.

__init__(params, lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, generator=None, state_dtype=torch.float32)

Constructs a new StochasticAdamW optimizer.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

required
betas tuple[float, float]

Coefficients used for computing running averages of gradient and its square.

(0.9, 0.999)
eps float

Term added to the denominator to improve numerical stability.

1e-08
weight_decay float

Weight decay coefficient.

0.01
generator Generator | None

Pseudorandom number generator for stochastic rounding. If None, a new generator is created and seeded from the main PyTorch generator.

None
state_dtype dtype

Data Type to use for the optimizer states.

float32

d9d.kernel.stochastic

Utilities for stochastic type casting (e.g., FP32 to BF16).

adamw_stochastic_bf16_(params, grads, exp_avg, exp_avg_sq, lr, beta1, beta2, eps, weight_decay, step, generator=None)

Performs a single in-place AdamW optimization step.

It is specifically designed for scenarios where parameters are stored in BFloat16.

To mitigate precision loss during the parameter update, it utilizes stochastic rounding when casting FP32 calculation results back to BFloat16.

This function supports mixed precision for gradients and optimizer states (they can be either FP32 or BFloat16).

Parameters:

Name Type Description Default
params Tensor

The tensor of model parameters to update. Must be BFloat16 and contiguous.

required
grads Tensor

The gradient tensor.

required
exp_avg Tensor

The exponential moving average of gradient values (first moment).

required
exp_avg_sq Tensor

The exponential moving average of squared gradient values (second moment).

required
lr float

The learning rate.

required
beta1 float

Decay rate for the first moment estimate.

required
beta2 float

Decay rate for the second moment estimate.

required
eps float

Term added to the denominator to improve numerical stability.

required
weight_decay float

Weight decay coefficient.

required
step int

The current optimization step count, used for bias correction.

required
generator Generator | None

PyTorch random number generator used to create the seed for stochastic rounding.

None

Raises:

Type Description
ValueError

If main parameters are not BFloat16, if input tensor shapes do not match, if input tensors are not contiguous (for those that require in-place modification), if the optimizer states (exp_avg, exp_avg_sq) have different dtypes.

copy_fp32_to_bf16_stochastic_(target, source, generator=None)

Copies elements from a Float32 tensor to a BFloat16 tensor using stochastic rounding.

Unlike standard round-to-nearest casting, stochastic rounding probabilistically rounds numbers up or down based on the value of the bits being truncated. This preserves the expected value of the tensor (E[round(x)] = x), which is crucial for accumulating gradients or parameters in low precision without stagnation.

This operation is performed in-place on the target tensor.

Parameters:

Name Type Description Default
target Tensor

The output tensor where results are written. Must be of type BFloat16 and contiguous.

required
source Tensor

The input tensor containing values to copy. Must be of type Float32.

required
generator Generator | None

An optional PyTorch RNG generator to strictly control the random noise used for rounding.

None

Returns:

Type Description
Tensor

The target tensor, modified in-place.

Raises:

Type Description
ValueError

If target is not contiguous, if source/target shapes do not match, or if dtypes are not FP32 and BF16 respectively.