Skip to content

Mixture of Experts (MoE)

About

The d9d.module.block.moe package provides a complete, high-performance implementation of Sparse Mixture-of-Experts layers.

Expert Parallelism

For information on setting up Expert Parallelism, see this page.

Features

Sparse Expert Router

TopKRouter is a learnable router implementation.

It computes routing probabilities in FP32 to ensure numeric stability.

Sparse Expert Token Dispatcher

ExpertCommunicationHandler is the messaging layer.

NoCommunicationHandler is used by default for single-GPU or Tensor Parallel setups where no token movement is needed.

DeepEpCommunicationHandler is enabled if using Expert Parallelism. It uses the DeepEP library for highly optimized all-to-all communication over NVLink/RDMA, enabling scaling to thousands of experts.

Sparse Experts

GroupedSwiGLU provides a sparse SwiGLU experts module implementation.

Instead of looping over experts, it uses Grouped GEMM kernels to execute all experts in parallel, regardless of how many tokens each expert received.

Uses efficient fused SiLU-Mul kernel.

Kernel Benchmarks (BF16, H100)

Shared Experts

Currently not supported, feel free to contribute :)

d9d.module.block.moe

Provides building blocks for Mixture-of-Experts (MoE) architectures.

GroupedLinear

Bases: Module, ModuleLateInit

Applies a linear transformation using Grouped GEMM (Generalized Matrix Multiplication).

This module allows efficient execution of multiple linear layers (experts) in parallel, where each expert processes a variable number of tokens. It is the computational core of the Mixture-of-Experts layer.

__init__(n_groups, in_features, out_features, device=None, dtype=None)

Constructs the GroupedLinear layer.

Parameters:

Name Type Description Default
n_groups int

Number of groups (experts).

required
in_features int

Input hidden size.

required
out_features int

Output hidden size.

required
device device | str | None

Target device.

None
dtype dtype | None

Target data type.

None

forward(x, x_groups)

Performs the grouped matrix multiplication.

Parameters:

Name Type Description Default
x Tensor

Flattened input tensor containing tokens for all groups. Shape: (total_tokens, in_features).

required
x_groups Tensor

CPU Tensor indicating the number of tokens assigned to each group. Must sum to total_tokens. Shape: (n_groups,).

required

Returns:

Type Description
Tensor

The output tensor. Shape: (total_tokens, out_features).

reset_parameters()

Initializes weights using a uniform distribution based on input features.

GroupedSwiGLU

Bases: Module, ModuleLateInit

Executes a collection of SwiGLU experts efficiently using Grouped GEMM.

This module implements the architectural pattern: down_proj(SiLU(gate_proj(x)) * up_proj(x)). It applies this operation across multiple discrete experts in parallel without padding or masking.

__init__(hidden_dim, intermediate_dim, num_experts)

Constructs the GroupedSwiGLU module.

Parameters:

Name Type Description Default
hidden_dim int

Dimensionality of the input and output hidden states.

required
intermediate_dim int

Dimensionality of the intermediate projection.

required
num_experts int

Total number of experts managed by this local instance.

required

forward(permuted_x, permuted_probs, tokens_per_expert)

Computes expert outputs for sorted input tokens.

Parameters:

Name Type Description Default
permuted_x Tensor

Input tokens sorted by their assigned expert. Shape: (total_tokens, hidden_dim).

required
permuted_probs Tensor

Routing weights/probabilities corresponding to the sorted tokens. Shape: (total_tokens).

required
tokens_per_expert Tensor

Number of tokens assigned to each consecutive expert. It is a CPU tensor. Shape: (num_experts).

required

Returns:

Name Type Description
Tensor

The computed and weighted output tokens (still permuted).

Shape Tensor

(total_tokens, hidden_dim).

reset_parameters()

Resets parameters for all internal linear projections.

MoELayer

Bases: Module, ModuleLateInit

A complete Mixture-of-Experts (MoE) block comprising routing, communication, and computation.

This layer integrates:

  1. Router: Selects experts for each token.
  2. Communicator: Handles token dispatch to local or remote experts (EP).
  3. Experts: Performs parallelized computation (Grouped SwiGLU).

__init__(hidden_dim, intermediate_dim_grouped, num_grouped_experts, top_k, router_renormalize_probabilities)

Constructs the MoELayer.

Parameters:

Name Type Description Default
hidden_dim int

Hidden size.

required
intermediate_dim_grouped int

Intermediate dimension for the Expert FFNs.

required
num_grouped_experts int

Total number of experts.

required
top_k int

