Skip to content

Qwen3 MoE

About

The d9d.module.model.qwen3_dense package implements the Qwen3 Dense model architecture.

The d9d.module.parallelism.model.qwen3_dense package implements default horizontal parallelism strategy for this model.

d9d.module.model.qwen3_dense

Qwen3DenseForCausalLM

Bases: Module, ModuleLateInit, ModuleSupportsPipelining

A Qwen3 Dense model wrapped with a Causal Language Modeling head.

It is designed to be split across multiple pipeline stages.

__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)

Constructs the Qwen3DenseForCausalLM object.

Parameters:

Name Type Description Default
params Qwen3DenseForCausalLMParameters

Full model configuration parameters.

required
stage PipelineStageInfo

Pipeline stage information for this instance.

required
hidden_states_snapshot_mode HiddenStatesAggregationMode

Configures intermediate hidden state aggregation & snapshotting mode.

required
enable_checkpointing bool

Whether to enable activation checkpointing.

required

forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, labels=None)

Executes the model forward pass.

If this is the last stage, it expects labels to be provided and computes the cross-entropy loss (returned as 'logps' typically representing per-token loss).

Parameters:

Name Type Description Default
input_ids Tensor | None

Input token IDS (for Stage 0).

None
hidden_states Tensor | None

Hidden states from previous stage (for Stage > 0).

None
position_ids Tensor | None

Positional indices for RoPE.

None
hidden_states_snapshot Tensor | None

Intermediate state collector.

None
hidden_states_agg_mask Tensor | None

Mask for state aggregation.

None
labels Tensor | None

Target tokens for loss computation (Last Stage).

None

Returns:

Type Description
dict[str, Tensor]

Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot',

dict[str, Tensor]

and per-token 'logps' if on the last stage.

reset_parameters()

Resets module parameters.

Qwen3DenseForCausalLMParameters

Bases: BaseModel

Configuration parameters for Qwen3 Dense model with a Causal Language Modeling head.

Attributes:

Name Type Description
model Qwen3DenseParameters

The configuration for the underlying Qwen3 Dense model.

Qwen3DenseForClassification

Bases: Module, ModuleLateInit, ModuleSupportsPipelining

A Qwen3 Dense model wrapped with a Sequence/Token Classification head.

It is designed to be split across multiple pipeline stages.

__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)

Constructs the Qwen3DenseForClassification object.

Parameters:

Name Type Description Default
params Qwen3DenseForClassificationParameters

Full model configuration parameters.

required
stage PipelineStageInfo

Pipeline stage information for this instance.

required
hidden_states_snapshot_mode HiddenStatesAggregationMode

Configures intermediate hidden state aggregation & snapshotting mode.

required
enable_checkpointing bool

Whether to enable activation checkpointing.

required

forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None, pooling_mask=None)

Executes the classification model forward pass.

Parameters:

Name Type Description Default
input_ids Tensor | None

Input token IDS (for Stage 0).

None
hidden_states Tensor | None

Hidden states from previous stage (for Stage > 0).

None
position_ids Tensor | None

Positional indices for RoPE.

None
hidden_states_snapshot Tensor | None

Intermediate state collector.

None
hidden_states_agg_mask Tensor | None

Mask for state aggregation.

None
pooling_mask Tensor | None

Binary mask indicating which token(s) to pool for classification. Note: you can use d9d.dataset.token_pooling_mask_from_attention_mask in your Dataset to preallocate the pooling mask from attention mask.

None

Returns:

Type Description
dict[str, Tensor]

Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot'. If on the last stage, also contains 'scores' (logits) of shape [batch, num_labels].

reset_parameters()

Resets module parameters.

Qwen3DenseForClassificationParameters

Bases: BaseModel

Configuration parameters for Qwen3 Dense model with a token/sequence classification head.

Attributes:

Name Type Description
model Qwen3DenseParameters

The configuration for the underlying Qwen3 Dense model.

num_labels int

The number of output labels for classification.

classifier_dropout float

The dropout probability for the classification head.

Qwen3DenseLayer

Bases: Module, ModuleLateInit

Implements a single Qwen3 Dense transformer layer.

This layer consists of a Grouped Query Attention mechanism followed by a SwiGLU MLP block, with pre-RMSNorm applied before each sub-layer.

__init__(params)

Constructs a Qwen3DenseLayer object.

Parameters:

Name Type Description Default
params Qwen3DenseLayerParameters

Configuration parameters for the layer.

required

forward(hidden_states, position_embeddings)

Performs the forward pass of the dense layer.

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor of shape (batch, seq_len, hidden_dim).

required
position_embeddings tuple[Tensor, Tensor]

Tuple containing RoPE precomputed embeddings (cos, sin).

required

Returns:

Type Description
Tensor

Output tensor after attention and MLP blocks, shape (batch, seq_len, hidden_dim).

reset_parameters()

Resets module parameters.

Qwen3DenseLayerParameters

Bases: BaseModel

