Skip to content

krafton-ai/moe-to-dense

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pruning and Distilling Mixture-of-Experts into Dense Language Models

Paper

Description

We present the first systematic framework for converting a trained Mixture-of-Experts (MoE) language model into a standard fully dense architecture: experts are scored, selected, and grouped, then concatenated into a dense FFN and refined by knowledge distillation from the MoE teacher. Our key finding is that expert scoring is the dominant design choice, and our diversity-aware DO-ACP scoring with pure pruning outperforms dense-to-dense pruning by +6.3 pp average downstream accuracy after ~4B-token distillation, at 1.6× faster training wall-clock speed.

Primary target: Qwen3-30B-A3B (128 experts, top-8 routing, 30B total / 3B active) → ~3.3B dense student. The codebase also covers three additional families for cross-model validation:

  • DeepSeek-V2-Lite (64 routed + 2 shared experts, top-6)
  • GPT-OSS-20B (32 experts, top-4)
  • EMO / StdMoE 1B-active-14B pair (127 routed + 1 shared experts, top-7) — a matched modularity-aware vs. standard-MoE pair, used to study compatibility with compression-aware pretraining

The MoE-to-Dense Pipeline

The pipeline proceeds in three stages (Section 3 of the paper):

  1. Score and select

    • Score all experts per layer from teacher routing statistics on a small calibration set.
    • Retain the top-K experts by importance.
  2. Group, merge, and concatenate

    • Assign the top-K experts to k groups — copied directly when K=k (pure pruning), or merged by score-weighted averaging when K>k.
    • Concatenate the k group representatives into a single dense FFN, with down-projection scaling to approximate the average routing behavior.
  3. Distill

    • Recover quality lost during compression by distilling the dense student from the original MoE teacher with forward-KL on logits.

Design Space

The design space is parameterized along three axes. The code retains internal identifiers from development; the table below maps each to the term used in the paper.

Axis Paper term Code identifier
Scoring SF — selection frequency selection_freq
PP — pre-selection probability preselect_prob
PS — post-selection probability postselect_prob
CP — conditional probability conditional_prob
ACP — activation-weighted conditional probability reap
DO-CP — D-optimal selection on CP gram_logdet
DO-ACP — D-optimal selection on ACP (best) gld_reap
Grouping RR — round-robin round_robin
WC — weight clustering weight_cluster
RC — router clustering router_cluster
AB — anchor-based dominant_expert
OC — output clustering output_cluster
DP scaling uniform (1/k) uniform
proportional (by selected-expert importance) proportional

The D-optimal scorings (reap, gram_logdet, gld_reap) use a submodular log-determinant criterion on the expert-output Gram matrix to select a diverse subset; they require the Gram matrix and are produced by the full-grid pipeline (see Reproducing the paper). The four base scorings run directly through run_merging_ablation.py.

Installation

The distillation trainer is built on NVIDIA Model Optimizer's modelopt.torch.distill plugins, installed automatically as a dependency.

# clone repository
git clone https://github.com/krafton-ai/moe-to-dense
cd moe-to-dense

# install project (nvidia-modelopt, torch, transformers, datasets, trl)
pip install -e .

# optional extras
pip install -e ".[train]"   # adds deepspeed
pip install -e ".[eval]"    # adds vllm + lm-eval

Requires Python ≥3.10. Distillation uses DeepSpeed ZeRO-2 by default; configs live in configs/. Dependency windows are pinned in pyproject.toml (notably nvidia-modelopt==0.37.0 and transformers>=4.51,<4.58); see the comments there for the rationale.

Repository Layout

.
├── modules/                       # Core library
│   ├── expert_importance.py       # Routing-stat collection
│   ├── expert_merger.py           # Scoring × grouping × scaling → dense FFN
│   ├── gram_logdet.py             # D-optimal (log-det) selection + ACP scores
│   ├── dense_pruner.py            # Dense-to-dense (D2D) pruning baseline
│   ├── distillation_trainer.py    # Forward-KL KD trainer wrapper
│   ├── evaluation.py              # Benchmark + WikiText-2 PPL eval
│   ├── checkpoint_manager.py
│   └── model_adapters.py          # Qwen3 / DeepSeek-V2 / GPT-OSS / EMO adapters
├── scripts/                       # Entry points (see below)
├── configs/                       # DeepSpeed + training configs
└── pyproject.toml

Quick Start

The three canonical stages on a single configuration. Replace model IDs / paths as needed.

