About

The d9d.module.model.qwen3_moe package implements the Qwen3 Mixture-of-Experts model architecture.

d9d.module.model.qwen3_moe

Qwen3MoEForCausalLM

Bases: Module, ModuleLateInit, ModuleSupportsPipelining

A Qwen3 MoE model wrapped with a Causal Language Modeling head.

It is designed to be split across multiple pipeline stages.

Source code in d9d/module/model/qwen3_moe/model.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
class Qwen3MoEForCausalLM(nn.Module, ModuleLateInit, ModuleSupportsPipelining):
    """
    A Qwen3 MoE model wrapped with a Causal Language Modeling head.

    It is designed to be split across multiple pipeline stages.
    """

    def __init__(
            self,
            params: Qwen3MoEForCausalLMParameters,
            stage: PipelineStageInfo,
            hidden_states_snapshot_mode: HiddenStatesAggregationMode,
            enable_checkpointing: bool
    ):
        """
        Constructs the Qwen3MoEForCausalLM object.

        Args:
            params: Full model configuration parameters.
            stage: Pipeline stage information for this instance.
            hidden_states_snapshot_mode: Configures intermediate hidden state aggregation & snapshotting mode.
            enable_checkpointing: Whether to enable activation checkpointing.
        """

        super().__init__()

        self.model = Qwen3MoEModel(
            params.model,
            stage,
            hidden_states_snapshot_mode=hidden_states_snapshot_mode,
            enable_checkpointing=enable_checkpointing
        )

        if stage.is_current_stage_last:
            self.lm_head = SplitLanguageModellingHead(
                split_vocab_size=params.model.split_vocab_size,
                split_order=params.model.split_vocab_order,
                hidden_size=params.model.layer.hidden_size
            )

        self._stage = stage
        self._hidden_size = params.model.layer.hidden_size

    def forward(
            self,
            input_ids: torch.Tensor | None = None,
            hidden_states: torch.Tensor | None = None,
            position_ids: torch.Tensor | None = None,
            hidden_states_snapshot: torch.Tensor | None = None,
            hidden_states_agg_mask: torch.Tensor | None = None,
            labels: torch.Tensor | None = None
    ) -> dict[str, torch.Tensor]:
        """
        Executes the model forward pass.

        If this is the last stage, it expects `labels` to be provided and computes
        the cross-entropy loss (returned as 'logps' typically representing per-token loss).

        Args:
            input_ids: Input token IDS (for Stage 0).
            hidden_states: Hidden states from previous stage (for Stage > 0).
            position_ids: Positional indices for RoPE.
            hidden_states_snapshot: Intermediate state collector.
            hidden_states_agg_mask: Mask for state aggregation.
            labels: Target tokens for loss computation (Last Stage).

        Returns:
            Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot',
            and per-token 'logps' if on the last stage.
        """

        model_outputs = self.model(
            input_ids=input_ids,
            hidden_states=hidden_states,
            position_ids=position_ids,
            hidden_states_snapshot=hidden_states_snapshot,
            hidden_states_agg_mask=hidden_states_agg_mask
        )
        if self._stage.is_current_stage_last:
            lm_out = self.lm_head(
                hidden_states=model_outputs["hidden_states"],
                labels=labels
            )
            model_outputs["logps"] = lm_out
        return model_outputs

    def reset_parameters(self):
        """
        Resets module parameters.
        """

        self.model.reset_parameters()

        if self._stage.is_current_stage_last:
            self.lm_head.reset_parameters()

    def reset_moe_stats(self):
        """
        Resets MoE routing statistics in the backbone.
        """

        self.model.reset_moe_stats()

    @property
    def moe_tokens_per_expert(self) -> torch.Tensor:
        """
        Accesses MoE routing statistics from the backbone.
        """

        return self.model.moe_tokens_per_expert

    def infer_stage_inputs_from_pipeline_inputs(
            self, inputs: dict[str, torch.Tensor], n_microbatches: int
    ) -> dict[str, torch.Tensor]:
        return self.model.infer_stage_inputs_from_pipeline_inputs(inputs, n_microbatches)

    def infer_stage_outputs_from_pipeline_inputs(
            self, inputs: dict[str, torch.Tensor], n_microbatches: int
    ) -> dict[str, torch.Tensor]:
        pp_outputs = self.model.infer_stage_outputs_from_pipeline_inputs(inputs, n_microbatches)

        if self._stage.is_current_stage_last:
            pp_outputs["logps"] = torch.empty(inputs["input_ids"].shape, dtype=torch.float32)

        return pp_outputs

moe_tokens_per_expert property

Accesses MoE routing statistics from the backbone.

__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)

Constructs the Qwen3MoEForCausalLM object.

Parameters:

Name Type Description Default
params Qwen3MoEForCausalLMParameters

Full model configuration parameters.

required
stage PipelineStageInfo

Pipeline stage information for this instance.

required
hidden_states_snapshot_mode HiddenStatesAggregationMode

Configures intermediate hidden state aggregation & snapshotting mode.

required
enable_checkpointing bool

Whether to enable activation checkpointing.

required
Source code in d9d/module/model/qwen3_moe/model.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def __init__(
        self,
        params: Qwen3MoEForCausalLMParameters,
        stage: PipelineStageInfo,
        hidden_states_snapshot_mode: HiddenStatesAggregationMode,
        enable_checkpointing: bool
):
    """
    Constructs the Qwen3MoEForCausalLM object.

    Args:
        params: Full model configuration parameters.
        stage: Pipeline stage information for this instance.
        hidden_states_snapshot_mode: Configures intermediate hidden state aggregation & snapshotting mode.
        enable_checkpointing: Whether to enable activation checkpointing.
    """

    super().__init__()

    self.model = Qwen3MoEModel(
        params.model,
        stage,
        hidden_states_snapshot_mode=hidden_states_snapshot_mode,
        enable_checkpointing=enable_checkpointing
    )

    if stage.is_current_stage_last:
        self.lm_head = SplitLanguageModellingHead(
            split_vocab_size=params.model.split_vocab_size,
            split_order=params.model.split_vocab_order,
            hidden_size=params.model.layer.hidden_size
        )

    self._stage = stage
    self._hidden_size = params.model.layer.hidden_size

forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, labels=None)

Executes the model forward pass.

If this is the last stage, it expects labels to be provided and computes the cross-entropy loss (returned as 'logps' typically representing per-token loss).

Parameters:

Name Type Description Default
input_ids Tensor | None

Input token IDS (for Stage 0).

None
hidden_states Tensor | None

Hidden states from previous stage (for Stage > 0).

None
position_ids Tensor | None

Positional indices for RoPE.

None
hidden_states_snapshot Tensor | None

Intermediate state collector.

None
hidden_states_agg_mask Tensor | None

Mask for state aggregation.

None
labels Tensor | None

Target tokens for loss computation (Last Stage).

None

Returns:

Type Description
dict[str, Tensor]

Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot',

dict[str, Tensor]

and per-token 'logps' if on the last stage.

Source code in d9d/module/model/qwen3_moe/model.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def forward(
        self,
        input_ids: torch.Tensor | None = None,
        hidden_states: torch.Tensor | None = None,
        position_ids: torch.Tensor | None = None,
        hidden_states_snapshot: torch.Tensor | None = None,
        hidden_states_agg_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None
) -> dict[str, torch.Tensor]:
    """
    Executes the model forward pass.

    If this is the last stage, it expects `labels` to be provided and computes
    the cross-entropy loss (returned as 'logps' typically representing per-token loss).

    Args:
        input_ids: Input token IDS (for Stage 0).
        hidden_states: Hidden states from previous stage (for Stage > 0).
        position_ids: Positional indices for RoPE.
        hidden_states_snapshot: Intermediate state collector.
        hidden_states_agg_mask: Mask for state aggregation.
        labels: Target tokens for loss computation (Last Stage).

    Returns:
        Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot',
        and per-token 'logps' if on the last stage.
    """

    model_outputs = self.model(
        input_ids=input_ids,
        hidden_states=hidden_states,
        position_ids=position_ids,
        hidden_states_snapshot=hidden_states_snapshot,
        hidden_states_agg_mask=hidden_states_agg_mask
    )
    if self._stage.is_current_stage_last:
        lm_out = self.lm_head(
            hidden_states=model_outputs["hidden_states"],
            labels=labels
        )
        model_outputs["logps"] = lm_out
    return model_outputs

