Normalization Layers
About
The d9d.module.block.normalization module implements memory-efficient normalization layers.
Features
RMSNorm
RMSNorm implements Root Mean Square Layer Normalization.
Uses an efficient custom Triton kernel for forward and backward passes.
It includes native support for zero-centered scaling weights.
Kernel Benchmarks (BF16, H100)
Forward, Hidden Size = 128

Forward, Hidden Size = 256

Forward, Hidden Size = 1024

Forward, Hidden Size = 4096

Forward, Hidden Size = 7168

Backward, Hidden Size = 128

Backward, Hidden Size = 256

Backward, Hidden Size = 1024

Backward, Hidden Size = 4096

Backward, Hidden Size = 7168

d9d.module.block.normalization
RMSNorm
Bases: Module, ModuleLateInit
Implements Root Mean Square (RMS) Normalization.
This module normalizes the input tensor across its last dimension using the root mean square statistic, applying learnable scaling weights. It can optionally use zero-centered weights.
References
__init__(hidden_size, eps=1e-06, zero_centered=False)
Constructs an RMSNorm object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_size
|
int
|
The dimensionality of the hidden size to normalize. |
required |
eps
|
float
|
A small value added to the variance for numerical stability to prevent division by zero. |
1e-06
|
zero_centered
|
bool
|
If True, the scaling weights are initialized to 0.0 and implicitly offset by 1.0 during computation. Otherwise, they are initialized to 1.0. |
False
|
forward(x)
reset_parameters()
Resets module parameters.