1. Merge experts into a dense student

python scripts/run_merging_ablation.py \
    --grid-sweep \
    --model Qwen/Qwen3-30B-A3B \
    --scoring conditional_prob \
    --grouping round_robin \
    --top-k 8

This collects routing stats once and merges each (scoring, grouping, K) config into ./checkpoints/merging_ablation/<config>/, reporting pre-distill WikiText-2 perplexity for each. To build the paper's best student (DO-ACP, gld_reap), use the full-grid pipeline below — it collects the expert-output Gram matrix the D-optimal scorings need.

2. Distill the dense student from the MoE teacher

deepspeed --num_gpus=4 scripts/run_distillation.py \
    --teacher Qwen/Qwen3-30B-A3B \
    --student ./checkpoints/merging_ablation/<config>/ \
    --output-dir ./checkpoints/distilled/<config>/ \
    --deepspeed configs/ds_zero2_config.json \
    --max-steps 3000 \
    --warmup-steps 100 --decay-steps 300 --scheduler-type wsd \
    --learning-rate 1e-4 --min-lr 1e-5 \
    --per-device-batch-size 4 --gradient-accumulation-steps 24 \
    --temperature 1.0

Forward KL on logits (no intermediate / hidden-state loss) is the default and the paper's best objective. Pass --teacher-top-k 16 to distill from expanded teacher routing (k'=2k, +0.70 pp).

3. Evaluate

# benchmark accuracy (vLLM backend; Qwen-family models)
python scripts/evaluate_benchmark_vllm.py \
    --model ./checkpoints/distilled/<config>/final \
    --tasks winogrande,hellaswag,arc_easy,arc_challenge,mmlu

# WikiText-2 perplexity
python scripts/eval_checkpoint_ppl.py \
    ./checkpoints/distilled/<config>/final

For DeepSeek-V2-Lite and GPT-OSS-20B, use the HuggingFace backend (vLLM is not supported for those families in this codebase).

Reproducing the Paper

End-to-end pipeline scripts assemble the three stages above plus baselines.

Script What it runs
scripts/run_full_grid_pipeline.sh 350-config pre-distill PPL sweep (7 scoring × 5 grouping × 2 DP scaling × 5 K) → distill the 35 best scoring×grouping pairs (0.3B tokens)
scripts/run_scaleup_pipeline.sh Extended ~4B-token distillation of the four headline configs (DO-ACP, SF, D2D, Random FFN)
scripts/run_deepseek_pipeline.sh DeepSeek-V2-Lite cross-model validation
scripts/run_gptoss_pipeline.sh GPT-OSS-20B cross-model validation
scripts/run_emo_pipeline.sh · scripts/run_stdmoe_pipeline.sh DO-ACP pure-pruning distillation on the matched modularity-aware / standard-MoE pair (Table 11); eval with run_emo_eval.sh + run_emo_wt2_ppl.sh
scripts/run_base_ppl_sweep.py + run_base_distill_pipeline.sh + run_base_k8_pipeline.sh Base-teacher A/B (pre-distill PPL sweep → distill top configs)
scripts/run_all_evals.sh Parallel benchmark evals across checkpoints

Pipeline scripts honor the DEEPSPEED env var (default deepspeed) and read the HuggingFace cache from $HF_HOME. Checkpoints land under ./checkpoints/ (gitignored); expect tens to hundreds of GB depending on which pipeline you run.

Baselines

Two baselines provide context for the MoE-to-dense results:

  • D2D pruning (modules/dense_pruner.py): dense-to-dense structured pruning of a matched-parameter dense teacher (Qwen3-32B → ~3.4B), distilled with its dense teacher under the same token budget.
  • Random FFN + teacher attention: copies the teacher's attention layers and randomly initializes the dense FFN — a lower bound on what distillation alone recovers.

Citation

@article{kim2026pruning,
  title={Pruning and Distilling Mixture-of-Experts into Dense Language Models},
  author={Kim, Junhyuck and Yun, Jihun and Kim, Haechan and Kim, Gyeongman and Bae, Joonghyun and Cho, Jaewoong},
  journal={arXiv preprint arXiv:2605.28207},
  year={2026}
}

License

Apache License 2.0 — see LICENSE.

Acknowledgements

This work builds on NVIDIA Model Optimizer, whose modelopt.torch.distill plugins power our knowledge-distillation trainer.

About

Pruning and Distilling Mixture-of-Experts into Dense Language Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors