Skip to content

Qwen3 MoE

About

The d9d.module.model.qwen3_moe package implements the Qwen3 Mixture-of-Experts model architecture.

The d9d.module.parallelism.model.qwen3_moe package implements default horizontal parallelism strategy for this model.

d9d.module.model.qwen3_moe

Qwen3MoEForCausalLM

Bases: Module, ModuleLateInit, ModuleSupportsPipelining

A Qwen3 MoE model wrapped with a Causal Language Modeling head.

It is designed to be split across multiple pipeline stages.

moe_tokens_per_expert property

Accesses MoE routing statistics from the backbone.

__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)

Constructs the Qwen3MoEForCausalLM object.

Parameters:

Name Type Description Default
params Qwen3MoEForCausalLMParameters

Full model configuration parameters.

required
stage PipelineStageInfo

Pipeline stage information for this instance.

required
hidden_states_snapshot_mode HiddenStatesAggregationMode

Configures intermediate hidden state aggregation & snapshotting mode.

required
enable_checkpointing bool

Whether to enable activation checkpointing.

required

forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, labels=None)

Executes the model forward pass.

If this is the last stage, it expects labels to be provided and computes the cross-entropy loss (returned as 'logps' typically representing per-token loss).

Parameters:

Name Type Description Default
input_ids Tensor | None

Input token IDS (for Stage 0).

None
hidden_states Tensor | None

Hidden states from previous stage (for Stage > 0).

None
position_ids Tensor | None

Positional indices for RoPE.

None
hidden_states_snapshot Tensor | None

Intermediate state collector.

None
hidden_states_agg_mask Tensor | None

Mask for state aggregation.

None
labels Tensor | None

Target tokens for loss computation (Last Stage).

None

Returns:

Type Description
dict[str, Tensor]

Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot',

dict[str, Tensor]

and per-token 'logps' if on the last stage.

reset_moe_stats()

Resets MoE routing statistics in the backbone.

reset_parameters()

Resets module parameters.

Qwen3MoEForCausalLMParameters

Bases: BaseModel

Configuration parameters for Qwen3 Mixture-of-Experts model with a Causal Language Modeling head.

Attributes:

Name Type Description
model Qwen3MoEParameters

The configuration for the underlying Qwen3 MoE model.

Qwen3MoEForClassification

Bases: Module, ModuleLateInit, ModuleSupportsPipelining

A Qwen3 MoE model wrapped with a Sequence/Token Classification head.

It is designed to be split across multiple pipeline stages.

moe_tokens_per_expert property

Accesses MoE routing statistics from the backbone.

__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)

Constructs the Qwen3MoEForClassification object.

Parameters:

Name Type Description Default
params Qwen3MoEForClassificationParameters

Full model configuration parameters.

required
stage PipelineStageInfo

Pipeline stage information for this instance.

required
hidden_states_snapshot_mode HiddenStatesAggregationMode

Configures intermediate hidden state aggregation & snapshotting mode.

required
enable_checkpointing bool

Whether to enable activation checkpointing.

required

forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, pooling_mask=None)

Executes the classification model forward pass.

Parameters:

Name Type Description Default
input_ids Tensor | None

Input token IDS (for Stage 0).

None
hidden_states Tensor | None

Hidden states from previous stage (for Stage > 0).

None
position_ids Tensor | None

Positional indices for RoPE.

None
hidden_states_snapshot Tensor | None

Intermediate state collector.

None
hidden_states_agg_mask Tensor | None

Mask for state aggregation.

None
pooling_mask Tensor | None

Binary mask indicating which token(s) to pool for classification. Note: you can use d9d.dataset.token_pooling_mask_from_attention_mask in your Dataset to preallocate the pooling mask from attention mask.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot'. If on the last stage, also contains 'scores' (logits) of shape [batch, num_labels].

reset_moe_stats()

Resets MoE routing statistics in the backbone.

reset_parameters()

Resets module parameters.

Qwen3MoEForClassificationParameters

Bases: BaseModel

Configuration parameters for Qwen3 Mixture-of-Experts model with a token/sequnce classification head.

Attributes:

Name Type Description
model Qwen3MoEParameters

The configuration for the underlying Qwen3 MoE model.

num_labels int

The number of output labels for classification.

classifier_dropout float

The dropout probability for the classification head.

Qwen3MoELayer

Bases: Module, ModuleLateInit

Implements a single Qwen3 Mixture-of-Experts (MoE) transformer layer.

This layer consists of a Grouped Query Attention mechanism followed by an MoE MLP block, with pre-RMSNorm applied before each sub-layer.

moe_tokens_per_expert property

Returns the number of tokens routed to each expert.

__init__(params)

Constructs a Qwen3MoELayer object.

Parameters:

Name Type Description Default
params Qwen3MoELayerParameters

Configuration parameters for the layer.

required

forward(hidden_states, position_embeddings)

Performs the forward pass of the MoE layer.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor of shape (batch, seq_len, hidden_dim).

required
position_embeddings tuple[Tensor, Tensor]

Tuple containing RoPE precomputed embeddings (cos, sin).

required

Returns:

Type Description
Tensor

Output tensor after attention and MoE blocks, shape (batch, seq_len, hidden_dim).

reset_moe_stats()

Resets statistical counters inside the MoE router (e.g., token counts per expert).

