Skip to content

How to enable Expert Parallelism (EP) for Qwen MoE during decoding (token generation) inference? #8

@dinghongsong

Description

@dinghongsong

Hi team,

I am experimenting with enabling Expert Parallelism (EP) for the Qwen MoE model during the decoding / token generation stage. While EP works as expected during the prefill stage, I encountered issues when enabling EP for token generation.

This issue is mainly related to missing configuration options in official main.py and a compilation failure caused by selective loading constraints during decoding.


Environment

  • Model: Qwen3-30B-A3B
  • Instance type: AWS EC2 trn2.3xlarge
  • Neuron cores: 4
  • Framework: NeuronX / NXDI
  • EP configuration:
    • moe-tp-degree = 1
    • moe-ep-degree = 4
  • Stage: Token generation (decoding)

Issue 1: Missing Expert Parallelism arguments in main.py

Currently, main.py does not expose any command-line arguments for configuring expert parallelism.

Based on the latest NXDI parallelism configuration for MoE models (here), there are two additional arguments:

  • --moe-tp-degree
  • --moe-ep-degree

These options are necessary to control tensor parallelism and expert parallelism independently.

Question

In future updates of main.py, is it planned to officially add --moe-tp-degree and --moe-ep-degree as supported arguments?
Or is it expected that participants manually modify main.py to add support for expert parallelism?


Issue 2: Compilation failure when enabling EP for token generation

After manually adding support for expert parallelism in main.py and running with:

python3 main.py \
        --mode evaluate_all \
        --model-path ~/qwen-30b-a3b/hf_model \
        --compiled-model-path ~/qwen-30b-a3b/traced_model \
        --prompt "What is the capital of France?" \
        --moe-tp-degree 1 \
        --moe-ep-degree 4 

the compilation fails during token generation with the following error:

NotImplementedError: Selective Loading with Expert parallelism is not supported in token generation.

Root Cause Analysis

The failure appears to be caused by the batch size configuration during decoding.

In main.py, the batch size is always set as: args.batch_size = len(args.prompts) (here), which is alway 1. As a result, perc_experts_loaded is always lower than DEFAULT_SELECTIVE_LOADING_THRESHOLDhere), which triggers the NotImplementedError once EP is enabled.

According to the definition of perc_experts_loaded (here), to satisfy: perc_experts_loaded >= DEFAULT_SELECTIVE_LOADING_THRESHOLD, the minimum required total tokens for Qwen3-30B-A3B (which equals batch size during decoding) should be:

outed_experts_mlp_config.num_experts (128) // routed_experts_mlp_config.top_k (8) = 16

After changing args.batch_size = len(args.prompts) in main.py to the minimum value args.batch_size = 16, the model compiles successfully with expert parallelism.

Question

In future updated versions of main.py, is it expected that args.batch_size = len(args.prompts) will be changed to args.batch_size = 16, or will participants be required to implement Selective Loading with Expert Parallelism in token generation (here) by themselves?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions