Qwen3 MoE
About
The d9d.module.model.qwen3_moe package implements the Qwen3 Mixture-of-Experts model architecture.
The d9d.module.parallelism.model.qwen3_moe package implements default horizontal parallelism strategy for this model.
d9d.module.model.qwen3_moe
Qwen3MoEForCausalLM
Bases: Module, ModuleLateInit, ModuleSupportsPipelining
A Qwen3 MoE model wrapped with a Causal Language Modeling head.
It is designed to be split across multiple pipeline stages.
moe_tokens_per_expert
property
Accesses MoE routing statistics from the backbone.
__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)
Constructs the Qwen3MoEForCausalLM object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Qwen3MoEForCausalLMParameters
|
Full model configuration parameters. |
required |
stage
|
PipelineStageInfo
|
Pipeline stage information for this instance. |
required |
hidden_states_snapshot_mode
|
HiddenStatesAggregationMode
|
Configures intermediate hidden state aggregation & snapshotting mode. |
required |
enable_checkpointing
|
bool
|
Whether to enable activation checkpointing. |
required |
forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, labels=None)
Executes the model forward pass.
If this is the last stage, it expects labels to be provided and computes
the cross-entropy loss (returned as 'logps' typically representing per-token loss).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor | None
|
Input token IDS (for Stage 0). |
None
|
hidden_states
|
Tensor | None
|
Hidden states from previous stage (for Stage > 0). |
None
|
position_ids
|
Tensor | None
|
Positional indices for RoPE. |
None
|
hidden_states_snapshot
|
Tensor | None
|
Intermediate state collector. |
None
|
hidden_states_agg_mask
|
Tensor | None
|
Mask for state aggregation. |
None
|
labels
|
Tensor | None
|
Target tokens for loss computation (Last Stage). |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot', |
dict[str, Tensor]
|
and per-token 'logps' if on the last stage. |
reset_moe_stats()
Resets MoE routing statistics in the backbone.
reset_parameters()
Resets module parameters.
Qwen3MoEForCausalLMParameters
Bases: BaseModel
Configuration parameters for Qwen3 Mixture-of-Experts model with a Causal Language Modeling head.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Qwen3MoEParameters
|
The configuration for the underlying Qwen3 MoE model. |
Qwen3MoEForClassification
Bases: Module, ModuleLateInit, ModuleSupportsPipelining
A Qwen3 MoE model wrapped with a Sequence/Token Classification head.
It is designed to be split across multiple pipeline stages.
moe_tokens_per_expert
property
Accesses MoE routing statistics from the backbone.
__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)
Constructs the Qwen3MoEForClassification object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Qwen3MoEForClassificationParameters
|
Full model configuration parameters. |
required |
stage
|
PipelineStageInfo
|
Pipeline stage information for this instance. |
required |
hidden_states_snapshot_mode
|
HiddenStatesAggregationMode
|
Configures intermediate hidden state aggregation & snapshotting mode. |
required |
enable_checkpointing
|
bool
|
Whether to enable activation checkpointing. |
required |
forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, pooling_mask=None)
Executes the classification model forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor | None
|
Input token IDS (for Stage 0). |
None
|
hidden_states
|
Tensor | None
|
Hidden states from previous stage (for Stage > 0). |
None
|
position_ids
|
Tensor | None
|
Positional indices for RoPE. |
None
|
hidden_states_snapshot
|
Tensor | None
|
Intermediate state collector. |
None
|
hidden_states_agg_mask
|
Tensor | None
|
Mask for state aggregation. |
None
|
pooling_mask
|
Tensor | None
|
Binary mask indicating which token(s) to pool for classification.
Note: you can use |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot'. If on the last stage, also contains 'scores' (logits) of shape [batch, num_labels]. |
reset_moe_stats()
Resets MoE routing statistics in the backbone.
reset_parameters()
Resets module parameters.
Qwen3MoEForClassificationParameters
Bases: BaseModel
Configuration parameters for Qwen3 Mixture-of-Experts model with a token/sequnce classification head.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Qwen3MoEParameters
|
The configuration for the underlying Qwen3 MoE model. |
num_labels |
int
|
The number of output labels for classification. |
classifier_dropout |
float
|
The dropout probability for the classification head. |
Qwen3MoELayer
Bases: Module, ModuleLateInit
Implements a single Qwen3 Mixture-of-Experts (MoE) transformer layer.
This layer consists of a Grouped Query Attention mechanism followed by an MoE MLP block, with pre-RMSNorm applied before each sub-layer.
moe_tokens_per_expert
property
Returns the number of tokens routed to each expert.
__init__(params)
Constructs a Qwen3MoELayer object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Qwen3MoELayerParameters
|
Configuration parameters for the layer. |
required |
forward(hidden_states, position_embeddings)
Performs the forward pass of the MoE layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
Input tensor of shape |
required |
position_embeddings
|
tuple[Tensor, Tensor]
|
Tuple containing RoPE precomputed embeddings (cos, sin). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor after attention and MoE blocks, shape |
reset_moe_stats()
Resets statistical counters inside the MoE router (e.g., token counts per expert).
reset_parameters()
Resets module parameters.
Qwen3MoELayerParameters
Bases: BaseModel
Configuration parameters for a single Qwen3 MoE layer.
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_size |
int
|
Dimension of the model's hidden states. |
intermediate_size |
int
|
Dimension of the feed-forward hidden state. |
num_experts |
int
|
Total number of experts in the MoE layer. |
experts_top_k |
int
|
Number of experts to route tokens to. |
num_attention_heads |
int
|
Number of attention heads for the query. |
num_key_value_heads |
int
|
Number of attention heads for key and value. |
rms_norm_eps |
float
|
Epsilon value found in the RMSNorm layers. |
head_dim |
int
|
Dimension of a single attention head. |
Qwen3MoEModel
Bases: Module, ModuleLateInit, ModuleSupportsPipelining
The Qwen3 Mixture-of-Experts (MoE) Transformer Decoder backbone.
It is designed to be split across multiple pipeline stages.
moe_tokens_per_expert
property
Retrieves the number of tokens routed to each expert across all layers.
Returns:
| Type | Description |
|---|---|
Tensor
|
A tensor of shape (num_local_layers, num_experts) containing counts. |
__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)
Constructs the Qwen3MoEModel object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Qwen3MoEParameters
|
Configuration parameters for the full model. |
required |
stage
|
PipelineStageInfo
|
Information about the pipeline stage this instance belongs to. |
required |
hidden_states_snapshot_mode
|
HiddenStatesAggregationMode
|
Configures intermediate hidden state aggregation & snapshotting mode |
required |
enable_checkpointing
|
bool
|
If True, enables activation checkpointing for transformer layers to save memory. |
required |
forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None)
Executes the forward pass for the current pipeline stage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor | None
|
Indices of input sequence tokens. Required if this is the first pipeline stage. |
None
|
hidden_states
|
Tensor | None
|
Hidden states from the previous pipeline stage. Required if this is not the first pipeline stage. |
None
|
position_ids
|
Tensor | None
|
Indices of positions of each input sequence tokens in the position embeddings. |
None
|
hidden_states_snapshot
|
Tensor | None
|
Accumulated tensor of aggregated hidden states from previous stages. Used if snapshotting is enabled. |
None
|
hidden_states_agg_mask
|
Tensor | None
|
Mask used to aggregate hidden states for snapshots. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor | None]
|
A dictionary containing: * 'hidden_states': The output of the last layer in this stage. * 'hidden_states_snapshot': (Optional) The updated snapshot tensor. |
output_dtype()
Returns the data type of the model output hidden states.
reset_moe_stats()
Resets routing statistics for all MoE layers in this stage.
reset_parameters()
Resets module parameters
Qwen3MoEParameters
Bases: BaseModel
Configuration parameters for the Qwen3 Mixture-of-Experts model backbone.
Attributes:
| Name | Type | Description |
|---|---|---|
layer |
Qwen3MoELayerParameters
|
Configuration shared across all transformer layers. |
num_hidden_layers |
int
|
The total number of transformer layers. |
rope_base |
int
|
Base value for RoPE frequency calculation. |
max_position_ids |
int
|
Maximum sequence length. |
split_vocab_size |
dict[str, int]
|
A dictionary mapping vocabulary segment names to their sizes. |
split_vocab_order |
list[str]
|
The sequence in which vocabulary splits are correctly ordered. |
pipeline_num_virtual_layers_pre |
int
|
The number of 'virtual' layers representing the computational cost of modules on the first stage, before the main layers (e.g., token and positional embeddings). |
pipeline_num_virtual_layers_post |
int
|
The number of 'virtual' layers representing the computational cost of modules on the last stage, after the main layers (e.g., the final layer normalization and LM head). |
d9d.module.parallelism.model.qwen3_moe
parallelize_qwen3_moe_for_causal_lm(dist_context, model, stage)
Parallelizes the Qwen3 MoE Causal LM model.
This function delegates backbone parallelization to parallelize_qwen3_moe_model
and additionally configures the language model head with Hybrid Sharded Data
Parallelism (HSDP).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context containing device meshes and topology info. |
required |
model
|
Qwen3MoEForCausalLM
|
The Qwen3 MoE Causal LM model to parallelize. |
required |
stage
|
PipelineStageInfo
|
Information about the current pipeline stage. |
required |
parallelize_qwen3_moe_for_classification(dist_context, model, stage)
Parallelizes the Qwen3 MoE classification model.
This function delegates backbone parallelization to parallelize_qwen3_moe_model
and additionally configures the classification head with Hybrid Sharded Data
Parallelism (HSDP).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context containing device meshes and topology info. |
required |
model
|
Qwen3MoEForClassification
|
The Qwen3 MoE classification model to parallelize. |
required |
stage
|
PipelineStageInfo
|
Information about the current pipeline stage. |
required |
parallelize_qwen3_moe_model(dist_context, model, stage)
Parallelizes the base Qwen3 MoE model components.
This function configures the model layers for distributed execution within a pipeline stage. It applies Hybrid Sharded Data Parallelism (HSDP) to dense components (embeddings, norms, attention) and Expert Parallelism (EP) to the Mixture-of-Experts (MLP) layers.
Current usage constraints: * Tensor Parallelism is not supported (we may implement it later). * Context Parallelism is not supported (we will implement it later).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context. |
required |
model
|
Qwen3MoEModel
|
The Qwen3 MoE base model to parallelize. |
required |
stage
|
PipelineStageInfo
|
Information about the current pipeline stage. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If Tensor Parallel or Context Parallel is enabled in the context. |