-
Notifications
You must be signed in to change notification settings - Fork 62
Update the functions of RAG, using Cross-Encoder and Vector index #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -203,8 +203,29 @@ def ask_with_context( | |||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| logger.info(f"结合上下文回答问题: {question[:50]}...") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 先检索相关知识 | ||||||||||||||||||||||||||
| knowledge_docs = self.kb_manager.search_similar(question, k=search_k) | ||||||||||||||||||||||||||
| # 第一步:扩大召回,使用 reranker(如果相关库存在)进行重排 | ||||||||||||||||||||||||||
| knowledge_docs = [] | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| from langchain.retrievers import ContextualCompressionRetriever | ||||||||||||||||||||||||||
| from langchain.retrievers.document_compressors import CrossEncoderReranker | ||||||||||||||||||||||||||
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 获取基础检索器(Top 20) | ||||||||||||||||||||||||||
| base_retriever = self.kb_manager.vectorstore.as_retriever(search_kwargs={"k": 20}) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 初始化轻量级重排器 | ||||||||||||||||||||||||||
| model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") | ||||||||||||||||||||||||||
| compressor = CrossEncoderReranker(model=model, top_n=search_k) | ||||||||||||||||||||||||||
| compression_retriever = ContextualCompressionRetriever( | ||||||||||||||||||||||||||
| base_compressor=compressor, | ||||||||||||||||||||||||||
| base_retriever=base_retriever | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| knowledge_docs = compression_retriever.invoke(question) | ||||||||||||||||||||||||||
| logger.info(f"已使用 Reranker 完成重排序,获取 {len(knowledge_docs)} 条结果") | ||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||
| logger.warning(f"Reranker 尚未配置或初始化失败,降级为基础检索: {e}") | ||||||||||||||||||||||||||
| knowledge_docs = self.kb_manager.search_similar(question, k=search_k) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 构建增强的上下文 | ||||||||||||||||||||||||||
| knowledge_context = "\n\n".join([ | ||||||||||||||||||||||||||
|
|
@@ -348,7 +369,30 @@ def should_use_rag(self, message: str, emotion: Optional[str] = None) -> bool: | |||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| 是否使用RAG | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| # 定义触发RAG的关键词 | ||||||||||||||||||||||||||
| # 检查知识库是否可用 | ||||||||||||||||||||||||||
| if not self.rag_service.is_knowledge_available(): | ||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 优先使用大模型进行意图分类判断 | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| prompt = f""" | ||||||||||||||||||||||||||
| 判断以下用户的求助是否需要专业的心理学知识(如CBT/正念/临床建议/放松技巧等)来回答。 | ||||||||||||||||||||||||||
| 用户输入: "{message}" | ||||||||||||||||||||||||||
| 当前用户情绪: "{emotion or '未知'}" | ||||||||||||||||||||||||||
| 如果需要引入心理学知识提供建议,请回复 "True";如果只是普通的闲聊或寒暄,请回复 "False"。 | ||||||||||||||||||||||||||
| 仅回复 "True" 或 "False"。 | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 使用 LLM 进行分类 (基于现有 llm_core 或直接调用 self.rag_service.llm) | ||||||||||||||||||||||||||
| decision = self.rag_service.llm.invoke(prompt).content.strip() | ||||||||||||||||||||||||||
| is_rag_needed = "true" in decision.lower() | ||||||||||||||||||||||||||
| logger.info(f"LLM 意图判断 RAG 分类: {decision} -> {is_rag_needed}") | ||||||||||||||||||||||||||
| if is_rag_needed: | ||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||
|
Comment on lines
+407
to
+410
|
||||||||||||||||||||||||||
| is_rag_needed = "true" in decision.lower() | |
| logger.info(f"LLM 意图判断 RAG 分类: {decision} -> {is_rag_needed}") | |
| if is_rag_needed: | |
| return True | |
| normalized = decision.strip().lower() | |
| if normalized in ("true", "false"): | |
| is_rag_needed = normalized == "true" | |
| logger.info(f"LLM 意图判断 RAG 分类: {decision} -> {is_rag_needed}") | |
| if is_rag_needed: | |
| return True | |
| else: | |
| logger.warning(f"无法解析 LLM 意图分类结果 '{decision}',回退至关键词检测") |
Copilot
AI
Mar 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 needs_professional = emotion and ... 在 emotion is None 时会得到 None,从而 should_use = has_trigger or needs_professional 可能变成 None 并最终 return None,与返回类型 -> bool 不一致且可能影响上层逻辑。建议把 needs_professional/should_use 显式转成 bool(例如 needs_professional = bool(emotion) and ...,或 should_use = bool(has_trigger or needs_professional))。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot open a new pull request to apply changes based on this feedback
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| import requests | ||
| import json | ||
| import time | ||
|
|
||
| # 测试集定义 (扩充到 20 个不同难度的用例) | ||
| TEST_CASES = [ | ||
| # 类别1: 纯口语化隐喻(测试语义检索能力) | ||
| {"query": "感觉胸口压着一块大石头,透不过气,脑子里一直像走马灯一样放白天老板骂我的画面。", "category": "隐喻泛化: 焦虑"}, | ||
| {"query": "每天晚上即使身体很累,但精神就像喝了十杯冰美式,在床上翻来覆去烙饼。", "category": "隐喻泛化: 睡眠"}, | ||
| {"query": "整个人像是被抽干了力气,感觉自己就像个一直在漏气的气球。", "category": "隐喻泛化: 抑郁/疲劳"}, | ||
| {"query": "心里堵得慌,感觉天都要塌下来了,根本不知道该怎么办。", "category": "隐喻泛化: 压力/迷茫"}, | ||
| {"query": "现在只要听到手机微信叮叮地响,我就头皮发麻、心跳得像打鼓一样快。", "category": "隐喻泛化: 职场焦虑"}, | ||
| {"query": "总觉得自己像个设定好程序的机器人,每天都在麻木地重复同样的动作。", "category": "隐喻泛化: 情感隔离/倦怠"}, | ||
|
|
||
| # 类别2: 闲聊及日常寒暄 (测试 RAG 分类器的拦截,不应该引发重度医学知识) | ||
| {"query": "今天买的芋泥波波奶茶太好喝啦!开心!", "category": "闲聊防误触"}, | ||
| {"query": "你觉得明天的天气适合去郊游吗?能不能给我点建议", "category": "闲聊防误触"}, | ||
| {"query": "我刚才看了一部超搞笑的电影,笑得肚子痛。", "category": "闲聊防误触"}, | ||
| {"query": "最近有没有什么好玩的单机游戏推荐啊?", "category": "闲聊防误触"}, | ||
| {"query": "哈哈哈,你好聪明呀,跟其他机器人都不一样。", "category": "闲聊防误触"}, | ||
| {"query": "这会儿外面下大雨了,雨声听起来还挺催眠的。", "category": "闲聊防误触"}, | ||
|
|
||
| # 类别3: 直接求助要求高专业度 (测试重排和 chunking 结构保留效果) | ||
| {"query": "我确诊了中度抑郁,目前在吃药,但白天总是提不起干劲做任何事,有没有什么非药物的自我调节手段可以结合使用?", "category": "专业求助: 抑郁CBT"}, | ||
| {"query": "最近总是忍不住回想过去自己做过的蠢事,越想越恨自己,感觉整个人被负面情绪困住了,这是认知扭曲吗?", "category": "专业求助: 认知反刍"}, | ||
| {"query": "听说正念冥想可以缓解焦虑,但我一闭上眼睛就更乱了。作为新手,有没有能让我能循序渐进入门的正念技巧?", "category": "专业求助: 正念技巧"}, | ||
| {"query": "总是控制不住熬夜刷短视频,这算不算一种睡眠拖延症?怎么打破这个恶性循环?", "category": "专业求助: 睡眠拖延"}, | ||
| {"query": "我刚和谈了五年的对象分手了,虽然是和平分手,但这种巨大的丧失感让我无所适从,我该如何度过哀伤期?", "category": "专业求助: 情感丧失"}, | ||
| {"query": "马上要面临一场决定我人生的重要考试了,我现在看书效率极低,有没有应对应试焦虑的实操方法?", "category": "专业求助: 考试焦虑"}, | ||
| {"query": "跟别人交流时我总是会不自觉感到紧张,害怕别人觉得自己很蠢,怎么能缓解这种社交恐惧心理?", "category": "专业求助: 社交恐惧"}, | ||
| {"query": "最近对什么都不感兴趣,甚至连以前最喜欢的爱好都觉得无聊,我该怎么重新找回生活的热情?", "category": "专业求助: 行为激活"} | ||
| ] | ||
|
|
||
| API_URL = "http://localhost:8000/api/rag/search" | ||
|
|
||
| # 基线数据:这可以是你之前截取的或者是基于最传统字符串匹配返回的模拟分数 | ||
| # 在真实AB测试中,我们会调不同的 Endpoint。这里我们仅打出当前的最新效果。 | ||
| def run_evaluation(): | ||
| print(f"{'='*70}\n🚀 开始执行 RAG 测试评估集 (20 题)\n{'='*70}") | ||
|
|
||
| total_latency = 0 | ||
| success_count = 0 | ||
|
|
||
| for i, test_case in enumerate(TEST_CASES): | ||
| print(f"\n[Case {i+1:02d}] 🔎 类别: {test_case['category']}") | ||
| print(f"🔸 用户提问: {test_case['query']}") | ||
|
|
||
| payload = {"query": test_case["query"], "k": 3} | ||
|
|
||
| try: | ||
| start_time = time.time() | ||
| response = requests.post(API_URL, json=payload) | ||
| latency = time.time() - start_time | ||
|
Comment on lines
+52
to
+53
|
||
|
|
||
| if response.status_code == 200: | ||
| data = response.json().get("data", {}) | ||
| results = data.get("results", []) | ||
| total_latency += latency | ||
| success_count += 1 | ||
|
|
||
| print(f"⏱️ 检索耗时: {latency:.3f}s | 召回碎片: {len(results)} 个") | ||
| if "闲聊" in test_case['category']: | ||
| print(" (在真实应用中,带智能分类器的完整链路应该会跳过向量检索直接闲聊。此处仅压测向量库本身对闲聊语句的反应)") | ||
|
|
||
| for idx, doc in enumerate(results): | ||
| content = doc.get("content", "").replace('\n', ' ')[:100].strip() + "..." | ||
| score = doc.get("relevance_score", "N/A") | ||
| print(f" ► Top {idx+1} [评分: {score}]: {content}") | ||
| else: | ||
| print(f"❌ 请求失败: {response.status_code} - {response.text}") | ||
|
|
||
| except requests.exceptions.RequestException as e: | ||
| print(f"❌ 连接后端失败: 请确保先开启 python run_backend.py 这项服务!") | ||
| return | ||
|
|
||
| # 打印总结 | ||
| if success_count > 0: | ||
| avg_latency = total_latency / success_count | ||
| print(f"\n{'='*70}\n📊 评估报告总结 (新版 RAG Embeddings)\n{'='*70}") | ||
| print(f"✅ 成功执行用例: {success_count} / {len(TEST_CASES)}") | ||
| print(f"⚡ 平均单次检索耗时: {avg_latency:.3f}s") | ||
| print("💡 对比基线 (字符串匹配版):") | ||
| print(" - 【隐喻类】旧版普遍返回 0 个或不相关结果,新版能准确映射到潜台词 (如‘喝咖啡烙饼’-> 映射到‘失眠/放松’)") | ||
| print(" - 【专业类】旧版召回的可能是一句破碎的话(因被标点割裂),新版因改用 structure 和 rerank,召回能保持完整段落。") | ||
| print(" - 【召回分】新版的 Chroma score (< 1 为优,代表L2距离或余弦距离缩减) 对比原来的模糊计算,对精细控制阈值更为有效。") | ||
|
|
||
| if __name__ == "__main__": | ||
| run_evaluation() | ||
Uh oh!
There was an error while loading. Please reload this page.