Skip to content

NJUxlj/Chinese-MedQA-Qwen2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

99 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Chinese-MedQA-Qwen2

Python PyTorch License Project

项目简介

本项目是一个基于 Qwen2 + Agent + RAG 的医疗问答系统,旨在打通从 SFT/Embedding 训练数据生成,到 SFT 微调、奖励模型微调,再到 DPO/DAPO/GSPO/TRPO 强化学习微调,最终使用 vLLM 部署推理,并通过 AgentFactory 调用医疗多 Agent 会诊系统(MDAgents)进行问诊的完整流水线。

最终目标:基于上述后训练方式,微调出一个使用西医知识进行疾病诊疗的垂直 Qwen2 模型。

项目特性

  • 🏋️ 多种训练框架:支持 SFT、DPO、DAPO、GSPO、TRPO、PPO、GRPO 等多种微调算法
  • 🔧 手写 DPOTrainer:完全自主实现的 DPO 训练器,非第三方库直接调用
  • 多种推理后端:支持 vLLM、XInference、Ollama、Transformers 多种推理方式
  • 🔍 混合检索系统:集成相似度、BM25、L2距离、KNN 等多种检索算法
  • 🤖 Multi-Agent 会诊:基于 MDAgents 的多专科医疗会诊系统
  • 📚 知识库增强:支持 FAISS 向量库 + Milvus + Neo4j 知识图谱

技术栈总结

层次 技术
基础模型 Qwen2 / Qwen3(支持本地部署和 API 调用)
SFT 微调 HuggingFace Trainer + Accelerate
DPO 微调 手写 DPOTrainer(参考 LLaMA-Factory)
强化学习 VeRL 框架(支持 PPO/GRPO/DAPO/SPPO/SPiNO/TRPO)
推理加速 vLLM、XInference、Ollama、Transformers
向量数据库 Milvus + FAISS
图数据库 Neo4j
RAG 框架 LangChain (langchain-core, langchain-community)
Agent 框架 LangChain + 自定义 Multi-Agent (MDAgents)
配置管理 Hydra + OmegaConf
Web 框架 FastAPI + Gradio
评估指标 BLEU、ROUGE、Math(自定义)

项目架构图

Chinese-MedQA-Qwen2
├── src/                          # 核心源码
│   ├── agent/                    # Agent 模块
│   │   ├── base_agent.py         # 基础 Agent 类
│   │   ├── medical_agent.py      # 医疗 Agent
│   │   ├── agent_factory.py      # Agent 工厂
│   │   ├── mdagents/             # 多专科会诊系统
│   │   │   ├── core/             # 核心控制器
│   │   │   └── agents/           # 各专科 Agent
│   │   └── tools/                # Agent 工具集
│   │
│   ├── rag/                      # RAG 流水线
│   │   ├── rag_pipeline.py       # RAG 主流程
│   │   ├── query_processor.py   # 查询处理
│   │   ├── context_builder.py    # 上下文构建
│   │   └── response_generator.py # 响应生成
│   │
│   ├── knowledge_base/            # 知识库
│   │   ├── retrieval/            # 检索器实现
│   │   │   ├── similarity_retriever.py  # 相似度检索
│   │   │   ├── bm25_retriever.py        # BM25 检索
│   │   │   ├── l2_retriever.py          # L2 距离检索
│   │   │   └── knn_retriever.py         # KNN 检索
│   │   ├── embedding/            # 嵌入管理
│   │   ├── kg/                    # 知识图谱
│   │   └── milvus/                # Milvus 接口
│   │
│   ├── training/                  # 训练模块
│   │   ├── trainer/               # 多种训练器
│   │   │   ├── sft_trainer.py     # SFT 训练器
│   │   │   ├── dpo_trainer.py     # DPO 训练器(手写)
│   │   │   ├── dapo_trainer.py     # DAPO 训练器
│   │   │   ├── gspo_trainer.py     # GSPO 训练器(序列级 GRPO)
│   │   │   ├── trpo_trainer.py     # TRPO 训练器
│   │   │   └── reward_model_trainer.py  # 奖励模型
│   │   ├── dataset/               # 数据集处理
│   │   │   └── medical_dataset.py # 医疗数据集类
│   │   └── run_trainer/           # 训练启动脚本
│   │       ├── run_sft_trainer.py
│   │       └── run_dpo_trainer.py
│   │
│   ├── inference/                  # 推理模块
│   │   ├── vllm_inference.py       # vLLM 推理
│   │   ├── xinference_inference.py # XInference 推理
│   │   ├── ollama_inference.py     # Ollama 推理
│   │   └── transformers_inference.py # Transformers 推理
│   │
│   ├── evaluation/                 # 评估模块
│   │   ├── base_evaluator.py       # 评估器基类
│   │   ├── bleu_rouge_evaluator.py # BLEU/ROUGE 评估
│   │   ├── medqa_llm_evaluator.py   # MedQA LLM 评估
│   │   └── math_evaluator.py        # 数学题评估
│   │
│   ├── models/                     # 模型封装
│   │   ├── base_model.py           # 基础模型类
│   │   ├── qwen_model.py           # Qwen 模型封装
│   │   └── api_model.py            # API 模型封装
│   │
│   ├── api/                        # FastAPI 服务
│   │   ├── main.py                 # API 主入口
│   │   └── routers/                # API 路由
│   │
│   ├── config/                     # 配置模块
│   │   ├── config.yaml             # 统一配置文件
│   │   ├── settings.py             # 配置加载器
│   │   └── deepspeed_config/       # DeepSpeed 配置
│   │
│   └── utils/                      # 工具模块
│       ├── logger.py               # 日志管理
│       ├── metrics.py              # 评估指标
│       └── text_utils.py           # 文本处理
│
├── examples/                       # 示例代码
│   ├── train/                      # 训练 YAML 配置
│   │   ├── sft.yaml
│   │   ├── dpo.yaml
│   │   ├── grpo.yaml
│   │   ├── gspo.yaml              # GSPO 训练配置
│   │   ├── ppo.yaml
│   │   └── reward_model.yaml
│   └── mdagents_usage_examples.py  # MDAgents 使用示例
│
├── data/                           # 数据目录
│   ├── raw/                        # 原始数据
│   ├── processed/                  # 处理后数据
│   └── indices/                    # FAISS 索引
│
├── requirements.txt                # 项目依赖
└── README.md