Number of experts to route each token to.

required
router_renormalize_probabilities bool

Configures router probability normalization behavior.

required

enable_distributed_communicator(group)

Switches from local no-op communication to distributed DeepEP communication.

This should be called during model initialization if the model is running in a distributed Expert Parallel environment.

Parameters:

Name Type Description Default
group ProcessGroup

The PyTorch process group spanning the expert parallel ranks.

required

forward(hidden_states)

Routes tokens to experts, computes, and combines results.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor. Shape: (batch_size, seq_len, hidden_dim).

required

Returns:

Type Description
Tensor

Output tensor combined from experts. Shape: (batch_size, seq_len, hidden_dim).

reset_parameters()

Resets module parameters.

reset_stats()

Resets the expert load balancing counters.

TopKRouter

Bases: Module, ModuleLateInit

Selects the top-K experts based on a learned gating mechanism.

This router:

  1. Projects input tokens into expert space
  2. Applies softmax, optionally adds expert bias to influence selection
  3. Selects the experts with the highest probabilities
  4. Selected probabilities are then re-normalized to sum to 1 if needed.

__init__(dim, num_experts, top_k, renormalize_probabilities, enable_expert_bias=False)

Constructs the TopKRouter.

Parameters:

Name Type Description Default
dim int

Input feature dimensionality.

required
num_experts int

Total number of experts to choose from.

required
top_k int

Number of experts to select for each token.

required
renormalize_probabilities bool

If True, probabilities of selected experts will be renormalized to sum up to 1

required
enable_expert_bias bool

If True, adds a bias term to the routing scores before top-k selection. This can be used for loss-free load balancing.

False

forward(hidden_states)

Calculates routing decisions for the input tokens.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tokens. Shape: (num_tokens, dim).

required

Returns:

Type Description
Tensor

A tuple containing:

Tensor
  • Selected expert indices. Shape: (num_tokens, top_k).
tuple[Tensor, Tensor]
  • Normalized routing weights for the selected experts. Shape: (num_tokens, top_k).

reset_parameters()

Resets module parameters.

d9d.module.block.moe.communications

Provides communication strategies for Mixture-of-Experts routing operations.

DeepEpCommunicationHandler

Bases: ExpertCommunicationHandler

Handles MoE communication using the high-performance DeepEP library.

__init__(num_experts)

Constructs the DeepEpCommunicationHandler.

setup(group, hidden_size, hidden_dtype)

Initializes the backend buffer and calculates expert sharding.

Parameters:

Name Type Description Default
group ProcessGroup

The process group containing all experts.

required
hidden_size int

Dimensionality of the hidden states.

required
hidden_dtype dtype

Data type of the hidden states.

required

ExpertCommunicationHandler

Bases: ABC

Abstract base class for Mixture-of-Experts communication strategies.

combine(hidden_states) abstractmethod

Restores hidden states to their original order and location.

Undoes the permutation and performs the reverse All-to-All communication to return processed results to the workers that originated the requests.

Parameters:

Name Type Description Default
hidden_states Tensor

The processed hidden states. Shape: (num_received_tokens, hidden_size).

required

Returns:

Type Description
Tensor

The combined hidden states with the original shape and order. Shape: (num_tokens, hidden_size).

dispatch(hidden_states, topk_ids, topk_weights) abstractmethod

Prepares and routes local hidden states to their target experts (possibly on other workers).

This process involves:

  1. All-to-All Communication: Transfers hidden states to workers containing the assigned experts. States assigned to multiple experts are replicated.

  2. Permutation: Sorts tokens by expert ID to prepare for Grouped GEMM.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tokens. Shape: (num_tokens, hidden_size).

required
topk_ids Tensor

Indices of the top-k experts selected for each token. Shape: (num_tokens, k).

required
topk_weights Tensor

Routing weights associated with the selected experts. Shape: (num_tokens, k).

required

Returns:

Type Description
Tensor

A tuple containing:

Tensor
  • Permuted hidden states received by this rank. Shape: (num_received_tokens, hidden_size).
Tensor
  • Permuted weights matching the hidden states order. Shape: (num_received_tokens).
tuple[Tensor, Tensor, Tensor]
  • Expert count tensor indicating how many tokens each local expert received. Shape: (num_local_experts).

NoCommunicationHandler

Bases: ExpertCommunicationHandler

Handles MoE routing within a single device or when no cross-device routing is needed.

This handler does not perform network operations. It only permutes elements mostly for local logical grouping or debugging.

__init__(num_experts)

Constructs the NoCommunicationHandler.