Configuration parameters for a single Qwen3 Dense layer.

Attributes:

Name Type Description
hidden_size int

Dimension of the model's hidden states.

intermediate_size int

Dimension of the feed-forward hidden state.

num_attention_heads int

Number of attention heads for the query.

num_key_value_heads int

Number of attention heads for key and value.

rms_norm_eps float

Epsilon value found in the RMSNorm layers.

head_dim int

Dimension of a single attention head.

Qwen3DenseModel

Bases: Module, ModuleLateInit, ModuleSupportsPipelining

The Qwen3 Dense Transformer Decoder backbone.

It is designed to be split across multiple pipeline stages.

__init__(params, stage, hidden_states_snapshot_mode, enable_checkpointing)

Constructs the Qwen3DenseModel object.

Parameters:

Name Type Description Default
params Qwen3DenseParameters

Configuration parameters for the full model.

required
stage PipelineStageInfo

Information about the pipeline stage this instance belongs to.

required
hidden_states_snapshot_mode HiddenStatesAggregationMode

Configures intermediate hidden state aggregation & snapshotting mode.

required
enable_checkpointing bool

If True, enables activation checkpointing for transformer layers to save memory.

required

forward(input_ids=None, hidden_states=None, position_ids=None, hidden_states_snapshot=None, hidden_states_agg_mask=None)

Executes the forward pass for the current pipeline stage.

Parameters:

Name Type Description Default
input_ids Tensor | None

Indices of input sequence tokens. Required if this is the first pipeline stage.

None
hidden_states Tensor | None

Hidden states from the previous pipeline stage. Required if this is not the first pipeline stage.

None
position_ids Tensor | None

Indices of positions of each input sequence tokens in the position embeddings.

None
hidden_states_snapshot Tensor | None

Accumulated tensor of aggregated hidden states from previous stages. Used if snapshotting is enabled.

None
hidden_states_agg_mask Tensor | None

Mask used to aggregate hidden states for snapshots.

None

Returns:

Type Description
dict[str, Tensor | None]

A dictionary containing: * 'hidden_states': The output of the last layer in this stage. * 'hidden_states_snapshot': (Optional) The updated snapshot tensor.

output_dtype()

Returns the data type of the model output hidden states.

reset_parameters()

Resets module parameters.

Qwen3DenseParameters

Bases: BaseModel

Configuration parameters for the Qwen3 Dense model backbone.

Attributes:

Name Type Description
layer Qwen3DenseLayerParameters

Configuration shared across all transformer layers.

num_hidden_layers int

The total number of transformer layers.

rope_base int

Base value for RoPE frequency calculation.

max_position_ids int

Maximum sequence length.

split_vocab_size dict[str, int]

A dictionary mapping vocabulary segment names to their sizes.

split_vocab_order list[str]

The sequence in which vocabulary splits are correctly ordered.

pipeline_num_virtual_layers_pre int

The number of 'virtual' layers representing the computational cost of modules on the first stage, before the main layers (e.g., token and positional embeddings).

pipeline_num_virtual_layers_post int

The number of 'virtual' layers representing the computational cost of modules on the last stage, after the main layers (e.g., the final layer normalization and LM head).

d9d.module.parallelism.model.qwen3_dense

parallelize_qwen3_dense_for_causal_lm(dist_context, model, stage)

Parallelizes the Qwen3 Dense Causal LM model.

This function delegates backbone parallelization to parallelize_qwen3_dense_model and additionally configures the language model head with Hybrid Sharded Data Parallelism (HSDP).

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context containing device meshes and topology info.

required
model Qwen3DenseForCausalLM

The Qwen3 Dense Causal LM model to parallelize.

required
stage PipelineStageInfo

Information about the current pipeline stage.

required

parallelize_qwen3_dense_for_classification(dist_context, model, stage)

Parallelizes the Qwen3 Dense classification model.

This function delegates backbone parallelization to parallelize_qwen3_dense_model and additionally configures the classification head with Hybrid Sharded Data Parallelism (HSDP).

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context containing device meshes and topology info.

required
model Qwen3DenseForClassification

The Qwen3 Dense classification model to parallelize.

required
stage PipelineStageInfo

Information about the current pipeline stage.

required

parallelize_qwen3_dense_model(dist_context, model, stage)

Parallelizes the base Qwen3 Dense model components.

This function configures the model layers for distributed execution within a pipeline stage. It applies Hybrid Sharded Data Parallelism (HSDP) to dense components (embeddings, norms, attention, MLP).

Current usage constraints: * Tensor Parallelism is not supported (we may implement it later). * Context Parallelism is not supported (we will implement it later).

Parameters:

Name Type Description Default
dist_context DistributedContext

The distributed context.

required
model Qwen3DenseModel

The Qwen3 Dense base model to parallelize.

required
stage PipelineStageInfo

Information about the current pipeline stage.

required

Raises:

Type Description
ValueError

If Tensor Parallel or Context Parallel is enabled in the context.