Qwen3 MoE
About
The d9d.module.model.qwen3_dense package implements the Qwen3 Dense model architecture.
The d9d.module.parallelism.model.qwen3_dense package implements default horizontal parallelism strategy for this model.
d9d.module.model.qwen3_dense
Qwen3DenseForCausalLM
Bases: Module, ModuleLateInit, ModuleSupportsPipelining
A Qwen3 Dense model wrapped with a Causal Language Modeling head.
It is designed to be split across multiple pipeline stages.
__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)
Constructs the Qwen3DenseForCausalLM object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Qwen3DenseForCausalLMParameters
|
Full model configuration parameters. |
required |
stage
|
PipelineStageInfo
|
Pipeline stage information for this instance. |
required |
hidden_states_snapshot_mode
|
HiddenStatesAggregationMode
|
Configures intermediate hidden state aggregation & snapshotting mode. |
required |
enable_checkpointing
|
bool
|
Whether to enable activation checkpointing. |
required |
forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, labels=None)
Executes the model forward pass.
If this is the last stage, it expects labels to be provided and computes
the cross-entropy loss (returned as 'logps' typically representing per-token loss).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor | None
|
Input token IDS (for Stage 0). |
None
|
hidden_states
|
Tensor | None
|
Hidden states from previous stage (for Stage > 0). |
None
|
position_ids
|
Tensor | None
|
Positional indices for RoPE. |
None
|
hidden_states_snapshot
|
Tensor | None
|
Intermediate state collector. |
None
|
hidden_states_agg_mask
|
Tensor | None
|
Mask for state aggregation. |
None
|
labels
|
Tensor | None
|
Target tokens for loss computation (Last Stage). |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot', |
dict[str, Tensor]
|
and per-token 'logps' if on the last stage. |
reset_parameters()
Resets module parameters.
Qwen3DenseForCausalLMParameters
Bases: BaseModel
Configuration parameters for Qwen3 Dense model with a Causal Language Modeling head.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Qwen3DenseParameters
|
The configuration for the underlying Qwen3 Dense model. |
Qwen3DenseForClassification
Bases: Module, ModuleLateInit, ModuleSupportsPipelining
A Qwen3 Dense model wrapped with a Sequence/Token Classification head.
It is designed to be split across multiple pipeline stages.
__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)
Constructs the Qwen3DenseForClassification object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Qwen3DenseForClassificationParameters
|
Full model configuration parameters. |
required |
stage
|
PipelineStageInfo
|
Pipeline stage information for this instance. |
required |
hidden_states_snapshot_mode
|
HiddenStatesAggregationMode
|
Configures intermediate hidden state aggregation & snapshotting mode. |
required |
enable_checkpointing
|
bool
|
Whether to enable activation checkpointing. |
required |
forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, pooling_mask=None)
Executes the classification model forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor | None
|
Input token IDS (for Stage 0). |
None
|
hidden_states
|
Tensor | None
|
Hidden states from previous stage (for Stage > 0). |
None
|
position_ids
|
Tensor | None
|
Positional indices for RoPE. |
None
|
hidden_states_snapshot
|
Tensor | None
|
Intermediate state collector. |
None
|
hidden_states_agg_mask
|
Tensor | None
|
Mask for state aggregation. |
None
|
pooling_mask
|
Tensor | None
|
Binary mask indicating which token(s) to pool for classification.
Note: you can use |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot'. If on the last stage, also contains 'scores' (logits) of shape [batch, num_labels]. |
reset_parameters()
Resets module parameters.
Qwen3DenseForClassificationParameters
Bases: BaseModel
Configuration parameters for Qwen3 Dense model with a token/sequence classification head.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Qwen3DenseParameters
|
The configuration for the underlying Qwen3 Dense model. |
num_labels |
int
|
The number of output labels for classification. |
classifier_dropout |
float
|
The dropout probability for the classification head. |
Qwen3DenseLayer
Bases: Module, ModuleLateInit
Implements a single Qwen3 Dense transformer layer.
This layer consists of a Grouped Query Attention mechanism followed by a SwiGLU MLP block, with pre-RMSNorm applied before each sub-layer.
__init__(params)
Constructs a Qwen3DenseLayer object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Qwen3DenseLayerParameters
|
Configuration parameters for the layer. |
required |
forward(hidden_states, position_embeddings)
Performs the forward pass of the dense layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states
|
Tensor
|
Input tensor of shape |
required |
position_embeddings
|
tuple[Tensor, Tensor]
|
Tuple containing RoPE precomputed embeddings (cos, sin). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor after attention and MLP blocks, shape |
reset_parameters()
Resets module parameters.
Qwen3DenseLayerParameters
Bases: BaseModel
Configuration parameters for a single Qwen3 Dense layer.
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_size |
int
|
Dimension of the model's hidden states. |
intermediate_size |
int
|
Dimension of the feed-forward hidden state. |
num_attention_heads |
int
|
Number of attention heads for the query. |
num_key_value_heads |
int
|
Number of attention heads for key and value. |
rms_norm_eps |
float
|
Epsilon value found in the RMSNorm layers. |
head_dim |
int
|
Dimension of a single attention head. |
Qwen3DenseModel
Bases: Module, ModuleLateInit, ModuleSupportsPipelining
The Qwen3 Dense Transformer Decoder backbone.
It is designed to be split across multiple pipeline stages.
__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)
Constructs the Qwen3DenseModel object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
Qwen3DenseParameters
|
Configuration parameters for the full model. |
required |
stage
|
PipelineStageInfo
|
Information about the pipeline stage this instance belongs to. |
required |
hidden_states_snapshot_mode
|
HiddenStatesAggregationMode
|
Configures intermediate hidden state aggregation & snapshotting mode. |
required |
enable_checkpointing
|
bool
|
If True, enables activation checkpointing for transformer layers to save memory. |
required |
forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None)
Executes the forward pass for the current pipeline stage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
Tensor | None
|
Indices of input sequence tokens. Required if this is the first pipeline stage. |
None
|
hidden_states
|
Tensor | None
|
Hidden states from the previous pipeline stage. Required if this is not the first pipeline stage. |
None
|
position_ids
|
Tensor | None
|
Indices of positions of each input sequence tokens in the position embeddings. |
None
|
hidden_states_snapshot
|
Tensor | None
|
Accumulated tensor of aggregated hidden states from previous stages. Used if snapshotting is enabled. |
None
|
hidden_states_agg_mask
|
Tensor | None
|
Mask used to aggregate hidden states for snapshots. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Tensor | None]
|
A dictionary containing: * 'hidden_states': The output of the last layer in this stage. * 'hidden_states_snapshot': (Optional) The updated snapshot tensor. |
output_dtype()
Returns the data type of the model output hidden states.
reset_parameters()
Resets module parameters.
Qwen3DenseParameters
Bases: BaseModel
Configuration parameters for the Qwen3 Dense model backbone.
Attributes:
| Name | Type | Description |
|---|---|---|
layer |
Qwen3DenseLayerParameters
|
Configuration shared across all transformer layers. |
num_hidden_layers |
int
|
The total number of transformer layers. |
rope_base |
int
|
Base value for RoPE frequency calculation. |
max_position_ids |
int
|
Maximum sequence length. |
split_vocab_size |
dict[str, int]
|
A dictionary mapping vocabulary segment names to their sizes. |
split_vocab_order |
list[str]
|
The sequence in which vocabulary splits are correctly ordered. |
pipeline_num_virtual_layers_pre |
int
|
The number of 'virtual' layers representing the computational cost of modules on the first stage, before the main layers (e.g., token and positional embeddings). |
pipeline_num_virtual_layers_post |
int
|
The number of 'virtual' layers representing the computational cost of modules on the last stage, after the main layers (e.g., the final layer normalization and LM head). |
d9d.module.parallelism.model.qwen3_dense
parallelize_qwen3_dense_for_causal_lm(dist_context, model, stage)
Parallelizes the Qwen3 Dense Causal LM model.
This function delegates backbone parallelization to parallelize_qwen3_dense_model
and additionally configures the language model head with Hybrid Sharded Data
Parallelism (HSDP).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context containing device meshes and topology info. |
required |
model
|
Qwen3DenseForCausalLM
|
The Qwen3 Dense Causal LM model to parallelize. |
required |
stage
|
PipelineStageInfo
|
Information about the current pipeline stage. |
required |
parallelize_qwen3_dense_for_classification(dist_context, model, stage)
Parallelizes the Qwen3 Dense classification model.
This function delegates backbone parallelization to parallelize_qwen3_dense_model
and additionally configures the classification head with Hybrid Sharded Data
Parallelism (HSDP).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context containing device meshes and topology info. |
required |
model
|
Qwen3DenseForClassification
|
The Qwen3 Dense classification model to parallelize. |
required |
stage
|
PipelineStageInfo
|
Information about the current pipeline stage. |
required |
parallelize_qwen3_dense_model(dist_context, model, stage)
Parallelizes the base Qwen3 Dense model components.
This function configures the model layers for distributed execution within a pipeline stage. It applies Hybrid Sharded Data Parallelism (HSDP) to dense components (embeddings, norms, attention, MLP).
Current usage constraints: * Tensor Parallelism is not supported (we may implement it later). * Context Parallelism is not supported (we will implement it later).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dist_context
|
DistributedContext
|
The distributed context. |
required |
model
|
Qwen3DenseModel
|
The Qwen3 Dense base model to parallelize. |
required |
stage
|
PipelineStageInfo
|
Information about the current pipeline stage. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If Tensor Parallel or Context Parallel is enabled in the context. |