reset_moe_stats()

Resets MoE routing statistics in the backbone.

Source code in d9d/module/model/qwen3_moe/model.py
345
346
347
348
349
350
def reset_moe_stats(self):
    """
    Resets MoE routing statistics in the backbone.
    """

    self.model.reset_moe_stats()

reset_parameters()

Resets module parameters.

Source code in d9d/module/model/qwen3_moe/model.py
335
336
337
338
339
340
341
342
343
def reset_parameters(self):
    """
    Resets module parameters.
    """

    self.model.reset_parameters()

    if self._stage.is_current_stage_last:
        self.lm_head.reset_parameters()

Qwen3MoEForCausalLMParameters

Bases: BaseModel

Configuration parameters for Qwen3 Mixture-of-Experts model with a Causal Language Modeling head.

Attributes:

Name Type Description
model Qwen3MoEParameters

The configuration for the underlying Qwen3 MoE model.

Source code in d9d/module/model/qwen3_moe/params.py
52
53
54
55
56
57
58
59
60
class Qwen3MoEForCausalLMParameters(BaseModel):
    """
    Configuration parameters for Qwen3 Mixture-of-Experts model with a Causal Language Modeling head.

    Attributes:
        model: The configuration for the underlying Qwen3 MoE model.
    """

    model: Qwen3MoEParameters

Qwen3MoELayer

Bases: Module, ModuleLateInit

Implements a single Qwen3 Mixture-of-Experts (MoE) transformer layer.

This layer consists of a Grouped Query Attention mechanism followed by an MoE MLP block, with pre-RMSNorm applied before each sub-layer.

Source code in d9d/module/model/qwen3_moe/decoder_layer.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class Qwen3MoELayer(nn.Module, ModuleLateInit):
    """
    Implements a single Qwen3 Mixture-of-Experts (MoE) transformer layer.

    This layer consists of a Grouped Query Attention mechanism followed by an MoE
    MLP block, with pre-RMSNorm applied before each sub-layer.
    """

    def __init__(
            self,
            params: Qwen3MoELayerParameters
    ):
        """
        Constructs a Qwen3MoELayer object.

        Args:
            params: Configuration parameters for the layer.
        """

        super().__init__()

        self.self_attn = GroupedQueryAttention(
            hidden_size=params.hidden_size,
            num_attention_heads=params.num_attention_heads,
            num_key_value_heads=params.num_key_value_heads,
            is_causal=True,
            qk_norm_eps=params.rms_norm_eps,
            head_dim=params.head_dim
        )

        self.mlp = MoELayer(
            hidden_dim=params.hidden_size,
            num_grouped_experts=params.num_experts,
            intermediate_dim_grouped=params.intermediate_size,
            top_k=params.experts_top_k,
            router_renormalize_probabilities=True
        )

        self.input_layernorm = nn.RMSNorm(params.hidden_size, eps=params.rms_norm_eps)
        self.post_attention_layernorm = nn.RMSNorm(params.hidden_size, eps=params.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        """
        Performs the forward pass of the MoE layer.

        Args:
            hidden_states: Input tensor of shape `(batch, seq_len, hidden_dim)`.
            position_embeddings: Tuple containing RoPE precomputed embeddings (cos, sin).

        Returns:
            Output tensor after attention and MoE blocks, shape `(batch, seq_len, hidden_dim)`.
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=None  # no mask for moe decoder
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        hidden_states = residual + hidden_states

        return hidden_states

    def reset_moe_stats(self):
        """
        Resets statistical counters inside the MoE router (e.g., token counts per expert).
        """

        self.mlp.reset_stats()

    @property
    def moe_tokens_per_expert(self) -> torch.Tensor:
        """
        Returns the number of tokens routed to each expert.
        """

        return self.mlp.tokens_per_expert

    def reset_parameters(self):
        """
        Resets module parameters.
        """

        self.self_attn.reset_parameters()
        self.mlp.reset_parameters()
        self.input_layernorm.reset_parameters()
        self.post_attention_layernorm.reset_parameters()

moe_tokens_per_expert property

Returns the number of tokens routed to each expert.

__init__(params)

Constructs a Qwen3MoELayer object.

Parameters:

Name Type Description Default
params Qwen3MoELayerParameters

Configuration parameters for the layer.

required
Source code in d9d/module/model/qwen3_moe/decoder_layer.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
        self,
        params: Qwen3MoELayerParameters
):
    """
    Constructs a Qwen3MoELayer object.

    Args:
        params: Configuration parameters for the layer.
    """

    super().__init__()

    self.self_attn = GroupedQueryAttention(
        hidden_size=params.hidden_size,
        num_attention_heads=params.num_attention_heads,
        num_key_value_heads=params.num_key_value_heads,
        is_causal=True,
        qk_norm_eps=params.rms_norm_eps,
        head_dim=params.head_dim
    )

    self.mlp = MoELayer(
        hidden_dim=params.hidden_size,
        num_grouped_experts=params.num_experts,
        intermediate_dim_grouped=params.intermediate_size,
        top_k=params.experts_top_k,
        router_renormalize_probabilities=True
    )

    self.input_layernorm = nn.RMSNorm(params.hidden_size, eps=params.rms_norm_eps)
    self.post_attention_layernorm = nn.RMSNorm(params.hidden_size, eps=params.rms_norm_eps)

forward(hidden_states, position_embeddings)

Performs the forward pass of the MoE layer.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor of shape (batch, seq_len, hidden_dim).

required
position_embeddings tuple[Tensor, Tensor]

Tuple containing RoPE precomputed embeddings (cos, sin).

required

Returns:

Type Description
Tensor

Output tensor after attention and MoE blocks, shape (batch, seq_len, hidden_dim).

Source code in d9d/module/model/qwen3_moe/decoder_layer.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def forward(
    self,
    hidden_states: torch.Tensor,
    position_embeddings: tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
    """
    Performs the forward pass of the MoE layer.

    Args:
        hidden_states: Input tensor of shape `(batch, seq_len, hidden_dim)`.
        position_embeddings: Tuple containing RoPE precomputed embeddings (cos, sin).

    Returns:
        Output tensor after attention and MoE blocks, shape `(batch, seq_len, hidden_dim)`.
    """

    residual = hidden_states

    hidden_states = self.input_layernorm(hidden_states)

    hidden_states = self.self_attn(
        hidden_states=hidden_states,
        position_embeddings=position_embeddings,
        attention_mask=None  # no mask for moe decoder
    )
    hidden_states = residual + hidden_states

    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)

    hidden_states = residual + hidden_states

    return hidden_states

reset_moe_stats()

Resets statistical counters inside the MoE router (e.g., token counts per expert).

Source code in d9d/module/model/qwen3_moe/decoder_layer.py
87
88
89
90
91
92
def reset_moe_stats(self):
    """
    Resets statistical counters inside the MoE router (e.g., token counts per expert).
    """

    self.mlp.reset_stats()

reset_parameters()

Resets module parameters.

Source code in d9d/module/model/qwen3_moe/decoder_layer.py
102
103
104
105
106
107
108
109
110
def reset_parameters(self):
    """
    Resets module parameters.
    """

    self.self_attn.reset_parameters()
    self.mlp.reset_parameters()
    self.input_layernorm.reset_parameters()
    self.post_attention_layernorm.reset_parameters()

Qwen3MoELayerParameters

Bases: BaseModel

Configuration parameters for a single Qwen3 MoE layer.

Attributes:

Name Type Description
hidden_size int

Dimension of the model's hidden states.

intermediate_size int

Dimension of the feed-forward hidden state.

num_experts int

Total number of experts in the MoE layer.

experts_top_k int

Number of experts to route tokens to.

num_attention_heads int

Number of attention heads for the query.

num_key_value_heads int

Number of attention heads for key and value.

rms_norm_eps float

Epsilon value found in the RMSNorm layers.

head_dim int

Dimension of a single attention head.

Source code in d9d/module/model/qwen3_moe/params.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Qwen3MoELayerParameters(BaseModel):
    """
    Configuration parameters for a single Qwen3 MoE layer.

    Attributes:
        hidden_size: Dimension of the model's hidden states.
        intermediate_size: Dimension of the feed-forward hidden state.
        num_experts: Total number of experts in the MoE layer.
        experts_top_k: Number of experts to route tokens to.
        num_attention_heads: Number of attention heads for the query.
        num_key_value_heads: Number of attention heads for key and value.
        rms_norm_eps: Epsilon value found in the RMSNorm layers.
        head_dim: Dimension of a single attention head.
    """

    hidden_size: int
    intermediate_size: int
    num_experts: int
    experts_top_k: int
    num_attention_heads: int
    num_key_value_heads: int
    rms_norm_eps: float
    head_dim: int

Qwen3MoEModel

Bases: Module, ModuleLateInit, ModuleSupportsPipelining

The Qwen3 Mixture-of-Experts (MoE) Transformer Decoder backbone.

It is designed to be split across multiple pipeline stages.

Source code in d9d/module/model/qwen3_moe/model.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class Qwen3MoEModel(nn.Module, ModuleLateInit, ModuleSupportsPipelining):
    """
    The Qwen3 Mixture-of-Experts (MoE) Transformer Decoder backbone.

    It is designed to be split across multiple pipeline stages.
    """

    def __init__(
            self,
            params: Qwen3MoEParameters,
            stage: PipelineStageInfo,
            hidden_states_snapshot_mode: HiddenStatesAggregationMode,
            enable_checkpointing: bool
    ):
        """
        Constructs the Qwen3MoEModel object.

        Args:
            params: Configuration parameters for the full model.
            stage: Information about the pipeline stage this instance belongs to.
            hidden_states_snapshot_mode: Configures intermediate hidden state aggregation & snapshotting mode
            enable_checkpointing: If True, enables activation checkpointing for transformer layers to save memory.
        """

        super().__init__()

        if stage.is_current_stage_first:
            self.embed_tokens = SplitTokenEmbeddings(
                hidden_size=params.layer.hidden_size,
                split_vocab_size=params.split_vocab_size,
                split_order=params.split_vocab_order
            )

        # we use ModuleDict here to properly handle pipelining and loading weights after the model
        # was pipelined
        layer_start, layer_end = distribute_layers_for_pipeline_stage(
            num_layers=params.num_hidden_layers,
            num_virtual_layers_pre=0,  # embeddings
            num_virtual_layers_post=2,  # LM head
            stage=stage
        )

        self._num_layers_before = layer_start
        self._layers_iter = list(map(str, range(layer_start, layer_end)))
        layers = nn.ModuleDict({
            str(layer_idx): Qwen3MoELayer(params=params.layer) for layer_idx in self._layers_iter
        })
        self.layers: Mapping[str, Qwen3MoELayer] = cast(Mapping[str, Qwen3MoELayer], layers)

        self.rope_provider = RotaryEmbeddingProvider(
            max_position_ids=params.max_position_ids,
            rope_base=params.rope_base,
            head_dim=params.layer.head_dim
        )

        if stage.is_current_stage_last:
            self.norm = nn.RMSNorm(
                normalized_shape=params.layer.hidden_size,
                eps=params.layer.rms_norm_eps
            )

        self._stage = stage
        self._hidden_states_snapshot_mode = hidden_states_snapshot_mode
        self._hidden_size = params.layer.hidden_size
        self._enable_checkpointing = enable_checkpointing

    def output_dtype(self) -> torch.dtype:
        """
        Returns the data type of the model output hidden states.
        """
        return self.layers[self._layers_iter[0]].input_layernorm.weight.dtype

    def forward(
            self,
            input_ids: torch.Tensor | None = None,
            hidden_states: torch.Tensor | None = None,
            position_ids: torch.Tensor | None = None,
            hidden_states_snapshot: torch.Tensor | None = None,
            hidden_states_agg_mask: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """
        Executes the forward pass for the current pipeline stage.

        Args:
            input_ids: Indices of input sequence tokens. Required if this is the
                first pipeline stage.
            hidden_states: Hidden states from the previous pipeline stage. Required
                if this is not the first pipeline stage.
            position_ids: Indices of positions of each input sequence tokens in the
                position embeddings.
            hidden_states_snapshot: Accumulated tensor of aggregated hidden states
                from previous stages. Used if snapshotting is enabled.
            hidden_states_agg_mask: Mask used to aggregate hidden states for
                snapshots.

        Returns:
            A dictionary containing:
                *   'hidden_states': The output of the last layer in this stage.
                *   'hidden_states_snapshot': (Optional) The updated snapshot tensor.
        """
        state_aggregator = create_hidden_states_aggregator(self._hidden_states_snapshot_mode, hidden_states_agg_mask)

        if input_ids is not None:
            last_hidden_states = self.embed_tokens(input_ids)
            state_aggregator.add_hidden_states(last_hidden_states)
        else:
            last_hidden_states = hidden_states

        rope_params = self.rope_provider(position_ids)

        for decoder_layer_name in self._layers_iter:
            decoder_layer = self.layers[decoder_layer_name]

            if self._enable_checkpointing:
                last_hidden_states = checkpoint(
                    decoder_layer, last_hidden_states, rope_params,
                    use_reentrant=False
                )
            else:
                last_hidden_states = decoder_layer(last_hidden_states, rope_params)

            state_aggregator.add_hidden_states(last_hidden_states)

        if self._stage.is_current_stage_last:
            last_hidden_states = self.norm(last_hidden_states)

        return {
            "hidden_states": last_hidden_states,
            "hidden_states_snapshot": state_aggregator.pack_with_snapshot(hidden_states_snapshot)
        }

    def reset_moe_stats(self):
        """
        Resets routing statistics for all MoE layers in this stage.
        """

        for layer_name in self._layers_iter:
            self.layers[layer_name].reset_moe_stats()

    @property
    def moe_tokens_per_expert(self) -> torch.Tensor:
        """
        Retrieves the number of tokens routed to each expert across all layers.

        Returns:
            A tensor of shape (num_local_layers, num_experts) containing counts.
        """

        return torch.stack(
            [self.layers[layer_name].moe_tokens_per_expert for layer_name in self._layers_iter],
            dim=0
        )

    def reset_parameters(self):
        """Resets module parameters"""

        if self._stage.is_current_stage_first:
            self.embed_tokens.reset_parameters()

        self.rope_provider.reset_parameters()

        for decoder_layer_name in self._layers_iter:
            decoder_layer = self.layers[decoder_layer_name]
            decoder_layer.reset_parameters()

        if self._stage.is_current_stage_last:
            self.norm.reset_parameters()

    def infer_stage_inputs_from_pipeline_inputs(
            self, inputs: dict[str, torch.Tensor], n_microbatches: int
    ) -> dict[str, torch.Tensor]:
        input_ids = inputs["input_ids"]

        pp_inputs = {}

        # for calculation - input ids or prev hidden state
        if self._stage.is_current_stage_first:
            pp_inputs["input_ids"] = torch.empty(
                (input_ids.shape[0] // n_microbatches, input_ids.shape[1]),
                dtype=torch.long,
                device=input_ids.device
            )
        else:
            pp_inputs["hidden_states"] = torch.empty(
                (input_ids.shape[0] // n_microbatches, input_ids.shape[1], self._hidden_size),
                dtype=self.output_dtype(),
                device=input_ids.device
            )
            if self._hidden_states_snapshot_mode != HiddenStatesAggregationMode.no:
                num_layers_before = self._num_layers_before + 1  # 1 for embedding
                pp_inputs["hidden_states_snapshot"] = torch.empty(
                    (num_layers_before, input_ids.shape[0] // n_microbatches, self._hidden_size),
                    dtype=self.output_dtype(),
                    device=input_ids.device
                )

        return pp_inputs

    def infer_stage_outputs_from_pipeline_inputs(
            self, inputs: dict[str, torch.Tensor], n_microbatches: int
    ) -> dict[str, torch.Tensor]:
        input_ids = inputs["input_ids"]

        # for calculation - last hidden state
        pp_outputs = {
            "hidden_states": torch.empty(
                (input_ids.shape[0] // n_microbatches, input_ids.shape[1], self._hidden_size),
                dtype=self.output_dtype(),
                device=input_ids.device
            )
        }

        # for state caching
        if self._hidden_states_snapshot_mode != HiddenStatesAggregationMode.no:
            num_layers_before = self._num_layers_before + 1
            num_layers_current = len(self.layers)
            num_layers_after = num_layers_before + num_layers_current
            pp_outputs["hidden_states_snapshot"] = torch.empty(
                (num_layers_after, input_ids.shape[0] // n_microbatches, self._hidden_size),
                dtype=self.output_dtype(),
                device=input_ids.device
            )

        return pp_outputs

moe_tokens_per_expert property

Retrieves the number of tokens routed to each expert across all layers.

Returns:

Type Description
Tensor

A tensor of shape (num_local_layers, num_experts) containing counts.

__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)

Constructs the Qwen3MoEModel object.

Parameters:

Name Type Description Default
params Qwen3MoEParameters

Configuration parameters for the full model.

required
stage PipelineStageInfo

Information about the pipeline stage this instance belongs to.

required
hidden_states_snapshot_mode HiddenStatesAggregationMode

Configures intermediate hidden state aggregation & snapshotting mode

required
enable_checkpointing bool

If True, enables activation checkpointing for transformer layers to save memory.

required
Source code in d9d/module/model/qwen3_moe/model.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
        self,
        params: Qwen3MoEParameters,
        stage: PipelineStageInfo,
        hidden_states_snapshot_mode: HiddenStatesAggregationMode,
        enable_checkpointing: bool
):
    """
    Constructs the Qwen3MoEModel object.

    Args:
        params: Configuration parameters for the full model.
        stage: Information about the pipeline stage this instance belongs to.
        hidden_states_snapshot_mode: Configures intermediate hidden state aggregation & snapshotting mode
        enable_checkpointing: If True, enables activation checkpointing for transformer layers to save memory.
    """

    super().__init__()

    if stage.is_current_stage_first:
        self.embed_tokens = SplitTokenEmbeddings(
            hidden_size=params.layer.hidden_size,
            split_vocab_size=params.split_vocab_size,
            split_order=params.split_vocab_order
        )

    # we use ModuleDict here to properly handle pipelining and loading weights after the model
    # was pipelined
    layer_start, layer_end = distribute_layers_for_pipeline_stage(
        num_layers=params.num_hidden_layers,
        num_virtual_layers_pre=0,  # embeddings
        num_virtual_layers_post=2,  # LM head
        stage=stage
    )

    self._num_layers_before = layer_start
    self._layers_iter = list(map(str, range(layer_start, layer_end)))
    layers = nn.ModuleDict({
        str(layer_idx): Qwen3MoELayer(params=params.layer) for layer_idx in self._layers_iter
    })
    self.layers: Mapping[str, Qwen3MoELayer] = cast(Mapping[str, Qwen3MoELayer], layers)

    self.rope_provider = RotaryEmbeddingProvider(
        max_position_ids=params.max_position_ids,
        rope_base=params.rope_base,
        head_dim=params.layer.head_dim
    )

    if stage.is_current_stage_last:
        self.norm = nn.RMSNorm(
            normalized_shape=params.layer.hidden_size,
            eps=params.layer.rms_norm_eps
        )

    self._stage = stage
    self._hidden_states_snapshot_mode = hidden_states_snapshot_mode
    self._hidden_size = params.layer.hidden_size
    self._enable_checkpointing = enable_checkpointing

forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None)

Executes the forward pass for the current pipeline stage.

Parameters:

Name Type Description Default
input_ids Tensor | None

Indices of input sequence tokens. Required if this is the first pipeline stage.

None
hidden_states Tensor | None

Hidden states from the previous pipeline stage. Required if this is not the first pipeline stage.

None
position_ids Tensor | None

Indices of positions of each input sequence tokens in the position embeddings.

None
hidden_states_snapshot Tensor | None

Accumulated tensor of aggregated hidden states from previous stages. Used if snapshotting is enabled.

None
hidden_states_agg_mask Tensor | None

Mask used to aggregate hidden states for snapshots.

None

Returns:

Type Description
dict[str, Tensor]

A dictionary containing: * 'hidden_states': The output of the last layer in this stage. * 'hidden_states_snapshot': (Optional) The updated snapshot tensor.

Source code in d9d/module/model/qwen3_moe/model.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def forward(
        self,
        input_ids: torch.Tensor | None = None,
        hidden_states: torch.Tensor | None = None,
        position_ids: torch.Tensor | None = None,
        hidden_states_snapshot: torch.Tensor | None = None,
        hidden_states_agg_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
    """
    Executes the forward pass for the current pipeline stage.

    Args:
        input_ids: Indices of input sequence tokens. Required if this is the
            first pipeline stage.
        hidden_states: Hidden states from the previous pipeline stage. Required
            if this is not the first pipeline stage.
        position_ids: Indices of positions of each input sequence tokens in the
            position embeddings.
        hidden_states_snapshot: Accumulated tensor of aggregated hidden states
            from previous stages. Used if snapshotting is enabled.
        hidden_states_agg_mask: Mask used to aggregate hidden states for
            snapshots.

    Returns:
        A dictionary containing:
            *   'hidden_states': The output of the last layer in this stage.
            *   'hidden_states_snapshot': (Optional) The updated snapshot tensor.
    """
    state_aggregator = create_hidden_states_aggregator(self._hidden_states_snapshot_mode, hidden_states_agg_mask)

    if input_ids is not None:
        last_hidden_states = self.embed_tokens(input_ids)
        state_aggregator.add_hidden_states(last_hidden_states)
    else:
        last_hidden_states = hidden_states

    rope_params = self.rope_provider(position_ids)

    for decoder_layer_name in self._layers_iter:
        decoder_layer = self.layers[decoder_layer_name]

        if self._enable_checkpointing:
            last_hidden_states = checkpoint(
                decoder_layer, last_hidden_states, rope_params,
                use_reentrant=False
            )
        else:
            last_hidden_states = decoder_layer(last_hidden_states, rope_params)

        state_aggregator.add_hidden_states(last_hidden_states)

    if self._stage.is_current_stage_last:
        last_hidden_states = self.norm(last_hidden_states)

    return {
        "hidden_states": last_hidden_states,
        "hidden_states_snapshot": state_aggregator.pack_with_snapshot(hidden_states_snapshot)
    }

output_dtype()

Returns the data type of the model output hidden states.

Source code in d9d/module/model/qwen3_moe/model.py
89
90
91
92
93
def output_dtype(self) -> torch.dtype:
    """
    Returns the data type of the model output hidden states.
    """
    return self.layers[self._layers_iter[0]].input_layernorm.weight.dtype

reset_moe_stats()

Resets routing statistics for all MoE layers in this stage.

Source code in d9d/module/model/qwen3_moe/model.py
154
155
156
157
158
159
160
def reset_moe_stats(self):
    """
    Resets routing statistics for all MoE layers in this stage.
    """

    for layer_name in self._layers_iter:
        self.layers[layer_name].reset_moe_stats()

reset_parameters()

Resets module parameters

Source code in d9d/module/model/qwen3_moe/model.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def reset_parameters(self):
    """Resets module parameters"""

    if self._stage.is_current_stage_first:
        self.embed_tokens.reset_parameters()

    self.rope_provider.reset_parameters()

    for decoder_layer_name in self._layers_iter:
        decoder_layer = self.layers[decoder_layer_name]
        decoder_layer.reset_parameters()

    if self._stage.is_current_stage_last:
        self.norm.reset_parameters()

Qwen3MoEParameters

Bases: BaseModel

Configuration parameters for the Qwen3 Mixture-of-Experts model backbone.

Attributes:

Name Type Description
layer Qwen3MoELayerParameters

Configuration shared across all transformer layers.

num_hidden_layers int

The total number of transformer layers.

rope_base int

Base value for RoPE frequency calculation.

max_position_ids int

Maximum sequence length.

split_vocab_size dict[str, int]

A dictionary mapping vocabulary segment names to their sizes.

split_vocab_order list[str]

The sequence in which vocabulary splits are correctly ordered.

Source code in d9d/module/model/qwen3_moe/params.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Qwen3MoEParameters(BaseModel):
    """
    Configuration parameters for the Qwen3 Mixture-of-Experts model backbone.

    Attributes:
        layer: Configuration shared across all transformer layers.
        num_hidden_layers: The total number of transformer layers.
        rope_base: Base value for RoPE frequency calculation.
        max_position_ids: Maximum sequence length.
        split_vocab_size: A dictionary mapping vocabulary segment names to their sizes.
        split_vocab_order: The sequence in which vocabulary splits are correctly ordered.
    """

    layer: Qwen3MoELayerParameters

    num_hidden_layers: int
    rope_base: int
    max_position_ids: int

    split_vocab_size: dict[str, int]
    split_vocab_order: list[str]