reset_parameters()

Resets module parameters.

Qwen3MoELayerParameters

Bases: BaseModel

Configuration parameters for a single Qwen3 MoE layer.

Attributes:

Name Type Description
hidden_size int

Dimension of the model's hidden states.

intermediate_size int

Dimension of the feed-forward hidden state.

num_experts int

Total number of experts in the MoE layer.

experts_top_k int

Number of experts to route tokens to.

num_attention_heads int

Number of attention heads for the query.

num_key_value_heads int

Number of attention heads for key and value.

rms_norm_eps float

Epsilon value found in the RMSNorm layers.

head_dim int

Dimension of a single attention head.

Qwen3MoEModel

Bases: Module, ModuleLateInit, ModuleSupportsPipelining

The Qwen3 Mixture-of-Experts (MoE) Transformer Decoder backbone.

It is designed to be split across multiple pipeline stages.

moe_tokens_per_expert property

Retrieves the number of tokens routed to each expert across all layers.

Returns:

Type Description
Tensor

A tensor of shape (num_local_layers, num_experts) containing counts.

__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)

Constructs the Qwen3MoEModel object.

Parameters:

Name Type Description Default
params Qwen3MoEParameters

Configuration parameters for the full model.

required
stage PipelineStageInfo

Information about the pipeline stage this instance belongs to.

required
hidden_states_snapshot_mode HiddenStatesAggregationMode

Configures intermediate hidden state aggregation & snapshotting mode

required
enable_checkpointing bool

If True, enables activation checkpointing for transformer layers to save memory.

required

forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None)

Executes the forward pass for the current pipeline stage.

Parameters:

Name Type Description Default
input_ids Tensor | None

Indices of input sequence tokens. Required if this is the first pipeline stage.

None
hidden_states Tensor | None

Hidden states from the previous pipeline stage. Required if this is not the first pipeline stage.

None
position_ids Tensor | None

Indices of positions of each input sequence tokens in the position embeddings.

None
hidden_states_snapshot Tensor | None

Accumulated tensor of aggregated hidden states from previous stages. Used if snapshotting is enabled.

None
hidden_states_agg_mask Tensor | None

Mask used to aggregate hidden states for snapshots.

None

Returns:

Type Description
dict[str, Tensor | None]

A dictionary containing: * 'hidden_states': The output of the last layer in this stage. * 'hidden_states_snapshot': (Optional) The updated snapshot tensor.

output_dtype()

Returns the data type of the model output hidden states.

reset_moe_stats()

Resets routing statistics for all MoE layers in this stage.

reset_parameters()

Resets module parameters

Qwen3MoEParameters

Bases: BaseModel

Configuration parameters for the Qwen3 Mixture-of-Experts model backbone.

Attributes:

Name Type Description
layer Qwen3MoELayerParameters

Configuration shared across all transformer layers.

num_hidden_layers int

The total number of transformer layers.

rope_base int

Base value for RoPE frequency calculation.

max_position_ids int

Maximum sequence length.

split_vocab_size dict[str, int]

A dictionary mapping vocabulary segment names to their sizes.

split_vocab_order list[str]

The sequence in which vocabulary splits are correctly ordered.

pipeline_num_virtual_layers_pre int

The number of 'virtual' layers representing the computational cost of modules on the first stage, before the main layers (e.g., token and positional embeddings).

pipeline_num_virtual_layers_post int

The number of 'virtual' layers representing the computational cost of modules on the last stage, after the main layers (e.g., the final layer normalization and LM head).

d9d.module.parallelism.model.qwen3_moe

parallelize_qwen3_moe_for_causal_lm(dist_context, model, stage)

Parallelizes the Qwen3 MoE Causal LM model.

This function delegates backbone parallelization to parallelize_qwen3_moe_model and additionally configures the language model head with Hybrid Sharded Data Parallelism (HSDP).

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context containing device meshes and topology info.

required
model Qwen3MoEForCausalLM

The Qwen3 MoE Causal LM model to parallelize.

required
stage PipelineStageInfo

Information about the current pipeline stage.

required

parallelize_qwen3_moe_for_classification(dist_context, model, stage)

Parallelizes the Qwen3 MoE classification model.

This function delegates backbone parallelization to parallelize_qwen3_moe_model and additionally configures the classification head with Hybrid Sharded Data Parallelism (HSDP).

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context containing device meshes and topology info.

required
model Qwen3MoEForClassification

The Qwen3 MoE classification model to parallelize.

required
stage PipelineStageInfo

Information about the current pipeline stage.

required

parallelize_qwen3_moe_model(dist_context, model, stage)

Parallelizes the base Qwen3 MoE model components.

This function configures the model layers for distributed execution within a pipeline stage. It applies Hybrid Sharded Data Parallelism (HSDP) to dense components (embeddings, norms, attention) and Expert Parallelism (EP) to the Mixture-of-Experts (MLP) layers.

Current usage constraints: * Tensor Parallelism is not supported (we may implement it later). * Context Parallelism is not supported (we will implement it later).

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context.

required
model Qwen3MoEModel

The Qwen3 MoE base model to parallelize.

required
stage PipelineStageInfo

Information about the current pipeline stage.

required

Raises:

Type Description
ValueError

If Tensor Parallel or Context Parallel is enabled in the context.