数据集

SFT 数据集

DPO 数据集

RL 训练数据

  • GSM8K 数学题: ~/data/rlhf/gsm8k/

模型权重

主要基座模型: Qwen/Qwen3-4B (HuggingFace)

所有模型路径与服务参数均集中在 src/config/config.yaml

  • local_model.model_path / inference.model_name_or_path - 本地模型路径
  • embedding.model_path - Embedding 模型路径
  • embedding.reranker_model_path - Reranker 模型路径

如何运行本项目

1. 安装依赖

pip install -r requirements.txt

2. 配置项目

本项目只使用 src/config/config.yaml 作为配置入口,不再读取任何环境变量, 不再使用 .env / python-dotenv。所有路径、密钥、端口等参数均在 YAML 中维护:

cp src/config/config.yaml.example src/config/config.yaml
# 然后按需编辑 config.yaml 中的字段,例如:
#   llm.api_key / llm.base_url
#   local_model.model_path / embedding.model_path
#   api_server.host / api_server.port / api_server.admin_api_key
#   milvus.uri / neo4j.password ...

3. 启动 API 服务

docker compose up
cd src/api
python main.py
#
uvicorn main:app --host 0.0.0.0 --port 8000


#
cd Chinese-MedQA-Qwen2
uvicorn src.api.main:app --host 0.0.0.0 --port 8000 --reload

4. 运行 SFT 训练

python src/training/run_trainer/run_sft_trainer.py \
    --model_name_or_path Qwen/Qwen3-4B \
    --train_data_dir ./data \
    --output_dir ./output/sft \
    --per_device_train_batch_size 4 \
    --learning_rate 2e-5 \
    --num_epochs 3 \
    --finetuning_type lora \
    --lora_rank 16

5. 运行 DPO 训练

python src/training/run_trainer/run_dpo_trainer.py \
    --model_name_or_path Qwen/Qwen3-4B \
    --train_data_dir ./data \
    --output_dir ./output/dpo \
    --per_device_train_batch_size 2 \
    --learning_rate 1e-5 \
    --num_epochs 3

6. 使用 YAML 配置训练

# SFT 训练
python src/training/run_trainer/run_sft_trainer.py --config examples/train/sft.yaml

# DPO 训练
python src/training/run_trainer/run_dpo_trainer.py --config examples/train/dpo.yaml

7. vLLM 部署

python -m vllm.entrypoints.openai.api_server \
    --model Qwen/Qwen3-4B \
    --host 0.0.0.0 \
    --port 8000

Evaluation 评估

项目提供多种评估器,位于 src/evaluation/:

评估器 说明
BleuRougeEvaluator BLEU/ROUGE 指标评估
MedQALLMEvaluator 基于 LLM 的医疗 QA 评估
MathEvaluator 数学问题评估
CodeEvaluator 代码执行评估

评估配置在 src/config/config.yamlevaluator 部分。

参考项目

修复日志

  • 2026-04-28:
    • MedQAEvaluator 评估器format_prompt 改为语义匹配(允许同义词),新增 _extract_answer() 方法解析 <answer>True/False</answer> 标签,适配带思考过程的推理模型(MiniMax-M2.7)
    • LLMProvider 流式输出generate() 新增 streaming 参数(默认 False),新增 generate_streaming() 便捷方法;_generate_api() 支持流式 SSE 响应;_generate_local() 向各本地后端传递 streaming 参数;_generate_transformers() 支持 TextIteratorStreamer_generate_vllm() 支持 SamplingParams(stream=True)_generate_ollama() 支持 stream=True
    • QA Router 修复routers/qa.py_build_llm_provider() 新增对 local 模式的支持(传入 model_pathlocal_backend);修复 max_new_tokensmax_tokens 参数名不匹配问题
    • Config 配置更新llmagentvllm 配置段均新增 streaming: false 字段
    • QA API 测试:非流式、流式、模型列表接口均通过本地模型(Qwen3-0.6B)测试验证
    • RAG Service 修复generate_response() 现正确调用 RAGPipeline 而非绕过;修复 hybrid 检索器初始化逻辑(dense_retriever 未创建);rag_minimal_main.py 添加 lifespan 初始化 rag_service;config.yaml 新增缺失的 retriever.bm25 配置段
  • 2026-04-14: 修复 17 项严重问题(P0),包括导入路径修正、配置安全加固、训练逻辑修复、LoggerManager 统一、训练 YAML 配置创建。详见 docs/修复文档/FIX_1.md

License

MIT License

About

基于Qwen2+SFT+DPO的医疗问答系统,项目中使用了自定义的 SFTTrainer/DPOTrainer/TRPOTrainer用于训练,其次,项目还调用各种知识库工具(neo4j, milvus, LDA, 等)进行自动化训练数据生成。另外,使用 vllm 用于推理和部署训好的模型, 该模型会通过 vllm API 来接入一个基于 embedder + Reranker 的 RAG 系统。另外还参考 MDAgents 论文实现了一个多智能体会诊系统,同样也支持 vllm api 接入。

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors