About

The d9d.module.block.moe package provides a complete, high-performance implementation of Sparse Mixture-of-Experts layers.

Expert Parallelism

For information on setting up Expert Parallelism, see this page

Features

Sparse Expert Router

TopKRouter is a learnable router implementation.

It computes routing probabilities in FP32 to ensure numeric stability.

Sparse Expert Token Dispatcher

ExpertCommunicationHandler is the messaging layer.

NoCommunicationHandler is used by default for single-GPU or Tensor Parallel setups where no token movement is needed.

DeepEpCommunicationHandler is enabled if using Expert Parallelism. It uses the DeepEP library for highly optimized all-to-all communication over NVLink/RDMA, enabling scaling to thousands of experts.

Sparse Experts

GroupedSwiGLU provides a sparse SwiGLU experts module implementation.

Instead of looping over experts, it uses Grouped GEMM kernels to execute all experts in parallel, regardless of how many tokens each expert received.

Shared Experts

Currently not supported, feel free to contribute :)

d9d.module.block.moe

Provides building blocks for Mixture-of-Experts (MoE) architectures.

GroupedLinear

Bases: Module, ModuleLateInit

Applies a linear transformation using Grouped GEMM (Generalized Matrix Multiplication).

This module allows efficient execution of multiple linear layers (experts) in parallel, where each expert processes a variable number of tokens. It is the computational core of the Mixture-of-Experts layer.

Source code in d9d/module/block/moe/grouped_linear.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class GroupedLinear(nn.Module, ModuleLateInit):
    """
    Applies a linear transformation using Grouped GEMM (Generalized Matrix Multiplication).

    This module allows efficient execution of multiple linear layers (experts) in parallel, where each expert
    processes a variable number of tokens.
    It is the computational core of the Mixture-of-Experts layer.
    """

    def __init__(
            self,
            n_groups: int,
            in_features: int,
            out_features: int,
            device: torch.device | str | None = None,
            dtype: torch.dtype | None = None
    ):
        """
        Constructs the GroupedLinear layer.

        Args:
            n_groups: Number of groups (experts).
            in_features: Input hidden size.
            out_features: Output hidden size.
            device: Target device.
            dtype: Target data type.
        """
        super().__init__()
        self.weight = nn.Parameter(torch.empty(n_groups, in_features, out_features,
                                               device=device, dtype=dtype))

        self.n_groups = n_groups
        self.in_features = in_features
        self.out_features = out_features

        self.reset_parameters()

    def forward(self, x: torch.Tensor, x_groups: torch.Tensor) -> torch.Tensor:
        """
        Performs the grouped matrix multiplication.

        Args:
            x: Flattened input tensor containing tokens for all groups.
                Shape: `(total_tokens, in_features)`.
            x_groups: CPU Tensor indicating the number of tokens assigned to each group.
                Must sum to `total_tokens`. Shape: `(n_groups,)`.

        Returns:
            The output tensor. Shape: `(total_tokens, out_features)`.
        """

        weight: torch.Tensor = self.weight

        if isinstance(weight, DTensor):
            weight = weight.to_local()

        return gmm(
            x,
            weight,
            x_groups,
            a_grad_direction=GradDirection.inputs,
            b_grad_direction=GradDirection.weight
        )

    def reset_parameters(self):
        """Initializes weights using a uniform distribution based on input features."""
        nn.init.uniform_(self.weight, -1 / math.sqrt(self.in_features), 1 / math.sqrt(self.in_features))

__init__(n_groups, in_features, out_features, device=None, dtype=None)

Constructs the GroupedLinear layer.

Parameters:

Name Type Description Default
n_groups int

Number of groups (experts).

required
in_features int

Input hidden size.

required
out_features int

Output hidden size.

required
device device | str | None

Target device.

None
dtype dtype | None

Target data type.

None
Source code in d9d/module/block/moe/grouped_linear.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(
        self,
        n_groups: int,
        in_features: int,
        out_features: int,
        device: torch.device | str | None = None,
        dtype: torch.dtype | None = None
):
    """
    Constructs the GroupedLinear layer.

    Args:
        n_groups: Number of groups (experts).
        in_features: Input hidden size.
        out_features: Output hidden size.
        device: Target device.
        dtype: Target data type.
    """
    super().__init__()
    self.weight = nn.Parameter(torch.empty(n_groups, in_features, out_features,
                                           device=device, dtype=dtype))

    self.n_groups = n_groups
    self.in_features = in_features
    self.out_features = out_features

    self.reset_parameters()

forward(x, x_groups)

Performs the grouped matrix multiplication.

Parameters:

Name Type Description Default
x Tensor

Flattened input tensor containing tokens for all groups. Shape: (total_tokens, in_features).

required
x_groups Tensor

CPU Tensor indicating the number of tokens assigned to each group. Must sum to total_tokens. Shape: (n_groups,).

required

Returns:

Type Description
Tensor

The output tensor. Shape: (total_tokens, out_features).

Source code in d9d/module/block/moe/grouped_linear.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def forward(self, x: torch.Tensor, x_groups: torch.Tensor) -> torch.Tensor:
    """
    Performs the grouped matrix multiplication.

    Args:
        x: Flattened input tensor containing tokens for all groups.
            Shape: `(total_tokens, in_features)`.
        x_groups: CPU Tensor indicating the number of tokens assigned to each group.
            Must sum to `total_tokens`. Shape: `(n_groups,)`.

    Returns:
        The output tensor. Shape: `(total_tokens, out_features)`.
    """

    weight: torch.Tensor = self.weight

    if isinstance(weight, DTensor):
        weight = weight.to_local()

    return gmm(
        x,
        weight,
        x_groups,
        a_grad_direction=GradDirection.inputs,
        b_grad_direction=GradDirection.weight
    )

reset_parameters()

Initializes weights using a uniform distribution based on input features.

Source code in d9d/module/block/moe/grouped_linear.py
76
77
78
def reset_parameters(self):
    """Initializes weights using a uniform distribution based on input features."""
    nn.init.uniform_(self.weight, -1 / math.sqrt(self.in_features), 1 / math.sqrt(self.in_features))

GroupedSwiGLU

Bases: Module, ModuleLateInit

Executes a collection of SwiGLU experts efficiently using Grouped GEMM.

This module implements the architectural pattern: down_proj(SiLU(gate_proj(x)) * up_proj(x)). It applies this operation across multiple discrete experts in parallel without padding or masking.

Source code in d9d/module/block/moe/grouped_experts.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class GroupedSwiGLU(nn.Module, ModuleLateInit):
    """
    Executes a collection of SwiGLU experts efficiently using Grouped GEMM.

    This module implements the architectural pattern: `down_proj(SiLU(gate_proj(x)) * up_proj(x))`.
    It applies this operation across multiple discrete experts in parallel without padding or masking.
    """

    def __init__(
            self,
            hidden_dim: int,
            intermediate_dim: int,
            num_experts: int
    ):
        """
        Constructs the GroupedSwiGLU module.

        Args:
            hidden_dim: Dimensionality of the input and output hidden states.
            intermediate_dim: Dimensionality of the intermediate projection.
            num_experts: Total number of experts managed by this local instance.
        """

        super().__init__()
        self._num_experts = num_experts

        self.gate_proj = GroupedLinear(num_experts, hidden_dim, intermediate_dim)
        self.up_proj = GroupedLinear(num_experts, hidden_dim, intermediate_dim)
        self.down_proj = GroupedLinear(num_experts, intermediate_dim, hidden_dim)

    def forward(
            self,
            permuted_x: torch.Tensor,
            permuted_probs: torch.Tensor,
            tokens_per_expert: torch.Tensor,
    ) -> torch.Tensor:
        """
        Computes expert outputs for sorted input tokens.

        Args:
            permuted_x: Input tokens sorted by their assigned expert.
                Shape: `(total_tokens, hidden_dim)`.
            permuted_probs: Routing weights/probabilities corresponding to the sorted tokens.
                Shape: `(total_tokens)`.
            tokens_per_expert: Number of tokens assigned to each consecutive expert. It is a CPU tensor.
                Shape: `(num_experts)`.

        Returns:
            The computed and weighted output tokens (still permuted).
            Shape: `(total_tokens, hidden_dim)`.
        """

        if permuted_x.numel() == 0:  # handle cases when there are no routed experts to this instance
            return permuted_x

        probs = permuted_probs[:, None].to(permuted_x.dtype)
        values = self.down_proj(
            LigerSiLUMulFunction.apply(
                self.gate_proj(permuted_x, tokens_per_expert),
                self.up_proj(permuted_x, tokens_per_expert)
            ),
            tokens_per_expert
        )

        return probs * values

    def reset_parameters(self):
        """Resets parameters for all internal linear projections."""

        self.gate_proj.reset_parameters()
        self.up_proj.reset_parameters()
        self.down_proj.reset_parameters()

__init__(hidden_dim, intermediate_dim, num_experts)

Constructs the GroupedSwiGLU module.

Parameters:

Name Type Description Default
hidden_dim int

Dimensionality of the input and output hidden states.

required
intermediate_dim int

Dimensionality of the intermediate projection.

required
num_experts int

Total number of experts managed by this local instance.

required
Source code in d9d/module/block/moe/grouped_experts.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
        self,
        hidden_dim: int,
        intermediate_dim: int,
        num_experts: int
):
    """
    Constructs the GroupedSwiGLU module.

    Args:
        hidden_dim: Dimensionality of the input and output hidden states.
        intermediate_dim: Dimensionality of the intermediate projection.
        num_experts: Total number of experts managed by this local instance.
    """

    super().__init__()
    self._num_experts = num_experts

    self.gate_proj = GroupedLinear(num_experts, hidden_dim, intermediate_dim)
    self.up_proj = GroupedLinear(num_experts, hidden_dim, intermediate_dim)
    self.down_proj = GroupedLinear(num_experts, intermediate_dim, hidden_dim)

forward(permuted_x, permuted_probs, tokens_per_expert)

Computes expert outputs for sorted input tokens.

Parameters:

Name Type Description Default
permuted_x Tensor

Input tokens sorted by their assigned expert. Shape: (total_tokens, hidden_dim).

required
permuted_probs Tensor

Routing weights/probabilities corresponding to the sorted tokens. Shape: (total_tokens).

required
tokens_per_expert Tensor

Number of tokens assigned to each consecutive expert. It is a CPU tensor. Shape: (num_experts).

required

Returns:

Name Type Description
Tensor

The computed and weighted output tokens (still permuted).

Shape Tensor

(total_tokens, hidden_dim).

Source code in d9d/module/block/moe/grouped_experts.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def forward(
        self,
        permuted_x: torch.Tensor,
        permuted_probs: torch.Tensor,
        tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
    """
    Computes expert outputs for sorted input tokens.

    Args:
        permuted_x: Input tokens sorted by their assigned expert.
            Shape: `(total_tokens, hidden_dim)`.
        permuted_probs: Routing weights/probabilities corresponding to the sorted tokens.
            Shape: `(total_tokens)`.
        tokens_per_expert: Number of tokens assigned to each consecutive expert. It is a CPU tensor.
            Shape: `(num_experts)`.

    Returns:
        The computed and weighted output tokens (still permuted).
        Shape: `(total_tokens, hidden_dim)`.
    """

    if permuted_x.numel() == 0:  # handle cases when there are no routed experts to this instance
        return permuted_x

    probs = permuted_probs[:, None].to(permuted_x.dtype)
    values = self.down_proj(
        LigerSiLUMulFunction.apply(
            self.gate_proj(permuted_x, tokens_per_expert),
            self.up_proj(permuted_x, tokens_per_expert)
        ),
        tokens_per_expert
    )

    return probs * values

reset_parameters()

Resets parameters for all internal linear projections.

Source code in d9d/module/block/moe/grouped_experts.py
78
79
80
81
82
83
def reset_parameters(self):
    """Resets parameters for all internal linear projections."""

    self.gate_proj.reset_parameters()
    self.up_proj.reset_parameters()
    self.down_proj.reset_parameters()

MoELayer

Bases: Module, ModuleLateInit

A complete Mixture-of-Experts (MoE) block comprising routing, communication, and computation.

This layer integrates:

  1. Router: Selects experts for each token.
  2. Communicator: Handles token dispatch to local or remote experts (EP).
  3. Experts: Performs parallelized computation (Grouped SwiGLU).
Source code in d9d/module/block/moe/layer.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class MoELayer(nn.Module, ModuleLateInit):
    """
    A complete Mixture-of-Experts (MoE) block comprising routing, communication, and computation.

    This layer integrates:

    1.  **Router**: Selects experts for each token.
    2.  **Communicator**: Handles token dispatch to local or remote experts (EP).
    3.  **Experts**: Performs parallelized computation (Grouped SwiGLU).
    """

    def __init__(
            self,
            hidden_dim: int,
            intermediate_dim_grouped: int,
            num_grouped_experts: int,
            top_k: int,
            router_renormalize_probabilities: bool
    ):
        """
        Constructs the MoELayer.

       Args:
           hidden_dim: Hidden size.
           intermediate_dim_grouped: Intermediate dimension for the Expert FFNs.
           num_grouped_experts: Total number of experts.
           top_k: Number of experts to route each token to.
           router_renormalize_probabilities: Configures router probability normalization behavior.
       """

        super().__init__()
        self.router = TopKRouter(
            dim=hidden_dim, num_experts=num_grouped_experts, top_k=top_k,
            renormalize_probabilities=router_renormalize_probabilities
        )
        self.grouped_experts = GroupedSwiGLU(
            hidden_dim=hidden_dim,
            intermediate_dim=intermediate_dim_grouped,
            num_experts=num_grouped_experts
        )
        self._communicator: ExpertCommunicationHandler = NoCommunicationHandler(num_grouped_experts)

        self._num_grouped_experts = num_grouped_experts
        self._hidden_dim = hidden_dim

        self.tokens_per_expert = nn.Buffer(torch.empty((num_grouped_experts,), dtype=torch.int64), persistent=False)

    def enable_distributed_communicator(self, group: ProcessGroup):
        """
        Switches from local no-op communication to distributed DeepEP communication.

        This should be called during model initialization if the model is running in a
        distributed Expert Parallel environment.

        Args:
            group: The PyTorch process group spanning the expert parallel ranks.
        """

        communicator = DeepEpCommunicationHandler(num_experts=self._num_grouped_experts)
        communicator.setup(group, self._hidden_dim, self.router.gate.weight.dtype)
        self._communicator = communicator

    @torch.no_grad()
    def _update_tokens_per_expert(self, expert_indices: torch.Tensor):
        self.tokens_per_expert.add_(expert_indices.view(-1).bincount(minlength=self._num_grouped_experts))

    @torch.no_grad()
    def reset_stats(self):
        """Resets the expert load balancing counters."""
        self.tokens_per_expert.zero_()

    def forward(
            self,
            hidden_states: torch.Tensor
    ) -> torch.Tensor:
        """
        Routes tokens to experts, computes, and combines results.

        Args:
            hidden_states: Input tensor. Shape: `(batch_size, seq_len, hidden_dim)`.

        Returns:
            Output tensor combined from experts. Shape: `(batch_size, seq_len, hidden_dim)`.
        """

        old_shape = hidden_states.shape
        hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
        expert_indices, expert_scores = self.router(hidden_states)
        self._update_tokens_per_expert(expert_indices)
        hidden_states, expert_scores, expert_count = self._communicator.dispatch(
            hidden_states, expert_indices, expert_scores
        )
        hidden_states = self.grouped_experts(hidden_states, expert_scores, expert_count)
        hidden_states = self._communicator.combine(hidden_states)
        hidden_states = hidden_states.reshape(*old_shape)

        return hidden_states

    def reset_parameters(self):
        """Resets module parameters."""
        self.router.reset_parameters()
        self.grouped_experts.reset_parameters()

        nn.init.zeros_(self.tokens_per_expert)

__init__(hidden_dim, intermediate_dim_grouped, num_grouped_experts, top_k, router_renormalize_probabilities)

Constructs the MoELayer.

Parameters:

Name Type Description Default
hidden_dim int

Hidden size.

required
intermediate_dim_grouped int

Intermediate dimension for the Expert FFNs.

required
num_grouped_experts int

Total number of experts.

required
top_k int

Number of experts to route each token to.

required
router_renormalize_probabilities bool

Configures router probability normalization behavior.

required
Source code in d9d/module/block/moe/layer.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(
        self,
        hidden_dim: int,
        intermediate_dim_grouped: int,
        num_grouped_experts: int,
        top_k: int,
        router_renormalize_probabilities: bool
):
    """
    Constructs the MoELayer.

   Args:
       hidden_dim: Hidden size.
       intermediate_dim_grouped: Intermediate dimension for the Expert FFNs.
       num_grouped_experts: Total number of experts.
       top_k: Number of experts to route each token to.
       router_renormalize_probabilities: Configures router probability normalization behavior.
   """

    super().__init__()
    self.router = TopKRouter(
        dim=hidden_dim, num_experts=num_grouped_experts, top_k=top_k,
        renormalize_probabilities=router_renormalize_probabilities
    )
    self.grouped_experts = GroupedSwiGLU(
        hidden_dim=hidden_dim,
        intermediate_dim=intermediate_dim_grouped,
        num_experts=num_grouped_experts
    )
    self._communicator: ExpertCommunicationHandler = NoCommunicationHandler(num_grouped_experts)

    self._num_grouped_experts = num_grouped_experts
    self._hidden_dim = hidden_dim

    self.tokens_per_expert = nn.Buffer(torch.empty((num_grouped_experts,), dtype=torch.int64), persistent=False)

enable_distributed_communicator(group)

Switches from local no-op communication to distributed DeepEP communication.

This should be called during model initialization if the model is running in a distributed Expert Parallel environment.

Parameters:

Name Type Description Default
group ProcessGroup

The PyTorch process group spanning the expert parallel ranks.

required
Source code in d9d/module/block/moe/layer.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def enable_distributed_communicator(self, group: ProcessGroup):
    """
    Switches from local no-op communication to distributed DeepEP communication.

    This should be called during model initialization if the model is running in a
    distributed Expert Parallel environment.

    Args:
        group: The PyTorch process group spanning the expert parallel ranks.
    """

    communicator = DeepEpCommunicationHandler(num_experts=self._num_grouped_experts)
    communicator.setup(group, self._hidden_dim, self.router.gate.weight.dtype)
    self._communicator = communicator

forward(hidden_states)

Routes tokens to experts, computes, and combines results.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor. Shape: (batch_size, seq_len, hidden_dim).

required

Returns:

Type Description
Tensor

Output tensor combined from experts. Shape: (batch_size, seq_len, hidden_dim).

Source code in d9d/module/block/moe/layer.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def forward(
        self,
        hidden_states: torch.Tensor
) -> torch.Tensor:
    """
    Routes tokens to experts, computes, and combines results.

    Args:
        hidden_states: Input tensor. Shape: `(batch_size, seq_len, hidden_dim)`.

    Returns:
        Output tensor combined from experts. Shape: `(batch_size, seq_len, hidden_dim)`.
    """

    old_shape = hidden_states.shape
    hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
    expert_indices, expert_scores = self.router(hidden_states)
    self._update_tokens_per_expert(expert_indices)
    hidden_states, expert_scores, expert_count = self._communicator.dispatch(
        hidden_states, expert_indices, expert_scores
    )
    hidden_states = self.grouped_experts(hidden_states, expert_scores, expert_count)
    hidden_states = self._communicator.combine(hidden_states)
    hidden_states = hidden_states.reshape(*old_shape)

    return hidden_states

reset_parameters()

Resets module parameters.

Source code in d9d/module/block/moe/layer.py
117
118
119
120
121
122
def reset_parameters(self):
    """Resets module parameters."""
    self.router.reset_parameters()
    self.grouped_experts.reset_parameters()

    nn.init.zeros_(self.tokens_per_expert)

reset_stats()

Resets the expert load balancing counters.

Source code in d9d/module/block/moe/layer.py
85
86
87
88
@torch.no_grad()
def reset_stats(self):
    """Resets the expert load balancing counters."""
    self.tokens_per_expert.zero_()

TopKRouter

Bases: Module, ModuleLateInit

Selects the top-K experts based on a learned gating mechanism.

This router:

  1. Projects input tokens into expert space
  2. Applies softmax, optionally adds expert bias to influence selection
  3. Selects the experts with the highest probabilities
  4. Selected probabilities are then re-normalized to sum to 1 if needed.
Source code in d9d/module/block/moe/router.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class TopKRouter(nn.Module, ModuleLateInit):
    """
    Selects the top-K experts based on a learned gating mechanism.

    This router:

    1. Projects input tokens into expert space
    2. Applies softmax, optionally adds expert bias to influence selection
    3. Selects the experts with the highest probabilities
    4. Selected probabilities are then re-normalized to sum to 1 if needed.
    """

    def __init__(
            self,
            dim: int,
            num_experts: int,
            top_k: int,
            renormalize_probabilities: bool,
            enable_expert_bias: bool = False
    ):
        """
        Constructs the TopKRouter.

        Args:
            dim: Input feature dimensionality.
            num_experts: Total number of experts to choose from.
            top_k: Number of experts to select for each token.
            renormalize_probabilities: If True, probabilities of selected experts will be renormalized to sum up to 1
            enable_expert_bias: If True, adds a bias term to the routing scores before top-k selection. This can be
                used for loss-free load balancing.
        """

        super().__init__()
        self.gate = nn.Linear(dim, num_experts, bias=False)

        self.expert_bias: nn.Buffer | None
        if enable_expert_bias:
            self.expert_bias = nn.Buffer(
                torch.empty(num_experts, dtype=torch.float32),
                persistent=True,
            )
        else:
            self.expert_bias = None

        self._num_experts = num_experts
        self._top_k = top_k
        self._renormalize_probabilities = renormalize_probabilities

    def forward(
            self,
            hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Calculates routing decisions for the input tokens.

        Args:
            hidden_states: Input tokens. Shape: `(num_tokens, dim)`.

        Returns:
            A tuple containing:

            - Selected expert indices. Shape: `(num_tokens, top_k)`.
            - Normalized routing weights for the selected experts. Shape: `(num_tokens, top_k)`.
        """

        # scores shape (bs*slen, num_experts)

        # gate
        scores = self.gate(hidden_states)

        # and now do softmax (before top-k to be able to apply expert bias)
        scores = F.softmax(scores, dim=-1, dtype=torch.float32)

        # select top-k
        if self.expert_bias is None:
            scores, selected_experts_indices = torch.topk(
                scores, k=self._top_k, dim=-1
            )
        else:
            _, selected_experts_indices = torch.topk(
                scores + self.expert_bias, k=self._top_k, dim=-1
            )
            scores = scores.gather(dim=-1, index=selected_experts_indices)

        # re-normalize scores
        denominator = scores.sum(dim=-1, keepdim=True) + 1e-20
        scores = scores / denominator

        return selected_experts_indices, scores

    def reset_parameters(self):
        """Resets module parameters."""
        if self.expert_bias is not None:
            nn.init.zeros_(self.expert_bias)

        self.gate.reset_parameters()

__init__(dim, num_experts, top_k, renormalize_probabilities, enable_expert_bias=False)

Constructs the TopKRouter.

Parameters:

Name Type Description Default
dim int

Input feature dimensionality.

required
num_experts int

Total number of experts to choose from.

required
top_k int

Number of experts to select for each token.

required
renormalize_probabilities bool

If True, probabilities of selected experts will be renormalized to sum up to 1

required
enable_expert_bias bool

If True, adds a bias term to the routing scores before top-k selection. This can be used for loss-free load balancing.

False
Source code in d9d/module/block/moe/router.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(
        self,
        dim: int,
        num_experts: int,
        top_k: int,
        renormalize_probabilities: bool,
        enable_expert_bias: bool = False
):
    """
    Constructs the TopKRouter.

    Args:
        dim: Input feature dimensionality.
        num_experts: Total number of experts to choose from.
        top_k: Number of experts to select for each token.
        renormalize_probabilities: If True, probabilities of selected experts will be renormalized to sum up to 1
        enable_expert_bias: If True, adds a bias term to the routing scores before top-k selection. This can be
            used for loss-free load balancing.
    """

    super().__init__()
    self.gate = nn.Linear(dim, num_experts, bias=False)

    self.expert_bias: nn.Buffer | None
    if enable_expert_bias:
        self.expert_bias = nn.Buffer(
            torch.empty(num_experts, dtype=torch.float32),
            persistent=True,
        )
    else:
        self.expert_bias = None

    self._num_experts = num_experts
    self._top_k = top_k
    self._renormalize_probabilities = renormalize_probabilities

forward(hidden_states)

Calculates routing decisions for the input tokens.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tokens. Shape: (num_tokens, dim).

required

Returns:

Type Description
Tensor

A tuple containing:

Tensor
  • Selected expert indices. Shape: (num_tokens, top_k).
tuple[Tensor, Tensor]
  • Normalized routing weights for the selected experts. Shape: (num_tokens, top_k).
Source code in d9d/module/block/moe/router.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def forward(
        self,
        hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Calculates routing decisions for the input tokens.

    Args:
        hidden_states: Input tokens. Shape: `(num_tokens, dim)`.

    Returns:
        A tuple containing:

        - Selected expert indices. Shape: `(num_tokens, top_k)`.
        - Normalized routing weights for the selected experts. Shape: `(num_tokens, top_k)`.
    """

    # scores shape (bs*slen, num_experts)

    # gate
    scores = self.gate(hidden_states)

    # and now do softmax (before top-k to be able to apply expert bias)
    scores = F.softmax(scores, dim=-1, dtype=torch.float32)

    # select top-k
    if self.expert_bias is None:
        scores, selected_experts_indices = torch.topk(
            scores, k=self._top_k, dim=-1
        )
    else:
        _, selected_experts_indices = torch.topk(
            scores + self.expert_bias, k=self._top_k, dim=-1
        )
        scores = scores.gather(dim=-1, index=selected_experts_indices)

    # re-normalize scores
    denominator = scores.sum(dim=-1, keepdim=True) + 1e-20
    scores = scores / denominator

    return selected_experts_indices, scores

reset_parameters()

Resets module parameters.

Source code in d9d/module/block/moe/router.py
 98
 99
100
101
102
103
def reset_parameters(self):
    """Resets module parameters."""
    if self.expert_bias is not None:
        nn.init.zeros_(self.expert_bias)

    self.gate.reset_parameters()

d9d.module.block.moe.communications

Provides communication strategies for Mixture-of-Experts routing operations.

DeepEpCommunicationHandler

Bases: ExpertCommunicationHandler

Handles MoE communication using the high-performance DeepEP library.

Source code in d9d/module/block/moe/communications/deepep.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
class DeepEpCommunicationHandler(ExpertCommunicationHandler):
    """Handles MoE communication using the high-performance DeepEP library."""

    def __init__(self, num_experts: int):
        """Constructs the DeepEpCommunicationHandler."""

        self._num_experts = num_experts
        self._num_experts_per_shard = None  # late-initialization

        # == fields saved for post-dispatch ==

        self._handle = None
        self._hidden_shape_before_permute = None
        self._unpermute_mapping = None

    def setup(self, group: torch.distributed.ProcessGroup, hidden_size: int, hidden_dtype: torch.dtype):
        """
        Initializes the backend buffer and calculates expert sharding.

        Args:
            group: The process group containing all experts.
            hidden_size: Dimensionality of the hidden states.
            hidden_dtype: Data type of the hidden states.
        """

        init_deepep_buffer(group, hidden_size * hidden_dtype.itemsize)

        if self._num_experts % group.size() != 0:
            raise ValueError("num_experts must be divisible by distributed group size")

        self._num_experts_per_shard = self._num_experts // group.size()

    def dispatch(
            self,
            hidden_states: torch.Tensor,
            topk_ids: torch.Tensor,
            topk_weights: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        (
            hidden_states,
            topk_ids,
            topk_weights,
            tokens_per_expert,
            handle
        ) = DeepEpDispatch.apply(
            hidden_states,
            topk_ids,
            topk_weights,
            self._num_experts
        )

        routing_map, routing_probs = fused_indices_to_multihot(
            topk_ids, topk_weights, self._num_experts_per_shard
        )

        self._hidden_shape_before_permute = hidden_states.shape

        hidden_states, routing_probs, reverse_permute_map = moe_permute_with_probs(
            hidden_states,
            routing_probs,
            routing_map,
            num_out_tokens=tokens_per_expert.sum().item()
        )

        self._handle = handle
        self._unpermute_mapping = reverse_permute_map

        return hidden_states, routing_probs, tokens_per_expert

    def combine(
            self,
            hidden_states: torch.Tensor
    ) -> torch.Tensor:
        if self._handle is None:
            raise ValueError("you fucked up moe communication order: you should dispatch first and after that combine")

        hidden_states = moe_unpermute_mask(
            hidden_states,
            self._unpermute_mapping,
            restore_shape=self._hidden_shape_before_permute,
        )

        hidden_states = DeepEpCombine.apply(
            hidden_states,
            self._handle
        )

        self._handle = None
        self._unpermute_mapping = None
        self._hidden_shape_before_permute = None

        return hidden_states

__init__(num_experts)

Constructs the DeepEpCommunicationHandler.

Source code in d9d/module/block/moe/communications/deepep.py
212
213
214
215
216
217
218
219
220
221
222
def __init__(self, num_experts: int):
    """Constructs the DeepEpCommunicationHandler."""

    self._num_experts = num_experts
    self._num_experts_per_shard = None  # late-initialization

    # == fields saved for post-dispatch ==

    self._handle = None
    self._hidden_shape_before_permute = None
    self._unpermute_mapping = None

setup(group, hidden_size, hidden_dtype)

Initializes the backend buffer and calculates expert sharding.

Parameters:

Name Type Description Default
group ProcessGroup

The process group containing all experts.

required
hidden_size int

Dimensionality of the hidden states.

required
hidden_dtype dtype

Data type of the hidden states.

required
Source code in d9d/module/block/moe/communications/deepep.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def setup(self, group: torch.distributed.ProcessGroup, hidden_size: int, hidden_dtype: torch.dtype):
    """
    Initializes the backend buffer and calculates expert sharding.

    Args:
        group: The process group containing all experts.
        hidden_size: Dimensionality of the hidden states.
        hidden_dtype: Data type of the hidden states.
    """

    init_deepep_buffer(group, hidden_size * hidden_dtype.itemsize)

    if self._num_experts % group.size() != 0:
        raise ValueError("num_experts must be divisible by distributed group size")

    self._num_experts_per_shard = self._num_experts // group.size()

ExpertCommunicationHandler

Bases: ABC

Abstract base class for Mixture-of-Experts communication strategies.

Source code in d9d/module/block/moe/communications/base.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class ExpertCommunicationHandler(abc.ABC):
    """Abstract base class for Mixture-of-Experts communication strategies."""

    @abc.abstractmethod
    def dispatch(
            self,
            hidden_states: torch.Tensor,
            topk_ids: torch.Tensor,
            topk_weights: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Prepares and routes local hidden states to their target experts (possibly on other workers).

        This process involves:

        1. All-to-All Communication: Transfers hidden states to workers containing the assigned experts. States
        assigned to multiple experts are replicated.

        2. Permutation: Sorts tokens by expert ID to prepare for Grouped GEMM.

        Args:
            hidden_states: Input tokens. Shape: `(num_tokens, hidden_size)`.
            topk_ids: Indices of the top-k experts selected for each token. Shape: `(num_tokens, k)`.
            topk_weights: Routing weights associated with the selected experts. Shape: `(num_tokens, k)`.

        Returns:
            A tuple containing:

            - Permuted hidden states received by this rank. Shape: `(num_received_tokens, hidden_size)`.
            - Permuted weights matching the hidden states order. Shape: `(num_received_tokens)`.
            - Expert count tensor indicating how many tokens each local expert received. Shape: `(num_local_experts)`.
        """

        ...

    @abc.abstractmethod
    def combine(
            self,
            hidden_states: torch.Tensor
    ) -> torch.Tensor:
        """
        Restores hidden states to their original order and location.

        Undoes the permutation and performs the reverse All-to-All communication
        to return processed results to the workers that originated the requests.

        Args:
            hidden_states: The processed hidden states. Shape: `(num_received_tokens, hidden_size)`.

        Returns:
            The combined hidden states with the original shape and order. Shape: `(num_tokens, hidden_size)`.
        """
        ...

combine(hidden_states) abstractmethod

Restores hidden states to their original order and location.

Undoes the permutation and performs the reverse All-to-All communication to return processed results to the workers that originated the requests.

Parameters:

Name Type Description Default
hidden_states Tensor

The processed hidden states. Shape: (num_received_tokens, hidden_size).

required

Returns:

Type Description
Tensor

The combined hidden states with the original shape and order. Shape: (num_tokens, hidden_size).

Source code in d9d/module/block/moe/communications/base.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@abc.abstractmethod
def combine(
        self,
        hidden_states: torch.Tensor
) -> torch.Tensor:
    """
    Restores hidden states to their original order and location.

    Undoes the permutation and performs the reverse All-to-All communication
    to return processed results to the workers that originated the requests.

    Args:
        hidden_states: The processed hidden states. Shape: `(num_received_tokens, hidden_size)`.

    Returns:
        The combined hidden states with the original shape and order. Shape: `(num_tokens, hidden_size)`.
    """
    ...

dispatch(hidden_states, topk_ids, topk_weights) abstractmethod

Prepares and routes local hidden states to their target experts (possibly on other workers).

This process involves:

  1. All-to-All Communication: Transfers hidden states to workers containing the assigned experts. States assigned to multiple experts are replicated.

  2. Permutation: Sorts tokens by expert ID to prepare for Grouped GEMM.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tokens. Shape: (num_tokens, hidden_size).

required
topk_ids Tensor

Indices of the top-k experts selected for each token. Shape: (num_tokens, k).

required
topk_weights Tensor

Routing weights associated with the selected experts. Shape: (num_tokens, k).

required

Returns:

Type Description
Tensor

A tuple containing:

Tensor
  • Permuted hidden states received by this rank. Shape: (num_received_tokens, hidden_size).
Tensor
  • Permuted weights matching the hidden states order. Shape: (num_received_tokens).
tuple[Tensor, Tensor, Tensor]
  • Expert count tensor indicating how many tokens each local expert received. Shape: (num_local_experts).
Source code in d9d/module/block/moe/communications/base.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@abc.abstractmethod
def dispatch(
        self,
        hidden_states: torch.Tensor,
        topk_ids: torch.Tensor,
        topk_weights: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Prepares and routes local hidden states to their target experts (possibly on other workers).

    This process involves:

    1. All-to-All Communication: Transfers hidden states to workers containing the assigned experts. States
    assigned to multiple experts are replicated.

    2. Permutation: Sorts tokens by expert ID to prepare for Grouped GEMM.

    Args:
        hidden_states: Input tokens. Shape: `(num_tokens, hidden_size)`.
        topk_ids: Indices of the top-k experts selected for each token. Shape: `(num_tokens, k)`.
        topk_weights: Routing weights associated with the selected experts. Shape: `(num_tokens, k)`.

    Returns:
        A tuple containing:

        - Permuted hidden states received by this rank. Shape: `(num_received_tokens, hidden_size)`.
        - Permuted weights matching the hidden states order. Shape: `(num_received_tokens)`.
        - Expert count tensor indicating how many tokens each local expert received. Shape: `(num_local_experts)`.
    """

    ...

NoCommunicationHandler

Bases: ExpertCommunicationHandler

Handles MoE routing within a single device or when no cross-device routing is needed.

This handler does not perform network operations. It only permutes elements mostly for local logical grouping or debugging.

Source code in d9d/module/block/moe/communications/naive.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class NoCommunicationHandler(ExpertCommunicationHandler):
    """
    Handles MoE routing within a single device or when no cross-device routing is needed.

    This handler does not perform network operations. It only permutes elements
    mostly for local logical grouping or debugging.
    """

    def __init__(self, num_experts: int):
        """Constructs the NoCommunicationHandler."""
        self._num_experts = num_experts

        self._hidden_shape_before_permute: Size | None = None
        self._unpermute_mapping: torch.Tensor | None = None

    def dispatch(
            self,
            hidden_states: torch.Tensor,
            topk_ids: torch.Tensor,
            topk_weights: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            tokens_per_expert = torch.bincount(topk_ids.flatten(), minlength=self._num_experts).cpu()

        routing_map, routing_probs = fused_indices_to_multihot(
            topk_ids, topk_weights, self._num_experts
        )

        self._hidden_shape_before_permute = hidden_states.shape

        hidden_states, routing_probs, reverse_permute_map = moe_permute_with_probs(
            hidden_states,
            routing_probs,
            routing_map,
            num_out_tokens=cast(int, tokens_per_expert.sum().item())
        )

        self._unpermute_mapping = reverse_permute_map

        return hidden_states, routing_probs, tokens_per_expert

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self._unpermute_mapping is None:
            raise ValueError("Cannot run combine before running dispatch!")

        hidden_states = moe_unpermute_mask(
            hidden_states,
            self._unpermute_mapping,
            restore_shape=self._hidden_shape_before_permute,
        )

        self._unpermute_mapping = None
        self._hidden_shape_before_permute = None

        return hidden_states

__init__(num_experts)

Constructs the NoCommunicationHandler.

Source code in d9d/module/block/moe/communications/naive.py
22
23
24
25
26
27
def __init__(self, num_experts: int):
    """Constructs the NoCommunicationHandler."""
    self._num_experts = num_experts

    self._hidden_shape_before_permute: Size | None = None
    self._unpermute_mapping: torch.Tensor | None = None