Skip to content

Normalization Layers

About

The d9d.module.block.normalization module implements memory-efficient normalization layers.

Features

RMSNorm

RMSNorm implements Root Mean Square Layer Normalization.

Uses an efficient custom Triton kernel for forward and backward passes.

It includes native support for zero-centered scaling weights.

Kernel Benchmarks (BF16, H100)

Forward, Hidden Size = 128

Forward, Hidden Size = 256

Forward, Hidden Size = 1024

Forward, Hidden Size = 4096

Forward, Hidden Size = 7168

Backward, Hidden Size = 128

Backward, Hidden Size = 256

Backward, Hidden Size = 1024

Backward, Hidden Size = 4096

Backward, Hidden Size = 7168

d9d.module.block.normalization

RMSNorm

Bases: Module, ModuleLateInit

Implements Root Mean Square (RMS) Normalization.

This module normalizes the input tensor across its last dimension using the root mean square statistic, applying learnable scaling weights. It can optionally use zero-centered weights.

References

Root Mean Square Layer Normalization

__init__(hidden_size, eps=1e-06, zero_centered=False)

Constructs an RMSNorm object.

Parameters:

Name Type Description Default
hidden_size int

The dimensionality of the hidden size to normalize.

required
eps float

A small value added to the variance for numerical stability to prevent division by zero.

1e-06
zero_centered bool

If True, the scaling weights are initialized to 0.0 and implicitly offset by 1.0 during computation. Otherwise, they are initialized to 1.0.

False

forward(x)

Applies RMS Normalization to the input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor to be normalized. The normalization is applied over the last dimension.

required

Returns:

Type Description
Tensor

The normalized tensor with the same shape as the input.

reset_parameters()

Resets module parameters.