From 83e62b3d35f3b2112111c488b2561117a963f606 Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 31 Aug 2025 22:24:15 +0800 Subject: [PATCH 01/17] runable --- ms_agent/agent/llm_agent.py | 13 +- ms_agent/agent/memory/base.py | 2 + ms_agent/agent/memory/default_memory.py | 253 ++++++++++++++++++++++++ ms_agent/agent/memory/mem0.py | 2 - ms_agent/utils/prompts.py | 77 ++++++++ 5 files changed, 339 insertions(+), 8 deletions(-) create mode 100644 ms_agent/agent/memory/default_memory.py delete mode 100644 ms_agent/agent/memory/mem0.py create mode 100644 ms_agent/utils/prompts.py diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 18c192752..b1ad5c1df 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -223,15 +223,19 @@ async def _prepare_messages( messages = await self.rag.query(messages[1].content) return messages - async def _prepare_memory(self): + async def _prepare_memory(self, messages: Optional[List[Message]] = None, **kwargs): """Load and initialize memory components from the config.""" + config, runtime, cache_messages = self._read_history( + messages, **kwargs) if hasattr(self.config, 'memory'): for _memory in (self.config.memory or []): assert _memory.name in memory_mapping, ( f'{_memory.name} not in memory_mapping, ' f'which supports: {list(memory_mapping.keys())}') self.memory_tools.append(memory_mapping[_memory.name]( - self.config)) + self.config, cache_messages, conversation_id=self.task)) + return config, runtime, messages + async def _prepare_planer(self): """Load and initialize the planer component from the config.""" @@ -469,14 +473,11 @@ async def _run(self, messages: Union[List[Message], str], self._prepare_llm() self._prepare_runtime() await self._prepare_tools() - await self._prepare_memory() await self._prepare_planer() await self._prepare_rag() + self.config, self.runtime, messages = await self._prepare_memory(messages, **kwargs) self.runtime.tag = self.tag - self.config, self.runtime, messages = self._read_history( - messages, **kwargs) - if self.runtime.round == 0: # 0 means no history messages = await self._prepare_messages(messages) diff --git a/ms_agent/agent/memory/base.py b/ms_agent/agent/memory/base.py index b8ea7496c..5a7285fcb 100644 --- a/ms_agent/agent/memory/base.py +++ b/ms_agent/agent/memory/base.py @@ -7,6 +7,8 @@ class Memory: """The memory refine tool""" + def __init__(self, config): + self.config = config @abstractmethod async def run(self, messages: List[Message]) -> List[Message]: diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py new file mode 100644 index 000000000..5e2424d40 --- /dev/null +++ b/ms_agent/agent/memory/default_memory.py @@ -0,0 +1,253 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from copy import deepcopy +from typing import List, Dict, Any, Literal, Optional, Set, Tuple + +from langchain.chains.question_answering.map_reduce_prompt import messages + +from ms_agent.llm.utils import Message +from ms_agent.agent.memory import Memory +from ms_agent.utils.prompts import FACT_RETRIEVAL_PROMPT +from mem0 import Memory as Mem0Memory +from ms_agent.utils.logger import logger +from omegaconf import DictConfig, OmegaConf + + +class DefaultMemory(Memory): + """The memory refine tool""" + def __init__(self, + config: DictConfig, + cache_messages: Optional[List[Message]] = None, + conversation_id: Optional[str] = None, + persist: bool = False, + path: str = None, + history_mode: Literal['add', 'overwrite'] = 'overwrite', + current_memory_cache_position: int = 0): + super().__init__(config) + self.cache_messages = cache_messages + self.conversation_id: Optional[str] = conversation_id or getattr(config.memory, 'conversation_id', None) + self.persist: Optional[bool] = persist or getattr(config.memory, 'persist', None) + self.compress: Optional[bool] = getattr(config.memory, 'compress', None) + self.embedder: Optional[str] = getattr(config.memory, 'embedder', None) + self.is_retrieve: Optional[bool] = getattr(config.memory, 'is_retrieve', None) + self.path: Optional[str] = path or getattr(config.memory, 'path', None) + self.history_mode = history_mode or getattr(config.memory, 'history_mode', None) + self.current_memory_cache_position = current_memory_cache_position + self.memory = self._init_memory() + + + def _should_update_memory(self, messages: List[Message]) -> bool: + return True + + def _find_messages_common_prefix(self, + messages: List[Dict], + ignore_role: Optional[Set[str]] = {'system'}, + ignore_fields: Optional[Set[str]] = {'reasoning_content'}, + ) -> Tuple[List[Dict], int, int]: + """ + 比对 messages 和缓存messages的差异,并提取最长公共前缀。 + + Args: + messages: 本次 List[Dict],符合 OpenAI API 格式 + ignore_role: 是否忽略 role="system"、或者role="tool" 的message + ignore_fields: 可选,要忽略比较的字段名集合,如 {"reasoning_content"} + + Returns: + 最长公共前缀(List[Dict]) + """ + if not messages or not isinstance(messages, list): + return [], -1, -1 + + if ignore_fields is None: + ignore_fields = set() + + # 预处理:根据 ignore_role 过滤消息 + def _ignore_role(msgs): + filtered = [] + indices = [] # 每个 filtered 消息对应的原始索引 + for idx, msg in enumerate(msgs): + if ignore_role and msg.get("role") in ignore_role: + continue + filtered.append(msg) + indices.append(idx) + return filtered, indices + + filtered_messages, indices = _ignore_role(messages) + filtered_cache_messages, cache_indices = _ignore_role(self.cache_messages) + + # 找最短长度,避免越界 + min_length = min(len(msgs) for msgs in [filtered_messages, filtered_cache_messages]) + common_prefix = [] + + idx = 0 + for idx in range(min_length): + current_cache_msg = filtered_cache_messages[idx] + current_msg = filtered_messages[idx] + is_common = True + + # 比较其他字段(除了忽略的字段) + all_keys = set(current_cache_msg.keys()).union(set(current_msg.keys())) + for key in all_keys: + if key in ignore_fields: + continue + if current_cache_msg.get(key) != current_msg.get(key): + is_common = False + break + + if not is_common: + break + + # 添加当前消息的深拷贝到结果中(保留原始结构) + common_prefix.append(deepcopy(current_msg)) + + if len(common_prefix) == 0: + return [], -1, -1 + + return common_prefix, indices[idx], cache_indices[idx] + + def rollback(self, common_prefix_messages, cache_message_idx): + # 支持retry机制,将memory回退到 self.cache_messages的第idx 条message + if self.history_mode == 'add': + # 只有覆盖更新模式才支持回退;回退涉及删除 + return + # TODO: 真正的回退 + self.memory.delete_all(user_id=self.conversation_id) + self.memory.add(common_prefix_messages, user_id=self.conversation_id) + + def run(self, messages, ignore_role = None, ignore_fields = None): + print(f'ahahahah?1 : {self.memory.get_all(user_id=self.conversation_id)}') + if not self.cache_messages: + self.cache_messages = messages + common_prefix_messages, messages_idx, cache_message_idx\ + = self._find_messages_common_prefix(messages, + ignore_role=ignore_role, + ignore_fields=ignore_fields) + print(f'ahahahah?2 : {self.memory.get_all(user_id=self.conversation_id)}') + if not self.is_retrieve or not self._should_update_memory(messages): + return messages + print(f'ahahahah?3 : {self.memory.get_all(user_id=self.conversation_id)}') + if self.history_mode == 'add': + print(f'ahahahah?4 : {self.memory.get_all(user_id=self.conversation_id)}') + self.memory.add(messages, user_id=self.conversation_id) + res = self.memory.get_all(user_id=self.conversation_id) + print(f'res: {res}') + else: + print(f'ahahahah?5 : {self.memory.get_all(user_id=self.conversation_id)}') + if cache_message_idx < len(self.cache_messages): + self.rollback(common_prefix_messages, cache_message_idx) + self.cache_messages = messages + print(f'messages: {messages}') + self.memory.add(messages[messages_idx:], user_id=self.conversation_id) + res = self.memory.get_all(user_id=self.conversation_id) + print(f'res: {res}') + print(f'messages[-1]["content"]: {messages[-1]["content"]}') + relevant_memories = self.memory.search(messages[-1]['content'], user_id=self.conversation_id, limit=3) + memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories["results"]) + print(f'memories_str: {memories_str}') + # 将memory对应的messages段删除,并添加相关的memory_str信息 + if messages[0].get('role') == 'system': + system_prompt = messages[0]['content'] + f'\nUser Memories: {memories_str}' + else: + system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\nUser Memories: {memories_str}' + new_messages = [{'role': 'system', 'content': system_prompt}] + messages[messages_idx:] + + return new_messages + + def _init_memory(self) -> Mem0Memory | None: + if not self.is_retrieve: + return + + if self.embedder is None: + # TODO: set default + raise ValueError('embedder must be set when is_retrieve=True.') + embedder = self.embedder + + llm = {} + if self.compress: + llm_config = getattr(self.config.memory, 'llm', None) + if llm_config is not None: + # follow mem0 config + model = llm_config.get('model') + provider = llm_config.get('provider', 'openai') + openai_base_url = llm_config.get('openai_base_url', None) + openai_api_key = llm_config.get('api_key', None) + else: + llm_config = self.config.llm + model = llm_config.model + provider = llm_config.service + openai_base_url = getattr(llm_config, f'{provider}_base_url', None) + openai_api_key = getattr(llm_config, f'{provider}_api_key', None) + + llm = { + "provider": provider, + "config": { + "model": model, + "openai_base_url": openai_base_url, + "api_key": openai_api_key + } + } + + mem0_config = { + "is_infer": self.compress, + "llm": llm, + "custom_fact_extraction_prompt": getattr(self.config.memory, 'fact_retrieval_prompt', FACT_RETRIEVAL_PROMPT), + "vector_store": { + "provider": "qdrant", + "config": { + "path": self.path, + # "on_disk": self.persist + "on_disk": True + } + }, + "embedder": embedder + } + #logger.info(f'Memory config: {mem0_config}') + memory = Mem0Memory.from_config(mem0_config) + memory.add(self.cache_messages, user_id=self.conversation_id) + res = memory.get_all(user_id=self.conversation_id) + print(f'res: {res}') + return memory + +async def main(): + import os + import json + cfg = { + "memory": { + "conversation_id": "default_id", + "persist": True, + "compress": True, + "is_retrieve": True, + "history_mode": "add", + # "embedding_model": "text-embedding-v4", + "llm": { + "provider": "openai", + "model": "qwen3-235b-a22b-instruct-2507", + "openai_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "api_key": os.getenv("DASHSCOPE_API_KEY"), + }, + "embedder": { + "provider": "openai", + "config": { + "api_key": os.getenv("DASHSCOPE_API_KEY"), + "openai_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model": "text-embedding-v4", + } + } + # "vector_store": { + # "provider": "qdrant", + # "config": { + # "path": "/Users/luyan/workspace/mem0/storage", + # "on_disk": False + # } + # } + } + } + with open('openai_format_test_case1.json', 'r') as f: + data = json.load(f) + config = OmegaConf.create(cfg) + memory = DefaultMemory(config, path='./output', cache_messages=data, history_mode='add') + res = memory.run(messages = [{'role': 'user', 'content': '使用bun会对新项目的影响大吗,有哪些新特性'}]) + print(res) + +if __name__ == '__main__': + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/ms_agent/agent/memory/mem0.py b/ms_agent/agent/memory/mem0.py deleted file mode 100644 index da45545af..000000000 --- a/ms_agent/agent/memory/mem0.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# TODO diff --git a/ms_agent/utils/prompts.py b/ms_agent/utils/prompts.py new file mode 100644 index 000000000..c72880031 --- /dev/null +++ b/ms_agent/utils/prompts.py @@ -0,0 +1,77 @@ +from datetime import datetime + + +FACT_RETRIEVAL_PROMPT = f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, preferences, and processing tool interaction outcomes. Your primary role is to extract relevant pieces of information from conversations, organize them into distinct, manageable facts, and additionally process and summarize tool invocation results when present. This ensures both personal data and system interactions are captured for improved context retention and future personalization. + +Types of Information to Remember: +1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment. +2. Maintain Important Personal Details: Remember significant personal information like names, relationships, and important dates. +3. Track Plans and Intentions: Note upcoming events, trips, goals, and any plans the user has shared. +4. Remember Activity and Service Preferences: Recall preferences for dining, travel, hobbies, and other services. +5. Monitor Health and Wellness Preferences: Keep a record of dietary restrictions, fitness routines, and other wellness-related information. +6. Store Professional Details: Remember job titles, work habits, career goals, and other professional information. +7. Miscellaneous Information Management: Keep track of favorite books, movies, brands, and other miscellaneous details that the user shares. + +Tool Interaction Processing Instructions (Additional Responsibilities): +When tool calls and their results are included in the conversation, perform the following in addition to fact extraction: + +1. Extract and Organize Factual Information from Tool Outputs: + - Parse the returned data from successful tool calls (e.g., weather, calendar, search, maps). + - Identify and store objective, user-relevant facts derived from these results (e.g., "It will rain in Paris on 2025-08-25", "The restaurant Little Italy is located at 123 Main St"). + - Integrate these into the "facts" list only if they reflect new, meaningful information about the user's context or environment. +2. Analyze and Summarize Error-Prone Tools: + - Identify tools that frequently fail, time out, or return inconsistent results. + - For such tools, generate a brief internal summary noting the pattern of failure (e.g., "Search tool often returns incomplete results for restaurant queries"). + - This summary does not go into the JSON output but informs future handling (e.g., suggesting alternative tools or double-checking outputs). +3. Identify and Log Tools That Cannot Be Called: + - If a tool was intended but not invoked (e.g., due to missing permissions, unavailability, or misconfiguration), note this in a separate internal log. + - Examples: "Calendar tool unavailable — cannot retrieve user's meeting schedule", "Location access denied — weather tool cannot auto-detect city". + - Include a user-facing reminder if relevant: add a fact like "Could not access calendar due to permission restrictions" only if it impacts user understanding. +4. Ensure Clarity and Non-Disclosure: + - Do not expose tool names, system architecture, or internal logs in the output. + - If asked why information is missing, respond: "I tried to retrieve it from publicly available sources, but the information may not be accessible right now." + +Here are some few-shot examples: +Input: Hi. +Output: {{"facts" : []}} + +Input: There are branches in trees. +Output: {{"facts" : []}} + +Input: Hi, I am looking for a restaurant in San Francisco. +Output: {{"facts" : ["Looking for a restaurant in San Francisco"]}} + +Input: Yesterday, I had a meeting with John at 3pm. We discussed the new project. +Output: {{"facts" : ["Had a meeting with John at 3pm", "Discussed the new project"]}} + +Input: Hi, my name is John. I am a software engineer. +Output: {{"facts" : ["Name is John", "Is a Software engineer"]}} + +Input: My favourite movies are Inception and Interstellar. +Output: {{"facts" : ["Favourite movies are Inception and Interstellar"]}} + +Input (with tool call): What's the weather like in Tokyo today? +[Tool Call: get_weather(location="Tokyo", date="2025-08-22") → Result: {{"status": "success", "data": {{"temp": 32°C, "condition": "Sunny", "humidity": 65%}}}}] +Output: {{"facts": ["It is 32°C and sunny in Tokyo today", "Humidity level in Tokyo is 65%"]}} + +Input (with failed tool): Check my calendar for tomorrow's meetings. +[Tool Call: get_calendar(date="2025-08-23") → Failed: "Access denied – calendar not connected"] +Output: {{"facts": ["Could not access calendar due to connection issues"]}} + +Input (with unreliable tool pattern): Search for vegan restaurants near Central Park. +[Tool Call: search(query="vegan restaurants near Central Park") → Returns incomplete/no results multiple times] +Output: {{"facts": ["Searching for vegan restaurants near Central Park yielded limited results"]}} +(Internal note: Search tool shows low reliability for location-based queries — consider fallback sources.) + +Final Output Rules: + - Today's date is {datetime.now().strftime("%Y-%m-%d")}. + - If the user asks where you fetched my information, answer that you found from publicly available sources on internet. + - Return only a JSON object with key "facts" and value as a list of strings. + - Do not include anything from the example prompts or system instructions. + - Do not reveal tool usage, internal logs, or model behavior. + - If no relevant personal or environmental facts are found, return: {{"facts": []}} + - Extract facts only from user and assistant messages — ignore system-level instructions. + - Detect the input language and record facts in the same language. + +Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation, process any tool call results, and return them in the JSON format as shown above. +""" From 7a8242557187561be33a2e2f08767b3af716d4af Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 31 Aug 2025 22:26:03 +0800 Subject: [PATCH 02/17] fix lint --- ms_agent/agent/llm_agent.py | 8 +- ms_agent/agent/memory/base.py | 1 + ms_agent/agent/memory/default_memory.py | 174 +++++++++++++++--------- ms_agent/utils/prompts.py | 1 - 4 files changed, 113 insertions(+), 71 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index b1ad5c1df..0f305f76c 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -223,7 +223,9 @@ async def _prepare_messages( messages = await self.rag.query(messages[1].content) return messages - async def _prepare_memory(self, messages: Optional[List[Message]] = None, **kwargs): + async def _prepare_memory(self, + messages: Optional[List[Message]] = None, + **kwargs): """Load and initialize memory components from the config.""" config, runtime, cache_messages = self._read_history( messages, **kwargs) @@ -236,7 +238,6 @@ async def _prepare_memory(self, messages: Optional[List[Message]] = None, **kwar self.config, cache_messages, conversation_id=self.task)) return config, runtime, messages - async def _prepare_planer(self): """Load and initialize the planer component from the config.""" if hasattr(self.config, 'planer'): @@ -475,7 +476,8 @@ async def _run(self, messages: Union[List[Message], str], await self._prepare_tools() await self._prepare_planer() await self._prepare_rag() - self.config, self.runtime, messages = await self._prepare_memory(messages, **kwargs) + self.config, self.runtime, messages = await self._prepare_memory( + messages, **kwargs) self.runtime.tag = self.tag if self.runtime.round == 0: diff --git a/ms_agent/agent/memory/base.py b/ms_agent/agent/memory/base.py index 5a7285fcb..409f1d483 100644 --- a/ms_agent/agent/memory/base.py +++ b/ms_agent/agent/memory/base.py @@ -7,6 +7,7 @@ class Memory: """The memory refine tool""" + def __init__(self, config): self.config = config diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index 5e2424d40..241166e9f 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -1,19 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from copy import deepcopy -from typing import List, Dict, Any, Literal, Optional, Set, Tuple +from typing import Any, Dict, List, Literal, Optional, Set, Tuple from langchain.chains.question_answering.map_reduce_prompt import messages - -from ms_agent.llm.utils import Message -from ms_agent.agent.memory import Memory -from ms_agent.utils.prompts import FACT_RETRIEVAL_PROMPT from mem0 import Memory as Mem0Memory +from ms_agent.agent.memory import Memory +from ms_agent.llm.utils import Message from ms_agent.utils.logger import logger +from ms_agent.utils.prompts import FACT_RETRIEVAL_PROMPT from omegaconf import DictConfig, OmegaConf class DefaultMemory(Memory): """The memory refine tool""" + def __init__(self, config: DictConfig, cache_messages: Optional[List[Message]] = None, @@ -24,24 +24,29 @@ def __init__(self, current_memory_cache_position: int = 0): super().__init__(config) self.cache_messages = cache_messages - self.conversation_id: Optional[str] = conversation_id or getattr(config.memory, 'conversation_id', None) - self.persist: Optional[bool] = persist or getattr(config.memory, 'persist', None) - self.compress: Optional[bool] = getattr(config.memory, 'compress', None) + self.conversation_id: Optional[str] = conversation_id or getattr( + config.memory, 'conversation_id', None) + self.persist: Optional[bool] = persist or getattr( + config.memory, 'persist', None) + self.compress: Optional[bool] = getattr(config.memory, 'compress', + None) self.embedder: Optional[str] = getattr(config.memory, 'embedder', None) - self.is_retrieve: Optional[bool] = getattr(config.memory, 'is_retrieve', None) + self.is_retrieve: Optional[bool] = getattr(config.memory, + 'is_retrieve', None) self.path: Optional[str] = path or getattr(config.memory, 'path', None) - self.history_mode = history_mode or getattr(config.memory, 'history_mode', None) + self.history_mode = history_mode or getattr(config.memory, + 'history_mode', None) self.current_memory_cache_position = current_memory_cache_position self.memory = self._init_memory() - def _should_update_memory(self, messages: List[Message]) -> bool: return True - def _find_messages_common_prefix(self, - messages: List[Dict], - ignore_role: Optional[Set[str]] = {'system'}, - ignore_fields: Optional[Set[str]] = {'reasoning_content'}, + def _find_messages_common_prefix( + self, + messages: List[Dict], + ignore_role: Optional[Set[str]] = {'system'}, + ignore_fields: Optional[Set[str]] = {'reasoning_content'}, ) -> Tuple[List[Dict], int, int]: """ 比对 messages 和缓存messages的差异,并提取最长公共前缀。 @@ -63,19 +68,21 @@ def _find_messages_common_prefix(self, # 预处理:根据 ignore_role 过滤消息 def _ignore_role(msgs): filtered = [] - indices = [] # 每个 filtered 消息对应的原始索引 + indices = [] # 每个 filtered 消息对应的原始索引 for idx, msg in enumerate(msgs): - if ignore_role and msg.get("role") in ignore_role: + if ignore_role and msg.get('role') in ignore_role: continue filtered.append(msg) indices.append(idx) return filtered, indices filtered_messages, indices = _ignore_role(messages) - filtered_cache_messages, cache_indices = _ignore_role(self.cache_messages) + filtered_cache_messages, cache_indices = _ignore_role( + self.cache_messages) # 找最短长度,避免越界 - min_length = min(len(msgs) for msgs in [filtered_messages, filtered_cache_messages]) + min_length = min( + len(msgs) for msgs in [filtered_messages, filtered_cache_messages]) common_prefix = [] idx = 0 @@ -85,7 +92,8 @@ def _ignore_role(msgs): is_common = True # 比较其他字段(除了忽略的字段) - all_keys = set(current_cache_msg.keys()).union(set(current_msg.keys())) + all_keys = set(current_cache_msg.keys()).union( + set(current_msg.keys())) for key in all_keys: if key in ignore_fields: continue @@ -113,42 +121,59 @@ def rollback(self, common_prefix_messages, cache_message_idx): self.memory.delete_all(user_id=self.conversation_id) self.memory.add(common_prefix_messages, user_id=self.conversation_id) - def run(self, messages, ignore_role = None, ignore_fields = None): - print(f'ahahahah?1 : {self.memory.get_all(user_id=self.conversation_id)}') + def run(self, messages, ignore_role=None, ignore_fields=None): + print( + f'ahahahah?1 : {self.memory.get_all(user_id=self.conversation_id)}' + ) if not self.cache_messages: self.cache_messages = messages common_prefix_messages, messages_idx, cache_message_idx\ = self._find_messages_common_prefix(messages, ignore_role=ignore_role, ignore_fields=ignore_fields) - print(f'ahahahah?2 : {self.memory.get_all(user_id=self.conversation_id)}') + print( + f'ahahahah?2 : {self.memory.get_all(user_id=self.conversation_id)}' + ) if not self.is_retrieve or not self._should_update_memory(messages): return messages - print(f'ahahahah?3 : {self.memory.get_all(user_id=self.conversation_id)}') + print( + f'ahahahah?3 : {self.memory.get_all(user_id=self.conversation_id)}' + ) if self.history_mode == 'add': - print(f'ahahahah?4 : {self.memory.get_all(user_id=self.conversation_id)}') + print( + f'ahahahah?4 : {self.memory.get_all(user_id=self.conversation_id)}' + ) self.memory.add(messages, user_id=self.conversation_id) res = self.memory.get_all(user_id=self.conversation_id) print(f'res: {res}') else: - print(f'ahahahah?5 : {self.memory.get_all(user_id=self.conversation_id)}') + print( + f'ahahahah?5 : {self.memory.get_all(user_id=self.conversation_id)}' + ) if cache_message_idx < len(self.cache_messages): self.rollback(common_prefix_messages, cache_message_idx) self.cache_messages = messages print(f'messages: {messages}') - self.memory.add(messages[messages_idx:], user_id=self.conversation_id) + self.memory.add( + messages[messages_idx:], user_id=self.conversation_id) res = self.memory.get_all(user_id=self.conversation_id) print(f'res: {res}') print(f'messages[-1]["content"]: {messages[-1]["content"]}') - relevant_memories = self.memory.search(messages[-1]['content'], user_id=self.conversation_id, limit=3) - memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories["results"]) + relevant_memories = self.memory.search( + messages[-1]['content'], user_id=self.conversation_id, limit=3) + memories_str = '\n'.join(f"- {entry['memory']}" + for entry in relevant_memories['results']) print(f'memories_str: {memories_str}') # 将memory对应的messages段删除,并添加相关的memory_str信息 if messages[0].get('role') == 'system': - system_prompt = messages[0]['content'] + f'\nUser Memories: {memories_str}' + system_prompt = messages[0][ + 'content'] + f'\nUser Memories: {memories_str}' else: system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\nUser Memories: {memories_str}' - new_messages = [{'role': 'system', 'content': system_prompt}] + messages[messages_idx:] + new_messages = [{ + 'role': 'system', + 'content': system_prompt + }] + messages[messages_idx:] return new_messages @@ -174,31 +199,38 @@ def _init_memory(self) -> Mem0Memory | None: llm_config = self.config.llm model = llm_config.model provider = llm_config.service - openai_base_url = getattr(llm_config, f'{provider}_base_url', None) - openai_api_key = getattr(llm_config, f'{provider}_api_key', None) + openai_base_url = getattr(llm_config, f'{provider}_base_url', + None) + openai_api_key = getattr(llm_config, f'{provider}_api_key', + None) llm = { - "provider": provider, - "config": { - "model": model, - "openai_base_url": openai_base_url, - "api_key": openai_api_key + 'provider': provider, + 'config': { + 'model': model, + 'openai_base_url': openai_base_url, + 'api_key': openai_api_key } } mem0_config = { - "is_infer": self.compress, - "llm": llm, - "custom_fact_extraction_prompt": getattr(self.config.memory, 'fact_retrieval_prompt', FACT_RETRIEVAL_PROMPT), - "vector_store": { - "provider": "qdrant", - "config": { - "path": self.path, + 'is_infer': + self.compress, + 'llm': + llm, + 'custom_fact_extraction_prompt': + getattr(self.config.memory, 'fact_retrieval_prompt', + FACT_RETRIEVAL_PROMPT), + 'vector_store': { + 'provider': 'qdrant', + 'config': { + 'path': self.path, # "on_disk": self.persist - "on_disk": True + 'on_disk': True } }, - "embedder": embedder + 'embedder': + embedder } #logger.info(f'Memory config: {mem0_config}') memory = Mem0Memory.from_config(mem0_config) @@ -207,29 +239,32 @@ def _init_memory(self) -> Mem0Memory | None: print(f'res: {res}') return memory + async def main(): import os import json cfg = { - "memory": { - "conversation_id": "default_id", - "persist": True, - "compress": True, - "is_retrieve": True, - "history_mode": "add", + 'memory': { + 'conversation_id': 'default_id', + 'persist': True, + 'compress': True, + 'is_retrieve': True, + 'history_mode': 'add', # "embedding_model": "text-embedding-v4", - "llm": { - "provider": "openai", - "model": "qwen3-235b-a22b-instruct-2507", - "openai_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", - "api_key": os.getenv("DASHSCOPE_API_KEY"), + 'llm': { + 'provider': 'openai', + 'model': 'qwen3-235b-a22b-instruct-2507', + 'openai_base_url': + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'api_key': os.getenv('DASHSCOPE_API_KEY'), }, - "embedder": { - "provider": "openai", - "config": { - "api_key": os.getenv("DASHSCOPE_API_KEY"), - "openai_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", - "model": "text-embedding-v4", + 'embedder': { + 'provider': 'openai', + 'config': { + 'api_key': os.getenv('DASHSCOPE_API_KEY'), + 'openai_base_url': + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'model': 'text-embedding-v4', } } # "vector_store": { @@ -244,10 +279,15 @@ async def main(): with open('openai_format_test_case1.json', 'r') as f: data = json.load(f) config = OmegaConf.create(cfg) - memory = DefaultMemory(config, path='./output', cache_messages=data, history_mode='add') - res = memory.run(messages = [{'role': 'user', 'content': '使用bun会对新项目的影响大吗,有哪些新特性'}]) + memory = DefaultMemory( + config, path='./output', cache_messages=data, history_mode='add') + res = memory.run(messages=[{ + 'role': 'user', + 'content': '使用bun会对新项目的影响大吗,有哪些新特性' + }]) print(res) + if __name__ == '__main__': import asyncio - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/ms_agent/utils/prompts.py b/ms_agent/utils/prompts.py index c72880031..9f27cb2e2 100644 --- a/ms_agent/utils/prompts.py +++ b/ms_agent/utils/prompts.py @@ -1,6 +1,5 @@ from datetime import datetime - FACT_RETRIEVAL_PROMPT = f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, preferences, and processing tool interaction outcomes. Your primary role is to extract relevant pieces of information from conversations, organize them into distinct, manageable facts, and additionally process and summarize tool invocation results when present. This ensures both personal data and system interactions are captured for improved context retention and future personalization. Types of Information to Remember: From e250d3cf87cb5ba7f83ceb39b8228c0df797e601 Mon Sep 17 00:00:00 2001 From: suluyan Date: Tue, 2 Sep 2025 17:20:17 +0800 Subject: [PATCH 03/17] update --- ms_agent/agent/llm_agent.py | 15 ++-- ms_agent/agent/memory/__init__.py | 2 +- ms_agent/agent/memory/default_memory.py | 110 ++++++++++-------------- ms_agent/agent/memory/utils.py | 3 +- tests/memory/test_default_memory.py | 91 ++++++++++++++++++++ 5 files changed, 150 insertions(+), 71 deletions(-) create mode 100644 tests/memory/test_default_memory.py diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 0f305f76c..4273035a8 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -230,12 +230,15 @@ async def _prepare_memory(self, config, runtime, cache_messages = self._read_history( messages, **kwargs) if hasattr(self.config, 'memory'): - for _memory in (self.config.memory or []): - assert _memory.name in memory_mapping, ( - f'{_memory.name} not in memory_mapping, ' - f'which supports: {list(memory_mapping.keys())}') - self.memory_tools.append(memory_mapping[_memory.name]( - self.config, cache_messages, conversation_id=self.task)) + memory_type = getattr(self.config.memory, 'name', 'default_memory') + assert memory_type in memory_mapping, ( + f'{memory_type} not in memory_mapping, ' + f'which supports: {list(memory_mapping.keys())}') + + self.memory_tools.append(memory_mapping[memory_type]( + self.config, + cache_messages if isinstance(cache_messages, list) else None, + conversation_id=self.task)) return config, runtime, messages async def _prepare_planer(self): diff --git a/ms_agent/agent/memory/__init__.py b/ms_agent/agent/memory/__init__.py index 3352d6d42..33db1cb01 100644 --- a/ms_agent/agent/memory/__init__.py +++ b/ms_agent/agent/memory/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .base import Memory -from .utils import memory_mapping +from .utils import memory_mapping, DefaultMemory diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index 241166e9f..6fe17ed59 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -20,9 +20,10 @@ def __init__(self, conversation_id: Optional[str] = None, persist: bool = False, path: str = None, - history_mode: Literal['add', 'overwrite'] = 'overwrite', + history_mode: Literal['add', 'overwrite'] = 'add', current_memory_cache_position: int = 0): super().__init__(config) + cache_messages = [message.to_dict() for message in cache_messages] if cache_messages else [] self.cache_messages = cache_messages self.conversation_id: Optional[str] = conversation_id or getattr( config.memory, 'conversation_id', None) @@ -33,7 +34,8 @@ def __init__(self, self.embedder: Optional[str] = getattr(config.memory, 'embedder', None) self.is_retrieve: Optional[bool] = getattr(config.memory, 'is_retrieve', None) - self.path: Optional[str] = path or getattr(config.memory, 'path', None) + self.path: Optional[str] = path or getattr(config.memory, 'path', None) or getattr(self.config, 'output_dir', 'output') + print(f'path: {self.path}') self.history_mode = history_mode or getattr(config.memory, 'history_mode', None) self.current_memory_cache_position = current_memory_cache_position @@ -70,7 +72,7 @@ def _ignore_role(msgs): filtered = [] indices = [] # 每个 filtered 消息对应的原始索引 for idx, msg in enumerate(msgs): - if ignore_role and msg.get('role') in ignore_role: + if ignore_role and getattr(msg, 'role') in ignore_role: continue filtered.append(msg) indices.append(idx) @@ -92,12 +94,11 @@ def _ignore_role(msgs): is_common = True # 比较其他字段(除了忽略的字段) - all_keys = set(current_cache_msg.keys()).union( - set(current_msg.keys())) + all_keys = ['role', 'content', 'reasoning_content', 'tool_calls'] for key in all_keys: if key in ignore_fields: continue - if current_cache_msg.get(key) != current_msg.get(key): + if getattr(current_cache_msg, key, '') != getattr(current_msg, key, ''): is_common = False break @@ -121,60 +122,51 @@ def rollback(self, common_prefix_messages, cache_message_idx): self.memory.delete_all(user_id=self.conversation_id) self.memory.add(common_prefix_messages, user_id=self.conversation_id) - def run(self, messages, ignore_role=None, ignore_fields=None): - print( - f'ahahahah?1 : {self.memory.get_all(user_id=self.conversation_id)}' - ) - if not self.cache_messages: - self.cache_messages = messages - common_prefix_messages, messages_idx, cache_message_idx\ + def add(self, messages: List[Message]) -> None: + messages_dict = [] + for message in messages: + if isinstance(message, Message): + messages_dict.append(message.to_dict()) + else: + messages_dict.append(message) + self.memory.add(messages_dict, user_id=self.conversation_id) + self.cache_messages.extend(messages_dict) + res = self.memory.get_all(user_id=self.conversation_id) + logger.info(f'Add memory done, current memory infos: {"; ".join([item["memory"] for item in res["results"]])}') + + def search(self, query: str) -> str: + relevant_memories = self.memory.search( + query, user_id=self.conversation_id, limit=3) + memories_str = '\n'.join(f"- {entry['memory']}" + for entry in relevant_memories['results']) + return memories_str + + async def run(self, messages, ignore_role=None, ignore_fields=None): + if not self.is_retrieve or not self._should_update_memory(messages): + return messages + common_prefix_messages, messages_idx, cache_message_idx \ = self._find_messages_common_prefix(messages, ignore_role=ignore_role, ignore_fields=ignore_fields) - print( - f'ahahahah?2 : {self.memory.get_all(user_id=self.conversation_id)}' - ) - if not self.is_retrieve or not self._should_update_memory(messages): - return messages - print( - f'ahahahah?3 : {self.memory.get_all(user_id=self.conversation_id)}' - ) - if self.history_mode == 'add': - print( - f'ahahahah?4 : {self.memory.get_all(user_id=self.conversation_id)}' - ) - self.memory.add(messages, user_id=self.conversation_id) - res = self.memory.get_all(user_id=self.conversation_id) - print(f'res: {res}') - else: - print( - f'ahahahah?5 : {self.memory.get_all(user_id=self.conversation_id)}' - ) + if self.history_mode == 'overwrite': if cache_message_idx < len(self.cache_messages): self.rollback(common_prefix_messages, cache_message_idx) - self.cache_messages = messages - print(f'messages: {messages}') - self.memory.add( - messages[messages_idx:], user_id=self.conversation_id) - res = self.memory.get_all(user_id=self.conversation_id) - print(f'res: {res}') - print(f'messages[-1]["content"]: {messages[-1]["content"]}') - relevant_memories = self.memory.search( - messages[-1]['content'], user_id=self.conversation_id, limit=3) - memories_str = '\n'.join(f"- {entry['memory']}" - for entry in relevant_memories['results']) - print(f'memories_str: {memories_str}') + self.add(messages[max(messages_idx, 0):]) + else: + self.add(messages) + + query = getattr(messages[-1], 'content') + memories_str = self.search(query) # 将memory对应的messages段删除,并添加相关的memory_str信息 - if messages[0].get('role') == 'system': - system_prompt = messages[0][ - 'content'] + f'\nUser Memories: {memories_str}' + if getattr(messages[0], 'role') == 'system': + system_prompt = getattr(messages[0], + 'content') + f'\nUser Memories: {memories_str}' else: system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\nUser Memories: {memories_str}' new_messages = [{ 'role': 'system', 'content': system_prompt }] + messages[messages_idx:] - return new_messages def _init_memory(self) -> Mem0Memory | None: @@ -232,11 +224,11 @@ def _init_memory(self) -> Mem0Memory | None: 'embedder': embedder } - #logger.info(f'Memory config: {mem0_config}') + logger.info(f'Memory config: {mem0_config}') memory = Mem0Memory.from_config(mem0_config) - memory.add(self.cache_messages, user_id=self.conversation_id) - res = memory.get_all(user_id=self.conversation_id) - print(f'res: {res}') + if self.cache_messages: + memory.add(self.cache_messages, user_id=self.conversation_id) + print('current memory:', memory.get_all(user_id=self.conversation_id)) return memory @@ -250,7 +242,6 @@ async def main(): 'compress': True, 'is_retrieve': True, 'history_mode': 'add', - # "embedding_model": "text-embedding-v4", 'llm': { 'provider': 'openai', 'model': 'qwen3-235b-a22b-instruct-2507', @@ -267,24 +258,17 @@ async def main(): 'model': 'text-embedding-v4', } } - # "vector_store": { - # "provider": "qdrant", - # "config": { - # "path": "/Users/luyan/workspace/mem0/storage", - # "on_disk": False - # } - # } } } with open('openai_format_test_case1.json', 'r') as f: data = json.load(f) config = OmegaConf.create(cfg) memory = DefaultMemory( - config, path='./output', cache_messages=data, history_mode='add') - res = memory.run(messages=[{ + config, path='./output', cache_messages=None, history_mode='add') + res = await memory.run(messages=[Message({ 'role': 'user', 'content': '使用bun会对新项目的影响大吗,有哪些新特性' - }]) + })]) print(res) diff --git a/ms_agent/agent/memory/utils.py b/ms_agent/agent/memory/utils.py index 44e5d1261..354cda7c0 100644 --- a/ms_agent/agent/memory/utils.py +++ b/ms_agent/agent/memory/utils.py @@ -1,2 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -memory_mapping = {} +from .default_memory import DefaultMemory +memory_mapping = {'default_memory': DefaultMemory} diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py new file mode 100644 index 000000000..961daf5a1 --- /dev/null +++ b/tests/memory/test_default_memory.py @@ -0,0 +1,91 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import json +import math +import os +import unittest + +from ms_agent.agent import LLMAgent + +from ms_agent.agent.memory.default_memory import DefaultMemory +from ms_agent.llm.utils import Message, Tool +from omegaconf import DictConfig, OmegaConf + +from modelscope.utils.test_utils import test_level + + +class TestDefaultMemory(unittest.TestCase): + def setUp(self) -> None: + self.config = { + 'memory': { + 'conversation_id': 'default_id', + 'persist': True, + 'compress': True, + 'is_retrieve': True, + 'history_mode': 'add', + 'llm': { + 'provider': 'openai', + 'model': 'qwen3-235b-a22b-instruct-2507', + 'openai_base_url': + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'api_key': os.getenv('DASHSCOPE_API_KEY'), + }, + 'embedder': { + 'provider': 'openai', + 'config': { + 'api_key': os.getenv('DASHSCOPE_API_KEY'), + 'openai_base_url': + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'model': 'text-embedding-v4', + } + } + } + } + history_file = os.getenv('TEST_MEMORY_LONG_HISTORY_MESSAGES', 'openai_format_test_case1.json') + with open(history_file, 'r') as f: + data = json.load(f) + self.history_messages = data + + @unittest.skip#Unless(test_level() >= 0, 'skip test in current test level') + def test_default(self): + config = OmegaConf.create({}) + memory = DefaultMemory(config) + memory.add(self.history_messages) + res = memory.search(self.query) + print(res) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_agent_use_memory(self): + import os + import yaml + import asyncio + current_dir = os.path.dirname(os.path.abspath(__file__)) + default_config_path = f'{current_dir}/../../ms_agent/agent/agent.yaml' + with open(default_config_path, 'r', encoding='utf-8') as file: + config = yaml.safe_load(file) + config['memory'] = self.config['memory'] + config['local_dir'] = current_dir + config['llm']['modelscope_api_key'] = os.getenv('MODELSCOPE_API_KEY') + async def main(): + agent = LLMAgent(config=OmegaConf.create(config)) + res = await agent.run('使用bun会对新项目的影响大吗,有哪些新特性') + print(res) + + asyncio.run(main()) + + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_compress_persist_add(self): + # 使用压缩的持久能记录用户历史偏好的,不记录tool的 + pass + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_compress_persist_remove(self): + # 中间节点开始retry + pass + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_diff_base_api(self): + pass + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 123ca3265d779d15ecc3beb79a1bd276e0b92b62 Mon Sep 17 00:00:00 2001 From: suluyan Date: Wed, 3 Sep 2025 18:51:16 +0800 Subject: [PATCH 04/17] agent runable --- ms_agent/agent/memory/default_memory.py | 6 +----- tests/memory/test_default_memory.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index 6fe17ed59..1e125e46c 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -163,10 +163,7 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): 'content') + f'\nUser Memories: {memories_str}' else: system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\nUser Memories: {memories_str}' - new_messages = [{ - 'role': 'system', - 'content': system_prompt - }] + messages[messages_idx:] + new_messages = [Message(role='system', content=system_prompt)] + messages[messages_idx:] return new_messages def _init_memory(self) -> Mem0Memory | None: @@ -228,7 +225,6 @@ def _init_memory(self) -> Mem0Memory | None: memory = Mem0Memory.from_config(mem0_config) if self.cache_messages: memory.add(self.cache_messages, user_id=self.conversation_id) - print('current memory:', memory.get_all(user_id=self.conversation_id)) return memory diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index 961daf5a1..cf1e961ab 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -66,7 +66,7 @@ def test_agent_use_memory(self): config['local_dir'] = current_dir config['llm']['modelscope_api_key'] = os.getenv('MODELSCOPE_API_KEY') async def main(): - agent = LLMAgent(config=OmegaConf.create(config)) + agent = LLMAgent(config=OmegaConf.create(config), task='default_id') res = await agent.run('使用bun会对新项目的影响大吗,有哪些新特性') print(res) From dba0c396d5788e380d43b17289d91aecb983f0b6 Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 7 Sep 2025 22:16:14 +0800 Subject: [PATCH 05/17] test agent default memory runabl --- ms_agent/agent/agent.yaml | 2 +- ms_agent/agent/base.py | 6 +- ms_agent/agent/memory/__init__.py | 2 +- ms_agent/agent/memory/default_memory.py | 104 +++++++++++++------ tests/memory/test_default_memory.py | 126 ++++++++++++++++-------- 5 files changed, 166 insertions(+), 74 deletions(-) diff --git a/ms_agent/agent/agent.yaml b/ms_agent/agent/agent.yaml index e29cf3f34..eb8990fc0 100644 --- a/ms_agent/agent/agent.yaml +++ b/ms_agent/agent/agent.yaml @@ -1,6 +1,6 @@ llm: service: modelscope - model: Qwen/Qwen3-235B-A22B + model: Qwen/Qwen3-235B-A22B-Instruct-2507 modelscope_api_key: modelscope_base_url: https://api-inference.modelscope.cn/v1 diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index cbb9db911..b16345020 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -9,7 +9,7 @@ from ms_agent.config import Config from ms_agent.config.config import ConfigLifecycleHandler from ms_agent.llm import Message -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf DEFAULT_YAML = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') @@ -43,10 +43,10 @@ def __init__(self, trust_remote_code: bool = False): if config_dir_or_id is not None: self.config: DictConfig = Config.from_task(config_dir_or_id, env) - elif config is not None: - self.config: DictConfig = config else: self.config: DictConfig = Config.from_task(DEFAULT_YAML) + if config is not None and isinstance(config, DictConfig): + self.config = OmegaConf.merge(self.config, config) if tag is None: self.tag = getattr(config, 'tag', None) or self.DEFAULT_TAG diff --git a/ms_agent/agent/memory/__init__.py b/ms_agent/agent/memory/__init__.py index 33db1cb01..cbbb992be 100644 --- a/ms_agent/agent/memory/__init__.py +++ b/ms_agent/agent/memory/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .base import Memory -from .utils import memory_mapping, DefaultMemory +from .utils import DefaultMemory, memory_mapping diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index 1e125e46c..36be5e933 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -1,9 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os from copy import deepcopy +from functools import partial, wraps from typing import Any, Dict, List, Literal, Optional, Set, Tuple -from langchain.chains.question_answering.map_reduce_prompt import messages -from mem0 import Memory as Mem0Memory from ms_agent.agent.memory import Memory from ms_agent.llm.utils import Message from ms_agent.utils.logger import logger @@ -18,26 +18,31 @@ def __init__(self, config: DictConfig, cache_messages: Optional[List[Message]] = None, conversation_id: Optional[str] = None, - persist: bool = False, + persist: bool = True, path: str = None, history_mode: Literal['add', 'overwrite'] = 'add', current_memory_cache_position: int = 0): super().__init__(config) - cache_messages = [message.to_dict() for message in cache_messages] if cache_messages else [] + cache_messages = [message.to_dict() for message in cache_messages + ] if cache_messages else [] self.cache_messages = cache_messages self.conversation_id: Optional[str] = conversation_id or getattr( config.memory, 'conversation_id', None) self.persist: Optional[bool] = persist or getattr( - config.memory, 'persist', None) + config.memory, 'persist', True) self.compress: Optional[bool] = getattr(config.memory, 'compress', - None) - self.embedder: Optional[str] = getattr(config.memory, 'embedder', None) + True) self.is_retrieve: Optional[bool] = getattr(config.memory, - 'is_retrieve', None) - self.path: Optional[str] = path or getattr(config.memory, 'path', None) or getattr(self.config, 'output_dir', 'output') - print(f'path: {self.path}') + 'is_retrieve', True) + self.path: Optional[str] = path or getattr( + config.memory, 'path', None) or getattr(self.config, 'output_dir', + 'output') self.history_mode = history_mode or getattr(config.memory, - 'history_mode', None) + 'history_mode') + self.ignore_role: List[str] = getattr(config.memory, 'ignore_role', + ['tool', 'system']) + self.ignore_fields: List[str] = getattr(config.memory, 'ignore_fields', + ['reasoning_content']) self.current_memory_cache_position = current_memory_cache_position self.memory = self._init_memory() @@ -98,7 +103,8 @@ def _ignore_role(msgs): for key in all_keys: if key in ignore_fields: continue - if getattr(current_cache_msg, key, '') != getattr(current_msg, key, ''): + if getattr(current_cache_msg, key, '') != getattr( + current_msg, key, ''): is_common = False break @@ -132,7 +138,9 @@ def add(self, messages: List[Message]) -> None: self.memory.add(messages_dict, user_id=self.conversation_id) self.cache_messages.extend(messages_dict) res = self.memory.get_all(user_id=self.conversation_id) - logger.info(f'Add memory done, current memory infos: {"; ".join([item["memory"] for item in res["results"]])}') + logger.info( + f'Add memory done, current memory infos: {"; ".join([item["memory"] for item in res["results"]])}' + ) def search(self, query: str) -> str: relevant_memories = self.memory.search( @@ -144,6 +152,7 @@ def search(self, query: str) -> str: async def run(self, messages, ignore_role=None, ignore_fields=None): if not self.is_retrieve or not self._should_update_memory(messages): return messages + common_prefix_messages, messages_idx, cache_message_idx \ = self._find_messages_common_prefix(messages, ignore_role=ignore_role, @@ -159,21 +168,56 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): memories_str = self.search(query) # 将memory对应的messages段删除,并添加相关的memory_str信息 if getattr(messages[0], 'role') == 'system': - system_prompt = getattr(messages[0], - 'content') + f'\nUser Memories: {memories_str}' + system_prompt = getattr( + messages[0], 'content') + f'\nUser Memories: {memories_str}' else: system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\nUser Memories: {memories_str}' - new_messages = [Message(role='system', content=system_prompt)] + messages[messages_idx:] + new_messages = [Message(role='system', content=system_prompt) + ] + messages[messages_idx:] return new_messages - def _init_memory(self) -> Mem0Memory | None: + def _init_memory(self): + from mem0.memory import utils as mem0_utils + parse_messages_origin = mem0_utils.parse_messages + + @wraps(parse_messages_origin) + def patched_parse_messages(messages, ignore_role): + print('hello!') + response = '' + for msg in messages: + if 'system' not in ignore_role and msg['role'] == 'system': + response += f"system: {msg['content']}\n" + if msg['role'] == 'user': + response += f"user: {msg['content']}\n" + if msg['role'] == 'assistant' and msg['content'] is not None: + response += f"assistant: {msg['content']}\n" + if 'tool' not in ignore_role and msg['role'] == 'tool': + response += f"tool: {msg['content']}\n" + return response + + patched_func = partial( + patched_parse_messages, + ignore_role=self.ignore_role, + ) + + mem0_utils.parse_messages = patched_func + + from mem0 import Memory as Mem0Memory + if not self.is_retrieve: return - if self.embedder is None: - # TODO: set default - raise ValueError('embedder must be set when is_retrieve=True.') - embedder = self.embedder + embedder: Optional[str] = getattr( + self.config.memory, 'embedder', + OmegaConf.create({ + 'provider': 'openai', + 'config': { + 'api_key': os.getenv('DASHSCOPE_API_KEY'), + 'openai_base_url': + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'model': 'text-embedding-v4', + } + })) llm = {} if self.compress: @@ -187,12 +231,12 @@ def _init_memory(self) -> Mem0Memory | None: else: llm_config = self.config.llm model = llm_config.model - provider = llm_config.service - openai_base_url = getattr(llm_config, f'{provider}_base_url', + service = llm_config.service + openai_base_url = getattr(llm_config, f'{service}_base_url', None) - openai_api_key = getattr(llm_config, f'{provider}_api_key', + openai_api_key = getattr(llm_config, f'{service}_api_key', None) - + provider = 'openai' llm = { 'provider': provider, 'config': { @@ -261,10 +305,12 @@ async def main(): config = OmegaConf.create(cfg) memory = DefaultMemory( config, path='./output', cache_messages=None, history_mode='add') - res = await memory.run(messages=[Message({ - 'role': 'user', - 'content': '使用bun会对新项目的影响大吗,有哪些新特性' - })]) + res = await memory.run(messages=[ + Message({ + 'role': 'user', + 'content': '使用bun会对新项目的影响大吗,有哪些新特性' + }) + ]) print(res) diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index cf1e961ab..d0fbe59c1 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -1,21 +1,27 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import json import math import os import unittest +import json from ms_agent.agent import LLMAgent - from ms_agent.agent.memory.default_memory import DefaultMemory -from ms_agent.llm.utils import Message, Tool +from ms_agent.llm.utils import Message, ToolCall +from ms_agent.utils.utils import get_default_config from omegaconf import DictConfig, OmegaConf from modelscope.utils.test_utils import test_level class TestDefaultMemory(unittest.TestCase): + def setUp(self) -> None: - self.config = { + self.config_default_memory = OmegaConf.create( + {'memory': { + 'name': 'default_memory' + }}) + + self.config = OmegaConf.create({ 'memory': { 'conversation_id': 'default_id', 'persist': True, @@ -26,7 +32,7 @@ def setUp(self) -> None: 'provider': 'openai', 'model': 'qwen3-235b-a22b-instruct-2507', 'openai_base_url': - 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'https://dashscope.aliyuncs.com/compatible-mode/v1', 'api_key': os.getenv('DASHSCOPE_API_KEY'), }, 'embedder': { @@ -34,58 +40,98 @@ def setUp(self) -> None: 'config': { 'api_key': os.getenv('DASHSCOPE_API_KEY'), 'openai_base_url': - 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'https://dashscope.aliyuncs.com/compatible-mode/v1', 'model': 'text-embedding-v4', } } } + }) + self.mcp_config = { + 'mcpServers': { + 'fetch': { + 'type': 'sse', + 'url': os.getenv('MCP_SERVER_FETCH_URL'), + } + } } - history_file = os.getenv('TEST_MEMORY_LONG_HISTORY_MESSAGES', 'openai_format_test_case1.json') - with open(history_file, 'r') as f: - data = json.load(f) - self.history_messages = data - - @unittest.skip#Unless(test_level() >= 0, 'skip test in current test level') - def test_default(self): - config = OmegaConf.create({}) - memory = DefaultMemory(config) - memory.add(self.history_messages) - res = memory.search(self.query) - print(res) + self.tool_history = [ + Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园,仅给出名字即可'), + Message( + role='assistant', + content= + '\n用户希望找到北京市朝阳区最炫酷的运动公园,关键词是‘炫酷’,说明用户追求的不仅是功能性,更看重科技感、设计感、互动体验和潮流元素。因此,我需要搜索具备未来感、融合高科技、潮流文化或创新设施的运动场所。\n\n为了解决这个问题,我将采取以下步骤:\n1. 使用awesome_map-search工具搜索北京市朝阳区的运动公园,重点关注‘炫酷’‘科技感’‘潮流’等关键词\n2. 筛选出最具特色、评价最高、视觉冲击力强的公园\n3. 提供运动公园名称。\n\n现在我将调用awesome_map-search工具进行搜索,该工具专为地理+趣味性信息检索设计,支持语义化查询,尤其擅长发现‘宝藏地点’。\n', + tool_calls=[ + ToolCall( + id='call_xyz789CoolPark', + type='function', + tool_name='awesome_map-search', + arguments= + '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地", "max_results": 1, "result_type": "detailed", "include_features": true}' + ) + ]), + Message( + role='tool', + content= + '[{"park_name": "量子跃动·朝阳未来运动穹顶", "address": "北京市朝阳区酒仙桥路8888号", "features": ["反重力悬浮跑道,采用磁悬浮缓震技术,跑步如在月球上跳跃", "AR极限攀岩墙,支持末日火山、星际迷宫等虚拟场景", "DJ动感骑行舱,踩得越快音乐越炸,实时生成能量波形图", "AI教练机器人\'铁燃\',可定制训练计划并用东北话激励用户", "夜光太极湖与漂浮瑜伽平台,湖水含环保夜光藻类,夜晚泛蓝光"], "special_events": ["每周五晚举办炫酷运动会:激光躲避球、荧光舞步跑、机器人拔河"], "tips": ["需完成3道运动谜题解锁入场", "禁止穿拖鞋跑步,否则AI会播放《凉凉》", "建议携带充电宝,拍照太频繁易没电"], "rating": 9.9, "opening_hours": "6:00 - 23:00", "admission": "免费"}]' + ), + Message( + role='assistant', + content= + '\n用户希望找到北京市朝阳区最炫酷的运动公园。我已经通过awesome_map-search工具获取了相关信息。\n\n## 🌟 北京市朝阳区最炫酷运动公园揭晓:「量子跃动·朝阳未来运动穹顶」' + ), + Message(role='user', content='好的,拜拜') + ] - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_agent_use_memory(self): - import os - import yaml + @unittest.skip # Unless(test_level() >= 0, 'skip test in current test level') + def test_default_memory(self): + import uuid import asyncio - current_dir = os.path.dirname(os.path.abspath(__file__)) - default_config_path = f'{current_dir}/../../ms_agent/agent/agent.yaml' - with open(default_config_path, 'r', encoding='utf-8') as file: - config = yaml.safe_load(file) - config['memory'] = self.config['memory'] - config['local_dir'] = current_dir - config['llm']['modelscope_api_key'] = os.getenv('MODELSCOPE_API_KEY') + async def main(): - agent = LLMAgent(config=OmegaConf.create(config), task='default_id') - res = await agent.run('使用bun会对新项目的影响大吗,有哪些新特性') + random_id = str(uuid.uuid4()) + default_memory = OmegaConf.create( + {'memory': { + 'name': 'default_memory' + }}) + agent1 = LLMAgent(config=default_memory, task=random_id) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run('我是素食主义者,我每天早上喝咖啡') + del agent1 + + agent2 = LLMAgent(config=default_memory, task=random_id) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run('请帮我准备明天的三餐食谱') print(res) + assert ('素' in res[-1].content and '咖啡' in res[-1].content) asyncio.run(main()) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_compress_persist_add(self): - # 使用压缩的持久能记录用户历史偏好的,不记录tool的 - pass + def test_agent_tool(self): + import uuid + import asyncio + + async def main(): + random_id = str(uuid.uuid4()) + config = OmegaConf.create({'memory': {'ignore_role': ['system']}}) + agent1 = LLMAgent(config=OmegaConf.create(config), task=random_id) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run(self.tool_history) + del agent1 + + agent2 = LLMAgent(config=OmegaConf.create(config), task=random_id) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run('北京市朝阳区最炫酷的运动公园的地点') + print(res) + assert ('酒仙桥路8888号' in res[-1].content) + + asyncio.run(main()) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_compress_persist_remove(self): + def test_overwrite_with_tool(self): # 中间节点开始retry pass - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_diff_base_api(self): - pass if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From d6910bd446dac38fb8e3e8960f72e218a6ff1fd3 Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 8 Sep 2025 15:30:08 +0800 Subject: [PATCH 06/17] support ignore_role --- .pre-commit-config.yaml | 5 +- ms_agent/agent/memory/default_memory.py | 109 +++++++----------------- ms_agent/agent/memory/utils.py | 1 + tests/memory/test_default_memory.py | 62 ++++---------- 4 files changed, 50 insertions(+), 127 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 558ddc5a8..d05512566 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ + + repos: - repo: https://github.com/pycqa/flake8.git rev: 4.0.0 @@ -7,7 +9,8 @@ repos: (?x)^( thirdparty/| examples/| - tests/run.py + tests/run.py| + ms_agent/utils/prompts.py )$ - repo: https://github.com/PyCQA/isort.git rev: 4.3.21 diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index 36be5e933..121208b3b 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -47,6 +47,7 @@ def __init__(self, self.memory = self._init_memory() def _should_update_memory(self, messages: List[Message]) -> bool: + # TODO: Avoid unnecessary frequent updates and reduce the number of update operations return True def _find_messages_common_prefix( @@ -56,15 +57,15 @@ def _find_messages_common_prefix( ignore_fields: Optional[Set[str]] = {'reasoning_content'}, ) -> Tuple[List[Dict], int, int]: """ - 比对 messages 和缓存messages的差异,并提取最长公共前缀。 + Compare the differences between messages and cached messages, and extract the longest common prefix. Args: - messages: 本次 List[Dict],符合 OpenAI API 格式 - ignore_role: 是否忽略 role="system"、或者role="tool" 的message - ignore_fields: 可选,要忽略比较的字段名集合,如 {"reasoning_content"} + messages: Current list of message dictionaries in OpenAI API format. + ignore_role: Whether to ignore messages with role="system" or role="tool". + ignore_fields: Optional set of field names to exclude from comparison, e.g., {"reasoning_content"}. Returns: - 最长公共前缀(List[Dict]) + The longest common prefix as a list of dictionaries. """ if not messages or not isinstance(messages, list): return [], -1, -1 @@ -72,10 +73,11 @@ def _find_messages_common_prefix( if ignore_fields is None: ignore_fields = set() - # 预处理:根据 ignore_role 过滤消息 + # Preprocessing: filter messages based on ignore_role def _ignore_role(msgs): filtered = [] - indices = [] # 每个 filtered 消息对应的原始索引 + indices = [ + ] # The original index corresponding to each filtered message for idx, msg in enumerate(msgs): if ignore_role and getattr(msg, 'role') in ignore_role: continue @@ -87,7 +89,7 @@ def _ignore_role(msgs): filtered_cache_messages, cache_indices = _ignore_role( self.cache_messages) - # 找最短长度,避免越界 + # Find the shortest length to avoid out-of-bounds access min_length = min( len(msgs) for msgs in [filtered_messages, filtered_cache_messages]) common_prefix = [] @@ -98,7 +100,7 @@ def _ignore_role(msgs): current_msg = filtered_messages[idx] is_common = True - # 比较其他字段(除了忽略的字段) + # Compare other fields except the ignored ones all_keys = ['role', 'content', 'reasoning_content', 'tool_calls'] for key in all_keys: if key in ignore_fields: @@ -111,7 +113,7 @@ def _ignore_role(msgs): if not is_common: break - # 添加当前消息的深拷贝到结果中(保留原始结构) + # Add a deep copy of the current message to the result (preserve original structure) common_prefix.append(deepcopy(current_msg)) if len(common_prefix) == 0: @@ -120,11 +122,11 @@ def _ignore_role(msgs): return common_prefix, indices[idx], cache_indices[idx] def rollback(self, common_prefix_messages, cache_message_idx): - # 支持retry机制,将memory回退到 self.cache_messages的第idx 条message + # Support retry mechanism: roll back memory to the idx-th message in self.cache_messages if self.history_mode == 'add': - # 只有覆盖更新模式才支持回退;回退涉及删除 + # Only overwrite update mode supports rollback; rollback involves deletion return - # TODO: 真正的回退 + # TODO: Implement actual rollback logic self.memory.delete_all(user_id=self.conversation_id) self.memory.add(common_prefix_messages, user_id=self.conversation_id) @@ -166,23 +168,23 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): query = getattr(messages[-1], 'content') memories_str = self.search(query) - # 将memory对应的messages段删除,并添加相关的memory_str信息 + # Remove the messages section corresponding to memory, and add the related memory_str information if getattr(messages[0], 'role') == 'system': system_prompt = getattr( messages[0], 'content') + f'\nUser Memories: {memories_str}' else: - system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\nUser Memories: {memories_str}' + system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\n' \ + f'User Memories: {memories_str}' new_messages = [Message(role='system', content=system_prompt) ] + messages[messages_idx:] return new_messages def _init_memory(self): - from mem0.memory import utils as mem0_utils - parse_messages_origin = mem0_utils.parse_messages + import mem0 + parse_messages_origin = mem0.memory.main.parse_messages @wraps(parse_messages_origin) def patched_parse_messages(messages, ignore_role): - print('hello!') response = '' for msg in messages: if 'system' not in ignore_role and msg['role'] == 'system': @@ -200,9 +202,7 @@ def patched_parse_messages(messages, ignore_role): ignore_role=self.ignore_role, ) - mem0_utils.parse_messages = patched_func - - from mem0 import Memory as Mem0Memory + mem0.memory.main.parse_messages = patched_func if not self.is_retrieve: return @@ -247,73 +247,22 @@ def patched_parse_messages(messages, ignore_role): } mem0_config = { - 'is_infer': - self.compress, - 'llm': - llm, - 'custom_fact_extraction_prompt': - getattr(self.config.memory, 'fact_retrieval_prompt', - FACT_RETRIEVAL_PROMPT), + 'is_infer': self.compress, + 'llm': llm, 'vector_store': { 'provider': 'qdrant', 'config': { 'path': self.path, - # "on_disk": self.persist - 'on_disk': True + 'on_disk': self.persist } }, - 'embedder': - embedder + 'embedder': embedder } logger.info(f'Memory config: {mem0_config}') - memory = Mem0Memory.from_config(mem0_config) + # Prompt content is too long, default logging reduces readability + mem0_config['custom_fact_extraction_prompt'] = getattr( + self.config.memory, 'fact_retrieval_prompt', FACT_RETRIEVAL_PROMPT) + memory = mem0.Memory.from_config(mem0_config) if self.cache_messages: memory.add(self.cache_messages, user_id=self.conversation_id) return memory - - -async def main(): - import os - import json - cfg = { - 'memory': { - 'conversation_id': 'default_id', - 'persist': True, - 'compress': True, - 'is_retrieve': True, - 'history_mode': 'add', - 'llm': { - 'provider': 'openai', - 'model': 'qwen3-235b-a22b-instruct-2507', - 'openai_base_url': - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - 'api_key': os.getenv('DASHSCOPE_API_KEY'), - }, - 'embedder': { - 'provider': 'openai', - 'config': { - 'api_key': os.getenv('DASHSCOPE_API_KEY'), - 'openai_base_url': - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - 'model': 'text-embedding-v4', - } - } - } - } - with open('openai_format_test_case1.json', 'r') as f: - data = json.load(f) - config = OmegaConf.create(cfg) - memory = DefaultMemory( - config, path='./output', cache_messages=None, history_mode='add') - res = await memory.run(messages=[ - Message({ - 'role': 'user', - 'content': '使用bun会对新项目的影响大吗,有哪些新特性' - }) - ]) - print(res) - - -if __name__ == '__main__': - import asyncio - asyncio.run(main()) diff --git a/ms_agent/agent/memory/utils.py b/ms_agent/agent/memory/utils.py index 354cda7c0..22a9f025a 100644 --- a/ms_agent/agent/memory/utils.py +++ b/ms_agent/agent/memory/utils.py @@ -1,3 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .default_memory import DefaultMemory + memory_mapping = {'default_memory': DefaultMemory} diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index d0fbe59c1..5c0c38f26 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -7,7 +7,6 @@ from ms_agent.agent import LLMAgent from ms_agent.agent.memory.default_memory import DefaultMemory from ms_agent.llm.utils import Message, ToolCall -from ms_agent.utils.utils import get_default_config from omegaconf import DictConfig, OmegaConf from modelscope.utils.test_utils import test_level @@ -16,63 +15,34 @@ class TestDefaultMemory(unittest.TestCase): def setUp(self) -> None: - self.config_default_memory = OmegaConf.create( - {'memory': { - 'name': 'default_memory' - }}) - - self.config = OmegaConf.create({ - 'memory': { - 'conversation_id': 'default_id', - 'persist': True, - 'compress': True, - 'is_retrieve': True, - 'history_mode': 'add', - 'llm': { - 'provider': 'openai', - 'model': 'qwen3-235b-a22b-instruct-2507', - 'openai_base_url': - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - 'api_key': os.getenv('DASHSCOPE_API_KEY'), - }, - 'embedder': { - 'provider': 'openai', - 'config': { - 'api_key': os.getenv('DASHSCOPE_API_KEY'), - 'openai_base_url': - 'https://dashscope.aliyuncs.com/compatible-mode/v1', - 'model': 'text-embedding-v4', - } - } - } - }) - self.mcp_config = { - 'mcpServers': { - 'fetch': { - 'type': 'sse', - 'url': os.getenv('MCP_SERVER_FETCH_URL'), - } - } - } self.tool_history = [ - Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园,仅给出名字即可'), + Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园'), Message( role='assistant', content= - '\n用户希望找到北京市朝阳区最炫酷的运动公园,关键词是‘炫酷’,说明用户追求的不仅是功能性,更看重科技感、设计感、互动体验和潮流元素。因此,我需要搜索具备未来感、融合高科技、潮流文化或创新设施的运动场所。\n\n为了解决这个问题,我将采取以下步骤:\n1. 使用awesome_map-search工具搜索北京市朝阳区的运动公园,重点关注‘炫酷’‘科技感’‘潮流’等关键词\n2. 筛选出最具特色、评价最高、视觉冲击力强的公园\n3. 提供运动公园名称。\n\n现在我将调用awesome_map-search工具进行搜索,该工具专为地理+趣味性信息检索设计,支持语义化查询,尤其擅长发现‘宝藏地点’。\n', + '\n用户希望找到北京市朝阳区最炫酷的运动公园,关键词是‘炫酷’,说明用户追求的不仅是功能性,更看重科技感、设计感、互动体验' + '和潮流元素。因此,我需要搜索具备未来感、融合高科技、潮流文化或创新设施的运动场所。\n\n为了解决这个问题,我将采取以下步' + '骤:\n1. 使用awesome_map-search工具搜索北京市朝阳区的运动公园,重点关注‘炫酷’‘科技感’‘潮流’等关键词\n2. 筛选出最' + '具特色、评价最高、视觉冲击力强的公园\n3. 提供运动公园名称。\n\n现在我将调用awesome_map-search工具进行搜索,该工具' + '专为地理+趣味性信息检索设计,支持语义化查询,尤其擅长发现‘宝藏地点’。\n', tool_calls=[ ToolCall( id='call_xyz789CoolPark', type='function', tool_name='awesome_map-search', arguments= - '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地", "max_results": 1, "result_type": "detailed", "include_features": true}' - ) + '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地", "max_results": 1, "result_type":' + '"detailed", "include_features": true}') ]), Message( role='tool', content= - '[{"park_name": "量子跃动·朝阳未来运动穹顶", "address": "北京市朝阳区酒仙桥路8888号", "features": ["反重力悬浮跑道,采用磁悬浮缓震技术,跑步如在月球上跳跃", "AR极限攀岩墙,支持末日火山、星际迷宫等虚拟场景", "DJ动感骑行舱,踩得越快音乐越炸,实时生成能量波形图", "AI教练机器人\'铁燃\',可定制训练计划并用东北话激励用户", "夜光太极湖与漂浮瑜伽平台,湖水含环保夜光藻类,夜晚泛蓝光"], "special_events": ["每周五晚举办炫酷运动会:激光躲避球、荧光舞步跑、机器人拔河"], "tips": ["需完成3道运动谜题解锁入场", "禁止穿拖鞋跑步,否则AI会播放《凉凉》", "建议携带充电宝,拍照太频繁易没电"], "rating": 9.9, "opening_hours": "6:00 - 23:00", "admission": "免费"}]' + '[{"park_name": "量子跃动·朝阳未来运动穹顶", "address": "北京市朝阳区酒仙桥路8888号", "features": ["反重力悬' + '浮跑道,采用磁悬浮缓震技术,跑步如在月球上跳跃", "AR极限攀岩墙,支持末日火山、星际迷宫等虚拟场景", "DJ动感骑行舱,踩' + '得越快音乐越炸,实时生成能量波形图", "AI教练机器人\'铁燃\',可定制训练计划并用东北话激励用户", "夜光太极湖与漂浮瑜伽' + '平台,湖水含环保夜光藻类,夜晚泛蓝光"], "special_events": ["每周五晚举办炫酷运动会:激光躲避球、荧光舞步跑、机器人' + '拔河"], "tips": ["需完成3道运动谜题解锁入场", "禁止穿拖鞋跑步,否则AI会播放《凉凉》", "建议携带充电宝,拍照太频繁易' + '没电"], "rating": 9.9, "opening_hours": "6:00 - 23:00", "admission": "免费"}]' ), Message( role='assistant', @@ -82,7 +52,7 @@ def setUp(self) -> None: Message(role='user', content='好的,拜拜') ] - @unittest.skip # Unless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_default_memory(self): import uuid import asyncio @@ -129,7 +99,7 @@ async def main(): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_overwrite_with_tool(self): - # 中间节点开始retry + # Retry starting from an intermediate node pass From 845f17d562b46f3de0ecf7eee615a61f770d216d Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 8 Sep 2025 15:40:10 +0800 Subject: [PATCH 07/17] minor fix --- tests/memory/test_default_memory.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index 5c0c38f26..5ad02b70c 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -16,7 +16,7 @@ class TestDefaultMemory(unittest.TestCase): def setUp(self) -> None: self.tool_history = [ - Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园'), + Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园。我标记一下,下次去。'), Message( role='assistant', content= @@ -59,10 +59,7 @@ def test_default_memory(self): async def main(): random_id = str(uuid.uuid4()) - default_memory = OmegaConf.create( - {'memory': { - 'name': 'default_memory' - }}) + default_memory = OmegaConf.create({'memory': {}}) agent1 = LLMAgent(config=default_memory, task=random_id) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run('我是素食主义者,我每天早上喝咖啡') From 228ca13d09ea780c881d6e78711180c578a381b7 Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 18 Sep 2025 16:52:43 +0800 Subject: [PATCH 08/17] feat: modify history messages --- ms_agent/agent/llm_agent.py | 11 +- ms_agent/agent/memory/default_memory.py | 406 +++++++++++++++++------- tests/memory/test_default_memory.py | 63 +++- 3 files changed, 363 insertions(+), 117 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 4273035a8..b27f04494 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -236,10 +236,11 @@ async def _prepare_memory(self, f'which supports: {list(memory_mapping.keys())}') self.memory_tools.append(memory_mapping[memory_type]( - self.config, - cache_messages if isinstance(cache_messages, list) else None, - conversation_id=self.task)) - return config, runtime, messages + self.config, conversation_id=self.task)) + + if messages != cache_messages: + runtime.should_stop = False + return config, runtime, cache_messages async def _prepare_planer(self): """Load and initialize the planer component from the config.""" @@ -317,7 +318,7 @@ def _log_output(content: str, tag: str): for _line in line.split('\\n'): logger.info(f'[{tag}] {_line}') - @async_retry(max_attempts=2, delay=1.0) + #@async_retry(max_attempts=2, delay=1.0) async def _step( self, messages: List[Message], tag: str) -> AsyncGenerator[List[Message], Any]: # type: ignore diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index 121208b3b..3a67e56c5 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -1,9 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import hashlib import os from copy import deepcopy from functools import partial, wraps from typing import Any, Dict, List, Literal, Optional, Set, Tuple +import json +import json5 from ms_agent.agent.memory import Memory from ms_agent.llm.utils import Message from ms_agent.utils.logger import logger @@ -11,21 +14,69 @@ from omegaconf import DictConfig, OmegaConf +class MemoryMapping: + memory_id: str = None + memory: str = None + valid: bool = None + enable_idxs: List[int] = [] + disable_idx: int = -1 + + def __init__(self, memory_id: str, value: str, enable_idxs: int + or List[int]): + self.memory_id = memory_id + self.value = value + self.valid = True + if isinstance(enable_idxs, int): + enable_idxs = [enable_idxs] + self.enable_idxs = enable_idxs + + def udpate_idxs(self, enable_idxs: int or List[int]): + if isinstance(enable_idxs, int): + enable_idxs = [enable_idxs] + self.enable_idxs.extend(enable_idxs) + + def disable(self, disable_idx: int): + self.valid = False + self.disable_idx = disable_idx + + def try_enable(self, expired_disable_idx: int): + if expired_disable_idx == self.disable_idx: + self.valid = True + self.disable_idx = -1 + + def get(self): + return self.value + + def to_dict(self) -> Dict: + return { + 'memory_id': self.memory_id, + 'value': self.value, + 'valid': self.valid, + 'enable_idxs': self.enable_idxs.copy(), # 返回副本防止外部修改 + 'disable_idx': self.disable_idx + } + + @classmethod + def from_dict(cls, data: Dict) -> 'MemoryMapping': + instance = cls( + memory_id=data['memory_id'], + value=data['value'], + enable_idxs=data['enable_idxs']) + instance.valid = data['valid'] + instance.disable_idx = data.get('disable_idx', -1) # 兼容旧数据 + return instance + + class DefaultMemory(Memory): """The memory refine tool""" def __init__(self, config: DictConfig, - cache_messages: Optional[List[Message]] = None, conversation_id: Optional[str] = None, persist: bool = True, path: str = None, - history_mode: Literal['add', 'overwrite'] = 'add', - current_memory_cache_position: int = 0): + history_mode: Literal['add', 'overwrite'] = None): super().__init__(config) - cache_messages = [message.to_dict() for message in cache_messages - ] if cache_messages else [] - self.cache_messages = cache_messages self.conversation_id: Optional[str] = conversation_id or getattr( config.memory, 'conversation_id', None) self.persist: Optional[bool] = persist or getattr( @@ -38,99 +89,122 @@ def __init__(self, config.memory, 'path', None) or getattr(self.config, 'output_dir', 'output') self.history_mode = history_mode or getattr(config.memory, - 'history_mode') + 'history_mode', 'add') self.ignore_role: List[str] = getattr(config.memory, 'ignore_role', ['tool', 'system']) self.ignore_fields: List[str] = getattr(config.memory, 'ignore_fields', ['reasoning_content']) - self.current_memory_cache_position = current_memory_cache_position - self.memory = self._init_memory() + self.memory = self._init_memory_obj() + self.init_cache_messages() + + def init_cache_messages(self): + self.load_cache() + if len(self.cache_messages) and not len(self.memory_snapshot): + new_blocks = self._split_into_blocks(self.cache_messages) + for messages in new_blocks: + self.max_msg_id += 1 + self.add(messages, msg_id=self.max_msg_id) + + def save_cache(self): + """ + 将 self.max_msg_id, self.cache_messages, self.memory_snapshot + 保存到 self.path/cache_messages.json + """ + cache_file = os.path.join(self.path, 'cache_messages.json') - def _should_update_memory(self, messages: List[Message]) -> bool: - # TODO: Avoid unnecessary frequent updates and reduce the number of update operations - return True + # 确保目录存在 + os.makedirs(self.path, exist_ok=True) - def _find_messages_common_prefix( - self, - messages: List[Dict], - ignore_role: Optional[Set[str]] = {'system'}, - ignore_fields: Optional[Set[str]] = {'reasoning_content'}, - ) -> Tuple[List[Dict], int, int]: - """ - Compare the differences between messages and cached messages, and extract the longest common prefix. + data = { + 'max_msg_id': self.max_msg_id, + 'cache_messages': { + str(k): ([msg.to_dict() for msg in msg_list], _hash) + for k, (msg_list, _hash) in self.cache_messages.items() + }, + 'memory_snapshot': [mm.to_dict() for mm in self.memory_snapshot] + } - Args: - messages: Current list of message dictionaries in OpenAI API format. - ignore_role: Whether to ignore messages with role="system" or role="tool". - ignore_fields: Optional set of field names to exclude from comparison, e.g., {"reasoning_content"}. + with open(cache_file, 'w', encoding='utf-8') as f: + json5.dump(data, f, indent=2, ensure_ascii=False) - Returns: - The longest common prefix as a list of dictionaries. + def load_cache(self): """ - if not messages or not isinstance(messages, list): - return [], -1, -1 - - if ignore_fields is None: - ignore_fields = set() - - # Preprocessing: filter messages based on ignore_role - def _ignore_role(msgs): - filtered = [] - indices = [ - ] # The original index corresponding to each filtered message - for idx, msg in enumerate(msgs): - if ignore_role and getattr(msg, 'role') in ignore_role: - continue - filtered.append(msg) - indices.append(idx) - return filtered, indices - - filtered_messages, indices = _ignore_role(messages) - filtered_cache_messages, cache_indices = _ignore_role( - self.cache_messages) - - # Find the shortest length to avoid out-of-bounds access - min_length = min( - len(msgs) for msgs in [filtered_messages, filtered_cache_messages]) - common_prefix = [] + 从 self.path/cache_messages.json 加载数据到 + self.max_msg_id, self.cache_messages, self.memory_snapshot + """ + cache_file = os.path.join(self.path, 'cache_messages.json') - idx = 0 - for idx in range(min_length): - current_cache_msg = filtered_cache_messages[idx] - current_msg = filtered_messages[idx] - is_common = True - - # Compare other fields except the ignored ones - all_keys = ['role', 'content', 'reasoning_content', 'tool_calls'] - for key in all_keys: - if key in ignore_fields: - continue - if getattr(current_cache_msg, key, '') != getattr( - current_msg, key, ''): - is_common = False - break - - if not is_common: - break - - # Add a deep copy of the current message to the result (preserve original structure) - common_prefix.append(deepcopy(current_msg)) - - if len(common_prefix) == 0: - return [], -1, -1 - - return common_prefix, indices[idx], cache_indices[idx] - - def rollback(self, common_prefix_messages, cache_message_idx): - # Support retry mechanism: roll back memory to the idx-th message in self.cache_messages - if self.history_mode == 'add': - # Only overwrite update mode supports rollback; rollback involves deletion + if not os.path.exists(cache_file): + # 如果文件不存在,初始化默认值并返回 + self.max_msg_id = -1 + self.cache_messages = {} + self.memory_snapshot = [] return - # TODO: Implement actual rollback logic - self.memory.delete_all(user_id=self.conversation_id) - self.memory.add(common_prefix_messages, user_id=self.conversation_id) - def add(self, messages: List[Message]) -> None: + try: + with open(cache_file, 'r', encoding='utf-8') as f: + data = json5.load(f) + + self.max_msg_id = data.get('max_msg_id', -1) + + # 解析 cache_messages + cache_messages = {} + raw_cache_msgs = data.get('cache_messages', {}) + for k, (msg_list, timestamp) in raw_cache_msgs.items(): + msg_objs = [Message(**msg_dict) for msg_dict in msg_list] + cache_messages[int(k)] = (msg_objs, timestamp) + self.cache_messages = cache_messages + + # 解析 memory_snapshot + self.memory_snapshot = [ + MemoryMapping.from_dict(d) + for d in data.get('memory_snapshot', []) + ] + + except (json.JSONDecodeError, KeyError, Exception) as e: + logger.warning(f'Failed to load cache: {e}') + # 出错时回退到默认状态 + self.max_msg_id = -1 + self.cache_messages = {} + self.memory_snapshot = [] + + def delete_single(self, msg_id: int): + messages_to_delete = self.cache_messages.get(msg_id, None) + if messages_to_delete is None: + return + self.cache_messages.pop(msg_id, None) + if msg_id == self.max_msg_id: + self.max_msg_id = max(self.cache_messages.keys()) + + idx = 0 + while idx < len(self.memory_snapshot): + + enable_ids = self.memory_snapshot[idx].enable_idxs + disable_id = self.memory_snapshot[idx].disable_idx + if msg_id == disable_id: + self.memory_snapshot[idx].try_enable(msg_id) + self.memory._create_memory( + data=self.memory_snapshot[idx].value, + existing_embeddings={}, + metadata={'user_id': self.conversation_id}) + if msg_id in enable_ids: + if len(enable_ids) > 1: + self.memory_snapshot[idx].enable_idxs.remove(msg_id) + else: + self.memory.delete(self.memory_snapshot[idx].memory_id) + self.memory_snapshot.pop(idx) + idx -= 1 # pop后下一条成为当前idx + + idx += 1 + res = self.memory.get_all(user_id=self.conversation_id) # sorted + res = [(item['id'], item['memory']) for item in res['results']] + logger.info(f'Roll back success. All memory info:') + for item in res: + logger.info(item[1]) + + def add(self, messages: List[Message], msg_id: int) -> None: + self.cache_messages[msg_id] = messages, self._hash_block(messages) + messages_dict = [] for message in messages: if isinstance(message, Message): @@ -138,11 +212,34 @@ def add(self, messages: List[Message]) -> None: else: messages_dict.append(message) self.memory.add(messages_dict, user_id=self.conversation_id) - self.cache_messages.extend(messages_dict) - res = self.memory.get_all(user_id=self.conversation_id) - logger.info( - f'Add memory done, current memory infos: {"; ".join([item["memory"] for item in res["results"]])}' - ) + + self.max_msg_id = max(self.max_msg_id, msg_id) + res = self.memory.get_all(user_id=self.conversation_id) # sorted + res = [(item['id'], item['memory']) for item in res['results']] + logger.info(f'Add memory success. All memory info:') + for item in res: + logger.info(item[1]) + valids = [] + unmatched = [] + for id, memory in res: + matched = False + for item in self.memory_snapshot: + if id == item.memory_id: + if item.value == memory and item.valid: + matched = True + valids.append(id) + break + else: + if item.valid: + item.disable(msg_id) + if not matched: + unmatched.append((id, memory)) + for item in self.memory_snapshot: + if item.memory_id not in valids: + item.disable(msg_id) + for (id, memory) in unmatched: + m = MemoryMapping(memory_id=id, value=memory, enable_idxs=msg_id) + self.memory_snapshot.append(m) def search(self, query: str) -> str: relevant_memories = self.memory.search( @@ -151,20 +248,115 @@ def search(self, query: str) -> str: for entry in relevant_memories['results']) return memories_str + def _split_into_blocks(self, + messages: List[Message]) -> List[List[Message]]: + """ + Split messages into blocks where each block starts with a 'user' message + and includes all following non-user messages until the next 'user' (exclusive). + + The very first messages before the first 'user' (e.g., system) are attached to the first user block. + If no user message exists, all messages go into one block. + """ + if not messages: + return [] + + blocks: List[List[Message]] = [] + current_block: List[Message] = [] + + # Handle leading non-user messages (like system) + have_user = False + for msg in messages: + if msg.role != 'user': + current_block.append(msg) + else: + if have_user: + blocks.append(current_block) + current_block = [msg] + else: + current_block.append(msg) + have_user = True + + # Append the last block + if current_block: + blocks.append(current_block) + + return blocks + + def _hash_block(self, block: List[Message]) -> str: + """Compute sha256 hash of a message block for comparison""" + data = [message.to_dict() for message in block] + allow_role = ['user', 'system', 'assistant', 'tool'] + allow_role = [ + role for role in allow_role if role not in self.ignore_role + ] + allow_fields = ['reasoning_content', 'content', 'tool_calls', 'role'] + allow_fields = [ + field for field in allow_fields if field not in self.ignore_fields + ] + + data = [{ + field: value + for field, value in msg.items() if field in allow_fields + } for msg in data if msg['role'] in allow_role] + + block_data = json5.dumps(data) + return hashlib.sha256(block_data.encode('utf-8')).hexdigest() + + def _analyze_messages( + self, + messages: List[Message]) -> Tuple[List[List[Message]], List[int]]: + """ + Analyze incoming messages against cache. + + Returns: + should_add_messages: blocks to add (not in cache or hash changed) + should_delete: list of msg_id to delete (in cache but not in new blocks) + """ + new_blocks = self._split_into_blocks(messages) + self.cache_messages = dict(sorted(self.cache_messages.items())) + + cache_messages = [(key, value) + for key, value in self.cache_messages.items()] + first_unmatched_idx = -1 + for idx in range(len(new_blocks)): + block_hash = self._hash_block(new_blocks[idx]) + if idx < len(cache_messages) - 1 and str(block_hash) == str( + cache_messages[idx][1][1]): + continue + first_unmatched_idx = idx + break + should_delete = [ + item[0] for item in cache_messages[first_unmatched_idx:] + ] if first_unmatched_idx != -1 else [] + should_add_messages = new_blocks[first_unmatched_idx:] + + return should_add_messages, should_delete + + def _get_user_message(self, block: List[Message]) -> Optional[Message]: + """Helper: get the user message from a block, if exists""" + for msg in block: + if msg.role == 'user': + return msg + return None + + def _should_update_memory(self, messages: List[Message]) -> bool: + # TODO: Avoid unnecessary frequent updates and reduce the number of update operations + return True + async def run(self, messages, ignore_role=None, ignore_fields=None): if not self.is_retrieve or not self._should_update_memory(messages): return messages - - common_prefix_messages, messages_idx, cache_message_idx \ - = self._find_messages_common_prefix(messages, - ignore_role=ignore_role, - ignore_fields=ignore_fields) - if self.history_mode == 'overwrite': - if cache_message_idx < len(self.cache_messages): - self.rollback(common_prefix_messages, cache_message_idx) - self.add(messages[max(messages_idx, 0):]) - else: - self.add(messages) + should_add_messages, should_delete = self._analyze_messages(messages) + + if should_delete: + if self.history_mode == 'overwrite': + for msg_id in should_delete: + self.delete_single(msg_id=msg_id) + if should_add_messages: + for messages in should_add_messages: + self.max_msg_id += 1 + self.add(messages, msg_id=self.max_msg_id) + self.save_cache() query = getattr(messages[-1], 'content') memories_str = self.search(query) @@ -175,11 +367,13 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): else: system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\n' \ f'User Memories: {memories_str}' + diff_idx = len(messages) - sum( + [len(block) for block in should_add_messages]) new_messages = [Message(role='system', content=system_prompt) - ] + messages[messages_idx:] + ] + messages[diff_idx:] return new_messages - def _init_memory(self): + def _init_memory_obj(self): import mem0 parse_messages_origin = mem0.memory.main.parse_messages @@ -263,6 +457,4 @@ def patched_parse_messages(messages, ignore_role): mem0_config['custom_fact_extraction_prompt'] = getattr( self.config.memory, 'fact_retrieval_prompt', FACT_RETRIEVAL_PROMPT) memory = mem0.Memory.from_config(mem0_config) - if self.cache_messages: - memory.add(self.cache_messages, user_id=self.conversation_id) return memory diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index 5ad02b70c..648b01107 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -16,7 +16,7 @@ class TestDefaultMemory(unittest.TestCase): def setUp(self) -> None: self.tool_history = [ - Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园。我标记一下,下次去。'), + Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园。记着该地点,下次去。'), Message( role='assistant', content= @@ -64,7 +64,7 @@ async def main(): agent1.config.callbacks.remove('input_callback') # noqa await agent1.run('我是素食主义者,我每天早上喝咖啡') del agent1 - + print('========== 数据准备结束,开始测试 ===========') agent2 = LLMAgent(config=default_memory, task=random_id) agent2.config.callbacks.remove('input_callback') # noqa res = await agent2.run('请帮我准备明天的三餐食谱') @@ -85,7 +85,7 @@ async def main(): agent1.config.callbacks.remove('input_callback') # noqa await agent1.run(self.tool_history) del agent1 - + print('========== 数据准备结束,开始测试 ===========') agent2 = LLMAgent(config=OmegaConf.create(config), task=random_id) agent2.config.callbacks.remove('input_callback') # noqa res = await agent2.run('北京市朝阳区最炫酷的运动公园的地点') @@ -96,8 +96,61 @@ async def main(): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_overwrite_with_tool(self): - # Retry starting from an intermediate node - pass + import uuid + import asyncio + + async def main(): + tool_history1 = self.tool_history[:-1] + [ + Message(role='user', content='你说的这家运动公园已经停业了。'), + Message( + role='assistant', + content= + '用户指出“量子跃动·朝阳未来运动穹顶”已停业。今天是2045年5月7日,需要重新搜索当前仍在运营的最炫酷运动公园。我将调用awesome_map-search工具,增加“2045年在营”等时间相关关键词,确保结果准确且时效性强。', + tool_calls=[ + ToolCall( + id='call_xyz2045NewPark', + type='function', + tool_name='awesome_map-search', + arguments= + '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地 2045年在营", "max_results": 1, "result_type": "detailed", "include_features": true}' + ) + ]), + Message( + role='tool', + content= + '[{"park_name": "星核动力·朝阳元宇宙运动矩阵", "address": "北京市朝阳区奥体南路99号", "features": ["全息投影跑道,每一步触发星际粒子 trail", "意念控制攀岩墙,脑波越专注吸附力越强", "重力可调训练舱,模拟火星/月球/深海环境", "AI虚拟教练‘Neo’支持跨次元形象定制", "夜间悬浮滑板池,地面磁力驱动实现无轮滑行"], "special_events": ["每日黄昏举行‘意识觉醒跑’:用脑机接口同步节奏,集体生成光影风暴"], "tips": ["需提前预约神经适配测试", "禁止情绪剧烈波动,否则系统会启动冷静模式播放白噪音", "建议穿导电运动服提升交互体验"], "rating": 9.8, "opening_hours": "5:30 - 24:00", "admission": "免费(脑纹注册入场)"}]' + ), + Message( + role='assistant', + content= + '2045年最新的最炫酷的运动公园是:星核动力·朝阳元宇宙运动矩阵。位于北京市朝阳区奥体南路99号,融合脑机接口、全息投影与重力调控技术,打造沉浸式未来运动体验。现已开放预约,支持脑纹注册免费入场。' + ), + Message(role='user', content='好的,谢谢。'), + ] + tool_history2 = self.tool_history[:-1] + [ + Message(role='user', content='北京市朝阳区最炫酷的运动公园的地点?') + ] + random_id = str(uuid.uuid4()) + config = OmegaConf.create({ + 'memory': { + 'ignore_role': ['system'], + 'history_mode': 'overwrite' + }, + 'output_dir': f'output/{random_id}' + }) + agent1 = LLMAgent(config=OmegaConf.create(config), task=random_id) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run(tool_history1) + del agent1 + print('========== 数据准备结束,开始测试 ===========') + agent2 = LLMAgent(config=OmegaConf.create(config), task=random_id) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run(tool_history2) + print(res) + assert ('酒仙桥路8888号' in res[-1].content + and '奥体南路' not in res[-1].content) + + asyncio.run(main()) if __name__ == '__main__': From 9f9ed652a6744a1269404a19572e8b38a66e7367 Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 18 Sep 2025 17:12:20 +0800 Subject: [PATCH 09/17] minor fix --- ms_agent/agent/memory/default_memory.py | 3 ++- tests/memory/test_default_memory.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index 3a67e56c5..d946898ca 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -216,7 +216,8 @@ def add(self, messages: List[Message], msg_id: int) -> None: self.max_msg_id = max(self.max_msg_id, msg_id) res = self.memory.get_all(user_id=self.conversation_id) # sorted res = [(item['id'], item['memory']) for item in res['results']] - logger.info(f'Add memory success. All memory info:') + if len(res): + logger.info(f'Add memory success. All memory info:') for item in res: logger.info(item[1]) valids = [] diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index 648b01107..bc26a4743 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -52,14 +52,14 @@ def setUp(self) -> None: Message(role='user', content='好的,拜拜') ] - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skip#Unless(test_level() >= 0, 'skip test in current test level') def test_default_memory(self): import uuid import asyncio async def main(): random_id = str(uuid.uuid4()) - default_memory = OmegaConf.create({'memory': {}}) + default_memory = OmegaConf.create({'memory': {}, 'output_dir': f'output/{random_id}'}) agent1 = LLMAgent(config=default_memory, task=random_id) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run('我是素食主义者,我每天早上喝咖啡') @@ -73,14 +73,14 @@ async def main(): asyncio.run(main()) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skip#Unless(test_level() >= 0, 'skip test in current test level') def test_agent_tool(self): import uuid import asyncio async def main(): random_id = str(uuid.uuid4()) - config = OmegaConf.create({'memory': {'ignore_role': ['system']}}) + config = OmegaConf.create({'memory': {'ignore_role': ['system']}, 'output_dir': f'output/{random_id}'}) agent1 = LLMAgent(config=OmegaConf.create(config), task=random_id) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run(self.tool_history) From c1fb41ed3e8c0a92fff7d3b5251d8575371df331 Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 18 Sep 2025 18:11:05 +0800 Subject: [PATCH 10/17] fix typo --- ms_agent/agent/llm_agent.py | 2 +- ms_agent/agent/memory/default_memory.py | 37 +++++++++++------------ tests/memory/test_default_memory.py | 39 ++++++++++++++++++------- 3 files changed, 49 insertions(+), 29 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index b27f04494..c26b7203f 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -318,7 +318,7 @@ def _log_output(content: str, tag: str): for _line in line.split('\\n'): logger.info(f'[{tag}] {_line}') - #@async_retry(max_attempts=2, delay=1.0) + @async_retry(max_attempts=2, delay=1.0) async def _step( self, messages: List[Message], tag: str) -> AsyncGenerator[List[Message], Any]: # type: ignore diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index d946898ca..ec79228a4 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -52,7 +52,8 @@ def to_dict(self) -> Dict: 'memory_id': self.memory_id, 'value': self.value, 'valid': self.valid, - 'enable_idxs': self.enable_idxs.copy(), # 返回副本防止外部修改 + 'enable_idxs': self.enable_idxs.copy( + ), # Return a copy to prevent external modification 'disable_idx': self.disable_idx } @@ -63,7 +64,8 @@ def from_dict(cls, data: Dict) -> 'MemoryMapping': value=data['value'], enable_idxs=data['enable_idxs']) instance.valid = data['valid'] - instance.disable_idx = data.get('disable_idx', -1) # 兼容旧数据 + instance.disable_idx = data.get('disable_idx', + -1) # Compatible with old data return instance @@ -107,12 +109,11 @@ def init_cache_messages(self): def save_cache(self): """ - 将 self.max_msg_id, self.cache_messages, self.memory_snapshot - 保存到 self.path/cache_messages.json + Save self.max_msg_id, self.cache_messages, and self.memory_snapshot to self.path/cache_messages.json """ cache_file = os.path.join(self.path, 'cache_messages.json') - # 确保目录存在 + # Ensure the directory exists os.makedirs(self.path, exist_ok=True) data = { @@ -129,13 +130,12 @@ def save_cache(self): def load_cache(self): """ - 从 self.path/cache_messages.json 加载数据到 - self.max_msg_id, self.cache_messages, self.memory_snapshot + Load data from self.path/cache_messages.json into self.max_msg_id, self.cache_messages, and self.memory_snapshot """ cache_file = os.path.join(self.path, 'cache_messages.json') if not os.path.exists(cache_file): - # 如果文件不存在,初始化默认值并返回 + # If the file does not exist, initialize default values and return. self.max_msg_id = -1 self.cache_messages = {} self.memory_snapshot = [] @@ -147,7 +147,7 @@ def load_cache(self): self.max_msg_id = data.get('max_msg_id', -1) - # 解析 cache_messages + # Parse cache_messages cache_messages = {} raw_cache_msgs = data.get('cache_messages', {}) for k, (msg_list, timestamp) in raw_cache_msgs.items(): @@ -155,7 +155,7 @@ def load_cache(self): cache_messages[int(k)] = (msg_objs, timestamp) self.cache_messages = cache_messages - # 解析 memory_snapshot + # Parse memory_snapshot self.memory_snapshot = [ MemoryMapping.from_dict(d) for d in data.get('memory_snapshot', []) @@ -163,7 +163,7 @@ def load_cache(self): except (json.JSONDecodeError, KeyError, Exception) as e: logger.warning(f'Failed to load cache: {e}') - # 出错时回退到默认状态 + # Fall back to default state when an error occurs self.max_msg_id = -1 self.cache_messages = {} self.memory_snapshot = [] @@ -193,14 +193,9 @@ def delete_single(self, msg_id: int): else: self.memory.delete(self.memory_snapshot[idx].memory_id) self.memory_snapshot.pop(idx) - idx -= 1 # pop后下一条成为当前idx + idx -= 1 # After pop, the next item becomes the current idx idx += 1 - res = self.memory.get_all(user_id=self.conversation_id) # sorted - res = [(item['id'], item['memory']) for item in res['results']] - logger.info(f'Roll back success. All memory info:') - for item in res: - logger.info(item[1]) def add(self, messages: List[Message], msg_id: int) -> None: self.cache_messages[msg_id] = messages, self._hash_block(messages) @@ -217,7 +212,7 @@ def add(self, messages: List[Message], msg_id: int) -> None: res = self.memory.get_all(user_id=self.conversation_id) # sorted res = [(item['id'], item['memory']) for item in res['results']] if len(res): - logger.info(f'Add memory success. All memory info:') + logger.info('Add memory success. All memory info:') for item in res: logger.info(item[1]) valids = [] @@ -353,6 +348,12 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): if self.history_mode == 'overwrite': for msg_id in should_delete: self.delete_single(msg_id=msg_id) + res = self.memory.get_all( + user_id=self.conversation_id) # sorted + res = [(item['id'], item['memory']) for item in res['results']] + logger.info('Roll back success. All memory info:') + for item in res: + logger.info(item[1]) if should_add_messages: for messages in should_add_messages: self.max_msg_id += 1 diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index bc26a4743..9c795b168 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -52,14 +52,22 @@ def setUp(self) -> None: Message(role='user', content='好的,拜拜') ] - @unittest.skip#Unless(test_level() >= 0, 'skip test in current test level') + def tearDown(self): + import shutil + shutil.rmtree('output', ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_default_memory(self): import uuid import asyncio async def main(): random_id = str(uuid.uuid4()) - default_memory = OmegaConf.create({'memory': {}, 'output_dir': f'output/{random_id}'}) + default_memory = OmegaConf.create({ + 'memory': {}, + 'output_dir': + f'output/{random_id}' + }) agent1 = LLMAgent(config=default_memory, task=random_id) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run('我是素食主义者,我每天早上喝咖啡') @@ -73,14 +81,19 @@ async def main(): asyncio.run(main()) - @unittest.skip#Unless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_agent_tool(self): import uuid import asyncio async def main(): random_id = str(uuid.uuid4()) - config = OmegaConf.create({'memory': {'ignore_role': ['system']}, 'output_dir': f'output/{random_id}'}) + config = OmegaConf.create({ + 'memory': { + 'ignore_role': ['system'] + }, + 'output_dir': f'output/{random_id}' + }) agent1 = LLMAgent(config=OmegaConf.create(config), task=random_id) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run(self.tool_history) @@ -105,26 +118,32 @@ async def main(): Message( role='assistant', content= - '用户指出“量子跃动·朝阳未来运动穹顶”已停业。今天是2045年5月7日,需要重新搜索当前仍在运营的最炫酷运动公园。我将调用awesome_map-search工具,增加“2045年在营”等时间相关关键词,确保结果准确且时效性强。', + '用户指出“量子跃动·朝阳未来运动穹顶”已停业。今天是2045年5月7日,需要重新搜索当前仍在运营的最炫酷运动公园。我将调用' + 'awesome_map-search工具,增加“2045年在营”等时间相关关键词,确保结果准确且时效性强。', tool_calls=[ ToolCall( id='call_xyz2045NewPark', type='function', tool_name='awesome_map-search', arguments= - '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地 2045年在营", "max_results": 1, "result_type": "detailed", "include_features": true}' + '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地 2045年在营", "max_results": 1, ' + '"result_type": "detailed", "include_features": true}' ) ]), Message( role='tool', content= - '[{"park_name": "星核动力·朝阳元宇宙运动矩阵", "address": "北京市朝阳区奥体南路99号", "features": ["全息投影跑道,每一步触发星际粒子 trail", "意念控制攀岩墙,脑波越专注吸附力越强", "重力可调训练舱,模拟火星/月球/深海环境", "AI虚拟教练‘Neo’支持跨次元形象定制", "夜间悬浮滑板池,地面磁力驱动实现无轮滑行"], "special_events": ["每日黄昏举行‘意识觉醒跑’:用脑机接口同步节奏,集体生成光影风暴"], "tips": ["需提前预约神经适配测试", "禁止情绪剧烈波动,否则系统会启动冷静模式播放白噪音", "建议穿导电运动服提升交互体验"], "rating": 9.8, "opening_hours": "5:30 - 24:00", "admission": "免费(脑纹注册入场)"}]' - ), + '[{"park_name": "星核动力·朝阳元宇宙运动矩阵", "address": "北京市朝阳区奥体南路99号", "features": ["全息投影' + '跑道,每一步触发星际粒子 trail", "意念控制攀岩墙,脑波越专注吸附力越强", "重力可调训练舱,模拟火星/月球/深海环境",' + '"AI虚拟教练‘Neo’支持跨次元形象定制", "夜间悬浮滑板池,地面磁力驱动实现无轮滑行"], "special_events": ["每日黄昏' + '举行‘意识觉醒跑’:用脑机接口同步节奏,集体生成光影风暴"], "tips": ["需提前预约神经适配测试", "禁止情绪剧烈波动,否' + '则系统会启动冷静模式播放白噪音", "建议穿导电运动服提升交互体验"], "rating": 9.8, "opening_hours": "5:30 - 2' + '4:00", "admission": "免费(脑纹注册入场)"}]'), Message( role='assistant', content= - '2045年最新的最炫酷的运动公园是:星核动力·朝阳元宇宙运动矩阵。位于北京市朝阳区奥体南路99号,融合脑机接口、全息投影与重力调控技术,打造沉浸式未来运动体验。现已开放预约,支持脑纹注册免费入场。' - ), + '2045年最新的最炫酷的运动公园是:星核动力·朝阳元宇宙运动矩阵。位于北京市朝阳区奥体南路99号,融合脑机接口、全息投影与' + '重力调控技术,打造沉浸式未来运动体验。现已开放预约,支持脑纹注册免费入场。'), Message(role='user', content='好的,谢谢。'), ] tool_history2 = self.tool_history[:-1] + [ From f91203c6913307be5663fe40cf66cf1f91cb7bf6 Mon Sep 17 00:00:00 2001 From: suluyan Date: Fri, 19 Sep 2025 11:25:58 +0800 Subject: [PATCH 11/17] fix time update --- ms_agent/agent/memory/default_memory.py | 4 ++-- ms_agent/utils/prompts.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/agent/memory/default_memory.py index ec79228a4..b3a138972 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/agent/memory/default_memory.py @@ -10,7 +10,7 @@ from ms_agent.agent.memory import Memory from ms_agent.llm.utils import Message from ms_agent.utils.logger import logger -from ms_agent.utils.prompts import FACT_RETRIEVAL_PROMPT +from ms_agent.utils.prompts import get_fact_retrieval_prompt from omegaconf import DictConfig, OmegaConf @@ -457,6 +457,6 @@ def patched_parse_messages(messages, ignore_role): logger.info(f'Memory config: {mem0_config}') # Prompt content is too long, default logging reduces readability mem0_config['custom_fact_extraction_prompt'] = getattr( - self.config.memory, 'fact_retrieval_prompt', FACT_RETRIEVAL_PROMPT) + self.config.memory, 'fact_retrieval_prompt', get_fact_retrieval_prompt()) memory = mem0.Memory.from_config(mem0_config) return memory diff --git a/ms_agent/utils/prompts.py b/ms_agent/utils/prompts.py index 9f27cb2e2..511af228b 100644 --- a/ms_agent/utils/prompts.py +++ b/ms_agent/utils/prompts.py @@ -1,6 +1,7 @@ from datetime import datetime -FACT_RETRIEVAL_PROMPT = f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, preferences, and processing tool interaction outcomes. Your primary role is to extract relevant pieces of information from conversations, organize them into distinct, manageable facts, and additionally process and summarize tool invocation results when present. This ensures both personal data and system interactions are captured for improved context retention and future personalization. +def get_fact_retrieval_prompt(): + return f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, preferences, and processing tool interaction outcomes. Your primary role is to extract relevant pieces of information from conversations, organize them into distinct, manageable facts, and additionally process and summarize tool invocation results when present. This ensures both personal data and system interactions are captured for improved context retention and future personalization. Types of Information to Remember: 1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment. From 27d85028e45cc7066fe5278c3fafc4f1bd7e9b79 Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 21 Sep 2025 18:17:05 +0800 Subject: [PATCH 12/17] fix comment & adjust for conficts in advance --- ms_agent/agent/llm_agent.py | 59 +++++++++++----- ms_agent/{agent => }/memory/__init__.py | 0 ms_agent/{agent => }/memory/base.py | 0 ms_agent/{agent => }/memory/default_memory.py | 70 ++++++------------- ms_agent/{agent => }/memory/utils.py | 0 ms_agent/utils/prompts.py | 1 + tests/memory/test_default_memory.py | 47 +++++++------ 7 files changed, 89 insertions(+), 88 deletions(-) rename ms_agent/{agent => }/memory/__init__.py (100%) rename ms_agent/{agent => }/memory/base.py (100%) rename ms_agent/{agent => }/memory/default_memory.py (85%) rename ms_agent/{agent => }/memory/utils.py (100%) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index c26b7203f..03e2e80d1 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -10,6 +10,7 @@ from ms_agent.callbacks import Callback, callbacks_mapping from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, Tool +from ms_agent.memory import Memory, memory_mapping from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping from ms_agent.tools import ToolManager @@ -19,7 +20,6 @@ from ..utils.utils import read_history, save_history from .base import Agent -from .memory import Memory, memory_mapping from .plan.base import Planer from .plan.utils import planer_mapping from .runtime import Runtime @@ -223,24 +223,46 @@ async def _prepare_messages( messages = await self.rag.query(messages[1].content) return messages - async def _prepare_memory(self, - messages: Optional[List[Message]] = None, - **kwargs): - """Load and initialize memory components from the config.""" - config, runtime, cache_messages = self._read_history( - messages, **kwargs) - if hasattr(self.config, 'memory'): - memory_type = getattr(self.config.memory, 'name', 'default_memory') - assert memory_type in memory_mapping, ( - f'{memory_type} not in memory_mapping, ' - f'which supports: {list(memory_mapping.keys())}') + async def _prepare_memory(self): + """ + Prepare memory + + Initializes and appends memory tool instances based on the configuration provided in self.config. + Args: + self: The instance of the class containing this method. Expected to have: + - config: An object that may contain a memory attribute, which is a list of memory configurations. + + Returns: + None - self.memory_tools.append(memory_mapping[memory_type]( - self.config, conversation_id=self.task)) + Raises: + AssertionError: If a specified memory type in the config does not exist in memory_mapping. + """ + if hasattr(self.config, 'memory'): + for _memory in (self.config.memory or []): + memory_type = getattr(_memory, 'name', 'default_memory') + assert memory_type in memory_mapping, ( + f'{memory_type} not in memory_mapping, ' + f'which supports: {list(memory_mapping.keys())}') + + # Use LLM config if no special configuration is specified + llm_config = getattr(_memory, 'llm', None) + if llm_config is None: + service = self.config.llm.service + config_dict = { + 'model': + self.config.llm.model, + 'provider': + 'openai', + 'openai_base_url': + getattr(self.config.llm, f'{service}_base_url', None), + 'openai_api_key': + getattr(self.config.llm, f'{service}_api_key', None), + } + llm_config_obj = OmegaConf.create(config_dict) + setattr(_memory, 'llm', llm_config_obj) - if messages != cache_messages: - runtime.should_stop = False - return config, runtime, cache_messages + self.memory_tools.append(memory_mapping[memory_type](_memory)) async def _prepare_planer(self): """Load and initialize the planer component from the config.""" @@ -478,9 +500,10 @@ async def _run(self, messages: Union[List[Message], str], self._prepare_llm() self._prepare_runtime() await self._prepare_tools() + await self._prepare_memory() await self._prepare_planer() await self._prepare_rag() - self.config, self.runtime, messages = await self._prepare_memory( + config, runtime, cache_messages = self._read_history( messages, **kwargs) self.runtime.tag = self.tag diff --git a/ms_agent/agent/memory/__init__.py b/ms_agent/memory/__init__.py similarity index 100% rename from ms_agent/agent/memory/__init__.py rename to ms_agent/memory/__init__.py diff --git a/ms_agent/agent/memory/base.py b/ms_agent/memory/base.py similarity index 100% rename from ms_agent/agent/memory/base.py rename to ms_agent/memory/base.py diff --git a/ms_agent/agent/memory/default_memory.py b/ms_agent/memory/default_memory.py similarity index 85% rename from ms_agent/agent/memory/default_memory.py rename to ms_agent/memory/default_memory.py index b3a138972..fbb337427 100644 --- a/ms_agent/agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -7,7 +7,7 @@ import json import json5 -from ms_agent.agent.memory import Memory +from ms_agent.memory import Memory from ms_agent.llm.utils import Message from ms_agent.utils.logger import logger from ms_agent.utils.prompts import get_fact_retrieval_prompt @@ -72,30 +72,16 @@ def from_dict(cls, data: Dict) -> 'MemoryMapping': class DefaultMemory(Memory): """The memory refine tool""" - def __init__(self, - config: DictConfig, - conversation_id: Optional[str] = None, - persist: bool = True, - path: str = None, - history_mode: Literal['add', 'overwrite'] = None): + def __init__(self, config: DictConfig): super().__init__(config) - self.conversation_id: Optional[str] = conversation_id or getattr( - config.memory, 'conversation_id', None) - self.persist: Optional[bool] = persist or getattr( - config.memory, 'persist', True) - self.compress: Optional[bool] = getattr(config.memory, 'compress', - True) - self.is_retrieve: Optional[bool] = getattr(config.memory, - 'is_retrieve', True) - self.path: Optional[str] = path or getattr( - config.memory, 'path', None) or getattr(self.config, 'output_dir', - 'output') - self.history_mode = history_mode or getattr(config.memory, - 'history_mode', 'add') - self.ignore_role: List[str] = getattr(config.memory, 'ignore_role', - ['tool', 'system']) - self.ignore_fields: List[str] = getattr(config.memory, 'ignore_fields', - ['reasoning_content']) + self.user_id: Optional[str] = getattr(self.config, 'user_id', None) + self.persist: Optional[bool] = getattr(config, 'persist', True) + self.compress: Optional[bool] = getattr(config, 'compress', True) + self.is_retrieve: Optional[bool] = getattr(config, 'is_retrieve', True) + self.path: Optional[str] = getattr(self.config, 'path', 'output') + self.history_mode = getattr(config, 'history_mode', 'add') + self.ignore_role: List[str] = getattr(config, 'ignore_role', ['tool', 'system']) + self.ignore_fields: List[str] = getattr(config, 'ignore_fields', ['reasoning_content']) self.memory = self._init_memory_obj() self.init_cache_messages() @@ -186,7 +172,7 @@ def delete_single(self, msg_id: int): self.memory._create_memory( data=self.memory_snapshot[idx].value, existing_embeddings={}, - metadata={'user_id': self.conversation_id}) + metadata={'user_id': self.user_id}) if msg_id in enable_ids: if len(enable_ids) > 1: self.memory_snapshot[idx].enable_idxs.remove(msg_id) @@ -206,10 +192,10 @@ def add(self, messages: List[Message], msg_id: int) -> None: messages_dict.append(message.to_dict()) else: messages_dict.append(message) - self.memory.add(messages_dict, user_id=self.conversation_id) + self.memory.add(messages_dict, user_id=self.user_id) self.max_msg_id = max(self.max_msg_id, msg_id) - res = self.memory.get_all(user_id=self.conversation_id) # sorted + res = self.memory.get_all(user_id=self.user_id) # sorted res = [(item['id'], item['memory']) for item in res['results']] if len(res): logger.info('Add memory success. All memory info:') @@ -239,7 +225,7 @@ def add(self, messages: List[Message], msg_id: int) -> None: def search(self, query: str) -> str: relevant_memories = self.memory.search( - query, user_id=self.conversation_id, limit=3) + query, user_id=self.user_id, limit=3) memories_str = '\n'.join(f"- {entry['memory']}" for entry in relevant_memories['results']) return memories_str @@ -349,7 +335,7 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): for msg_id in should_delete: self.delete_single(msg_id=msg_id) res = self.memory.get_all( - user_id=self.conversation_id) # sorted + user_id=self.user_id) # sorted res = [(item['id'], item['memory']) for item in res['results']] logger.info('Roll back success. All memory info:') for item in res: @@ -404,7 +390,7 @@ def patched_parse_messages(messages, ignore_role): return embedder: Optional[str] = getattr( - self.config.memory, 'embedder', + self.config, 'embedder', OmegaConf.create({ 'provider': 'openai', 'config': { @@ -417,22 +403,12 @@ def patched_parse_messages(messages, ignore_role): llm = {} if self.compress: - llm_config = getattr(self.config.memory, 'llm', None) - if llm_config is not None: - # follow mem0 config - model = llm_config.get('model') - provider = llm_config.get('provider', 'openai') - openai_base_url = llm_config.get('openai_base_url', None) - openai_api_key = llm_config.get('api_key', None) - else: - llm_config = self.config.llm - model = llm_config.model - service = llm_config.service - openai_base_url = getattr(llm_config, f'{service}_base_url', - None) - openai_api_key = getattr(llm_config, f'{service}_api_key', - None) - provider = 'openai' + llm_config = getattr(self.config, 'llm', None) + # follow mem0 config + model = llm_config.get('model') + provider = llm_config.get('provider', 'openai') + openai_base_url = llm_config.get('openai_base_url', None) + openai_api_key = llm_config.get('openai_api_key', None) llm = { 'provider': provider, 'config': { @@ -457,6 +433,6 @@ def patched_parse_messages(messages, ignore_role): logger.info(f'Memory config: {mem0_config}') # Prompt content is too long, default logging reduces readability mem0_config['custom_fact_extraction_prompt'] = getattr( - self.config.memory, 'fact_retrieval_prompt', get_fact_retrieval_prompt()) + self.config, 'fact_retrieval_prompt', get_fact_retrieval_prompt()) memory = mem0.Memory.from_config(mem0_config) return memory diff --git a/ms_agent/agent/memory/utils.py b/ms_agent/memory/utils.py similarity index 100% rename from ms_agent/agent/memory/utils.py rename to ms_agent/memory/utils.py diff --git a/ms_agent/utils/prompts.py b/ms_agent/utils/prompts.py index 511af228b..fb57044b3 100644 --- a/ms_agent/utils/prompts.py +++ b/ms_agent/utils/prompts.py @@ -1,5 +1,6 @@ from datetime import datetime + def get_fact_retrieval_prompt(): return f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, preferences, and processing tool interaction outcomes. Your primary role is to extract relevant pieces of information from conversations, organize them into distinct, manageable facts, and additionally process and summarize tool invocation results when present. This ensures both personal data and system interactions are captured for improved context retention and future personalization. diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index 9c795b168..a83c94f2d 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -1,13 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import math -import os import unittest -import json from ms_agent.agent import LLMAgent -from ms_agent.agent.memory.default_memory import DefaultMemory from ms_agent.llm.utils import Message, ToolCall -from omegaconf import DictConfig, OmegaConf +from omegaconf import OmegaConf from modelscope.utils.test_utils import test_level @@ -64,16 +60,17 @@ def test_default_memory(self): async def main(): random_id = str(uuid.uuid4()) default_memory = OmegaConf.create({ - 'memory': {}, - 'output_dir': - f'output/{random_id}' + 'memory': [{ + 'path': f'output/{random_id}', + 'user_id': random_id + }], }) - agent1 = LLMAgent(config=default_memory, task=random_id) + agent1 = LLMAgent(config=default_memory) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run('我是素食主义者,我每天早上喝咖啡') del agent1 print('========== 数据准备结束,开始测试 ===========') - agent2 = LLMAgent(config=default_memory, task=random_id) + agent2 = LLMAgent(config=default_memory) agent2.config.callbacks.remove('input_callback') # noqa res = await agent2.run('请帮我准备明天的三餐食谱') print(res) @@ -89,19 +86,21 @@ def test_agent_tool(self): async def main(): random_id = str(uuid.uuid4()) config = OmegaConf.create({ - 'memory': { - 'ignore_role': ['system'] - }, - 'output_dir': f'output/{random_id}' + 'memory': [{ + 'ignore_role': ['system'], + 'user_id': random_id, + 'path': f'output/{random_id}' + }] }) - agent1 = LLMAgent(config=OmegaConf.create(config), task=random_id) + agent1 = LLMAgent(config=OmegaConf.create(config)) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run(self.tool_history) del agent1 print('========== 数据准备结束,开始测试 ===========') - agent2 = LLMAgent(config=OmegaConf.create(config), task=random_id) + agent2 = LLMAgent(config=OmegaConf.create(config)) agent2.config.callbacks.remove('input_callback') # noqa res = await agent2.run('北京市朝阳区最炫酷的运动公园的地点') + del agent2 print(res) assert ('酒仙桥路8888号' in res[-1].content) @@ -150,21 +149,23 @@ async def main(): Message(role='user', content='北京市朝阳区最炫酷的运动公园的地点?') ] random_id = str(uuid.uuid4()) - config = OmegaConf.create({ + config = OmegaConf.create([{ 'memory': { 'ignore_role': ['system'], - 'history_mode': 'overwrite' - }, - 'output_dir': f'output/{random_id}' - }) - agent1 = LLMAgent(config=OmegaConf.create(config), task=random_id) + 'history_mode': 'overwrite', + 'path': f'output/{random_id}', + 'user_id': random_id, + } + }]) + agent1 = LLMAgent(config=OmegaConf.create(config)) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run(tool_history1) del agent1 print('========== 数据准备结束,开始测试 ===========') - agent2 = LLMAgent(config=OmegaConf.create(config), task=random_id) + agent2 = LLMAgent(config=OmegaConf.create(config)) agent2.config.callbacks.remove('input_callback') # noqa res = await agent2.run(tool_history2) + del agent2 print(res) assert ('酒仙桥路8888号' in res[-1].content and '奥体南路' not in res[-1].content) From 9051d16f405635ef9b8603f4e00ca2b9a65ea322 Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 21 Sep 2025 18:28:55 +0800 Subject: [PATCH 13/17] fix bugs --- ms_agent/memory/default_memory.py | 8 +++++--- tests/memory/test_default_memory.py | 2 -- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index fbb337427..cbe58f1d6 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -349,16 +349,18 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): query = getattr(messages[-1], 'content') memories_str = self.search(query) # Remove the messages section corresponding to memory, and add the related memory_str information + remain_idx = len(messages) - sum([len(block) for block in should_add_messages]) if getattr(messages[0], 'role') == 'system': system_prompt = getattr( messages[0], 'content') + f'\nUser Memories: {memories_str}' + if remain_idx < 1: + remain_idx = 1 else: system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\n' \ f'User Memories: {memories_str}' - diff_idx = len(messages) - sum( - [len(block) for block in should_add_messages]) + new_messages = [Message(role='system', content=system_prompt) - ] + messages[diff_idx:] + ] + messages[remain_idx:] return new_messages def _init_memory_obj(self): diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index a83c94f2d..0907c108b 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -100,7 +100,6 @@ async def main(): agent2 = LLMAgent(config=OmegaConf.create(config)) agent2.config.callbacks.remove('input_callback') # noqa res = await agent2.run('北京市朝阳区最炫酷的运动公园的地点') - del agent2 print(res) assert ('酒仙桥路8888号' in res[-1].content) @@ -165,7 +164,6 @@ async def main(): agent2 = LLMAgent(config=OmegaConf.create(config)) agent2.config.callbacks.remove('input_callback') # noqa res = await agent2.run(tool_history2) - del agent2 print(res) assert ('酒仙桥路8888号' in res[-1].content and '奥体南路' not in res[-1].content) From f1a92c6f3f3e1f16aaef6aad638024daa496d315 Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 21 Sep 2025 18:40:12 +0800 Subject: [PATCH 14/17] minor fix --- ms_agent/agent/llm_agent.py | 5 +++-- ms_agent/memory/default_memory.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 03e2e80d1..34582f4ea 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -503,10 +503,11 @@ async def _run(self, messages: Union[List[Message], str], await self._prepare_memory() await self._prepare_planer() await self._prepare_rag() - config, runtime, cache_messages = self._read_history( - messages, **kwargs) self.runtime.tag = self.tag + self.config, self.runtime, messages = self._read_history( + messages, **kwargs) + if self.runtime.round == 0: # 0 means no history messages = await self._prepare_messages(messages) diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index cbe58f1d6..7bddcafb6 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -7,8 +7,8 @@ import json import json5 -from ms_agent.memory import Memory from ms_agent.llm.utils import Message +from ms_agent.memory import Memory from ms_agent.utils.logger import logger from ms_agent.utils.prompts import get_fact_retrieval_prompt from omegaconf import DictConfig, OmegaConf @@ -80,8 +80,10 @@ def __init__(self, config: DictConfig): self.is_retrieve: Optional[bool] = getattr(config, 'is_retrieve', True) self.path: Optional[str] = getattr(self.config, 'path', 'output') self.history_mode = getattr(config, 'history_mode', 'add') - self.ignore_role: List[str] = getattr(config, 'ignore_role', ['tool', 'system']) - self.ignore_fields: List[str] = getattr(config, 'ignore_fields', ['reasoning_content']) + self.ignore_role: List[str] = getattr(config, 'ignore_role', + ['tool', 'system']) + self.ignore_fields: List[str] = getattr(config, 'ignore_fields', + ['reasoning_content']) self.memory = self._init_memory_obj() self.init_cache_messages() @@ -334,8 +336,7 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): if self.history_mode == 'overwrite': for msg_id in should_delete: self.delete_single(msg_id=msg_id) - res = self.memory.get_all( - user_id=self.user_id) # sorted + res = self.memory.get_all(user_id=self.user_id) # sorted res = [(item['id'], item['memory']) for item in res['results']] logger.info('Roll back success. All memory info:') for item in res: @@ -349,7 +350,8 @@ async def run(self, messages, ignore_role=None, ignore_fields=None): query = getattr(messages[-1], 'content') memories_str = self.search(query) # Remove the messages section corresponding to memory, and add the related memory_str information - remain_idx = len(messages) - sum([len(block) for block in should_add_messages]) + remain_idx = len(messages) - sum( + [len(block) for block in should_add_messages]) if getattr(messages[0], 'role') == 'system': system_prompt = getattr( messages[0], 'content') + f'\nUser Memories: {memories_str}' From 4a30dd8c28f56e4dda327dddcd9cda75ea08f6a4 Mon Sep 17 00:00:00 2001 From: suluyan Date: Sun, 21 Sep 2025 19:14:59 +0800 Subject: [PATCH 15/17] add en test case --- ms_agent/agent/agent.yaml | 16 ++- tests/memory/test_default_memory.py | 136 +++++++++++++------ tests/memory/test_default_memory_zh.py | 175 +++++++++++++++++++++++++ 3 files changed, 279 insertions(+), 48 deletions(-) create mode 100644 tests/memory/test_default_memory_zh.py diff --git a/ms_agent/agent/agent.yaml b/ms_agent/agent/agent.yaml index eb8990fc0..85efe4248 100644 --- a/ms_agent/agent/agent.yaml +++ b/ms_agent/agent/agent.yaml @@ -31,20 +31,22 @@ prompt: 6. Do not call tools carelessly. Show your thoughts **as detailed as possible**. - For requests that require performing a specific task or retrieving information, you must use the following format: + 7. Respond in the same language the user uses. If the user switches, switch accordingly. + + For requests that require performing a specific task or retrieving information, you must use the following format in user language: ``` - 用户需要 ... - 针对该需求,我进行了详细拆解和规划,需要按照如下步骤来解决问题: + The user needs to ... + I have analyzed this request in detail and broken it down into the following steps: ... ``` If you have tools which may help you to solve problems, follow this format to answer: ``` - 用户需要 ... - 针对该需求,我进行了详细拆解和规划,需要按照如下步骤来解决问题: + The user needs to ... + I have analyzed this request in detail and broken it down into the following steps: ... - 首先我应当选择...工具,由于该工具..., 该工具的入参需要... + First, I should use the [Tool Name] because [explain relevance]. The required input parameters are: ... ... - 我仔细查看了工具返回值,该工具的返回值符合/不符合我的要求,我接下来需要... + I have carefully reviewed the tool's output. The result does/does not fully meet my expectations. Next, I need to ... ``` max_chat_round: 9999 diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py index 0907c108b..410f37bc5 100644 --- a/tests/memory/test_default_memory.py +++ b/tests/memory/test_default_memory.py @@ -12,40 +12,58 @@ class TestDefaultMemory(unittest.TestCase): def setUp(self) -> None: self.tool_history = [ - Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园。记着该地点,下次去。'), + Message( + role='user', + content= + 'Help me find the coolest sports park in Chaoyang District, Beijing. Remember this location for next' + 'time.'), Message( role='assistant', content= - '\n用户希望找到北京市朝阳区最炫酷的运动公园,关键词是‘炫酷’,说明用户追求的不仅是功能性,更看重科技感、设计感、互动体验' - '和潮流元素。因此,我需要搜索具备未来感、融合高科技、潮流文化或创新设施的运动场所。\n\n为了解决这个问题,我将采取以下步' - '骤:\n1. 使用awesome_map-search工具搜索北京市朝阳区的运动公园,重点关注‘炫酷’‘科技感’‘潮流’等关键词\n2. 筛选出最' - '具特色、评价最高、视觉冲击力强的公园\n3. 提供运动公园名称。\n\n现在我将调用awesome_map-search工具进行搜索,该工具' - '专为地理+趣味性信息检索设计,支持语义化查询,尤其擅长发现‘宝藏地点’。\n', + '\nThe user wants to find the coolest sports park in Chaoyang District, Beijing. The keyword "cool" ' + 'suggests they are not just looking for functionality, but also value high-tech design, interactive ' + 'experiences, and trendy elements. Therefore, I need to search for venues that have a futuristic feel, ' + 'integrate advanced technology, youth culture, or innovative facilities.\n\nTo solve this, I will take ' + 'the following steps:\n' + '1. Use the awesome_map-search tool to search for sports parks in Chaoyang District, Beijing, focusing ' + 'on keywords like "cool", "high-tech", "trendy", etc.\n' + '2. Filter out the most distinctive, highly-rated parks with strong visual impact\n' + '3. Provide the name of the sports park.\n\nNow I will call the awesome_map-search tool. This tool is ' + 'designed specifically for geo + fun information retrieval, ' + 'supports semantic queries, and excels at discovering "hidden gem" locations.', tool_calls=[ ToolCall( id='call_xyz789CoolPark', type='function', tool_name='awesome_map-search', arguments= - '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地", "max_results": 1, "result_type":' + '{"query": "Beijing Chaoyang District coolest sports park high-tech trendy must-visit spot", ' + '"max_results": 1, "result_type":' '"detailed", "include_features": true}') ]), Message( role='tool', content= - '[{"park_name": "量子跃动·朝阳未来运动穹顶", "address": "北京市朝阳区酒仙桥路8888号", "features": ["反重力悬' - '浮跑道,采用磁悬浮缓震技术,跑步如在月球上跳跃", "AR极限攀岩墙,支持末日火山、星际迷宫等虚拟场景", "DJ动感骑行舱,踩' - '得越快音乐越炸,实时生成能量波形图", "AI教练机器人\'铁燃\',可定制训练计划并用东北话激励用户", "夜光太极湖与漂浮瑜伽' - '平台,湖水含环保夜光藻类,夜晚泛蓝光"], "special_events": ["每周五晚举办炫酷运动会:激光躲避球、荧光舞步跑、机器人' - '拔河"], "tips": ["需完成3道运动谜题解锁入场", "禁止穿拖鞋跑步,否则AI会播放《凉凉》", "建议携带充电宝,拍照太频繁易' - '没电"], "rating": 9.9, "opening_hours": "6:00 - 23:00", "admission": "免费"}]' + '[{"park_name": "Quantum Leap · Chaoyang Future Sports Dome", "address": "No. 8888 Jiuxianqiao Road, ' + 'Chaoyang District, Beijing", "features": ["Anti-gravity floating track using magnetic levitation ' + 'cushioning tech — running feels like jumping on the moon", "AR extreme climbing wall supporting ' + 'virtual scenes such as volcanic apocalypse and interstellar maze", "DJ-powered cycling pods: the ' + 'faster you pedal, the more intense the music, with real-time energy waveform visualization", ' + '"AI coach robot \'Tie Ran\', which creates custom workout plans and motivates users in Northeastern ' + 'Chinese dialect", "Glow-in-the-dark Tai Chi lake with floating yoga platforms; water contains ' + 'eco-friendly bioluminescent algae glowing blue at night"], "special_events": ["Cool Sports Event ' + 'every Friday night: laser dodgeball, glow-in-the-dark dance run, robot tug-of-war"], "tips": ["Must ' + 'solve 3 fitness puzzles to unlock entry", "Bring a power bank — taking photos too frequently drains ' + 'battery quickly"], "rating": 9.9, "opening_hours": "6:00 - 23:00", "admission": "Free"}]' ), Message( role='assistant', content= - '\n用户希望找到北京市朝阳区最炫酷的运动公园。我已经通过awesome_map-search工具获取了相关信息。\n\n## 🌟 北京市朝阳区最炫酷运动公园揭晓:「量子跃动·朝阳未来运动穹顶」' + '\nThe user wanted to find the coolest sports park in Chaoyang District, Beijing. I have retrieved ' + 'relevant information via the awesome_map-search tool.\n\n## 🌟 The Coolest Sports Park in Chaoyang ' + 'District, Beijing Revealed: 「Quantum Leap · Chaoyang Future Sports Dome」' ), - Message(role='user', content='好的,拜拜') + Message(role='user', content='Alright, bye!') ] def tearDown(self): @@ -63,18 +81,24 @@ async def main(): 'memory': [{ 'path': f'output/{random_id}', 'user_id': random_id - }], + }] }) agent1 = LLMAgent(config=default_memory) agent1.config.callbacks.remove('input_callback') # noqa - await agent1.run('我是素食主义者,我每天早上喝咖啡') + await agent1.run( + 'I am a vegetarian and I drink coffee every morning.') del agent1 - print('========== 数据准备结束,开始测试 ===========') + print( + '========== Data preparation completed, starting test ===========' + ) agent2 = LLMAgent(config=default_memory) agent2.config.callbacks.remove('input_callback') # noqa - res = await agent2.run('请帮我准备明天的三餐食谱') + res = await agent2.run( + 'Please help me plan tomorrow’s three meals.') print(res) - assert ('素' in res[-1].content and '咖啡' in res[-1].content) + assert ('vegetarian' in res[-1].content.lower() + or 'vegan' in res[-1].content.lower() + ) and 'coffee' in res[-1].content.lower() asyncio.run(main()) @@ -95,13 +119,19 @@ async def main(): agent1 = LLMAgent(config=OmegaConf.create(config)) agent1.config.callbacks.remove('input_callback') # noqa await agent1.run(self.tool_history) + agent1.memory_tools[0].memory.vector_store.client.close() del agent1 - print('========== 数据准备结束,开始测试 ===========') + print( + '========== Data preparation completed, starting test ===========' + ) agent2 = LLMAgent(config=OmegaConf.create(config)) agent2.config.callbacks.remove('input_callback') # noqa - res = await agent2.run('北京市朝阳区最炫酷的运动公园的地点') + res = await agent2.run( + 'What is the location of the coolest sports park in Chaoyang District, Beijing?' + ) print(res) - assert ('酒仙桥路8888号' in res[-1].content) + assert 'Jiuxianqiao Road 8888' in res[ + -1].content or 'No. 8888 Jiuxianqiao Road' in res[-1].content asyncio.run(main()) @@ -112,40 +142,60 @@ def test_overwrite_with_tool(self): async def main(): tool_history1 = self.tool_history[:-1] + [ - Message(role='user', content='你说的这家运动公园已经停业了。'), + Message( + role='user', + content= + 'The sports park you mentioned has already closed down.'), Message( role='assistant', content= - '用户指出“量子跃动·朝阳未来运动穹顶”已停业。今天是2045年5月7日,需要重新搜索当前仍在运营的最炫酷运动公园。我将调用' - 'awesome_map-search工具,增加“2045年在营”等时间相关关键词,确保结果准确且时效性强。', + 'The user mentioned that "Quantum Leap · Chaoyang Future Sports Dome" has shut down. Today is ' + 'May 7, 2045. I need to search again for the currently operating coolest sports park. I will use ' + 'the awesome_map-search tool with updated time-sensitive keywords such as "open in 2045" to ensure ' + 'accuracy and timeliness.', tool_calls=[ ToolCall( id='call_xyz2045NewPark', type='function', tool_name='awesome_map-search', arguments= - '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地 2045年在营", "max_results": 1, ' - '"result_type": "detailed", "include_features": true}' + '{"query": "Beijing Chaoyang District coolest sports park high-tech trendy must-visit spot ' + 'open in 2045", "max_results": 1, "result_type": "detailed", "include_features": true}' ) ]), Message( role='tool', content= - '[{"park_name": "星核动力·朝阳元宇宙运动矩阵", "address": "北京市朝阳区奥体南路99号", "features": ["全息投影' - '跑道,每一步触发星际粒子 trail", "意念控制攀岩墙,脑波越专注吸附力越强", "重力可调训练舱,模拟火星/月球/深海环境",' - '"AI虚拟教练‘Neo’支持跨次元形象定制", "夜间悬浮滑板池,地面磁力驱动实现无轮滑行"], "special_events": ["每日黄昏' - '举行‘意识觉醒跑’:用脑机接口同步节奏,集体生成光影风暴"], "tips": ["需提前预约神经适配测试", "禁止情绪剧烈波动,否' - '则系统会启动冷静模式播放白噪音", "建议穿导电运动服提升交互体验"], "rating": 9.8, "opening_hours": "5:30 - 2' - '4:00", "admission": "免费(脑纹注册入场)"}]'), + '[{"park_name": "Stellar Core Dynamics · Chaoyang Metaverse Sports Matrix", ' + '"address": "No. 99 Aoti South Road, Chaoyang District, Beijing", ' + '"features": ["Holographic projection tracks that trigger stardust trails with each step", ' + '"Mind-controlled climbing wall — stronger focus increases adhesion", ' + '"Gravity-adjustable training pods simulating Mars, Moon, or deep-sea environments", ' + '"AI virtual coach \'Neo\' with customizable cross-dimensional avatars", ' + '"Nighttime hoverboard pool using magnetic ground propulsion for wheel-free gliding"], ' + '"special_events": ["Daily twilight \'Consciousness Awakening Run\': synchronized rhythm via ' + 'brain-computer interface, generating collective lightstorm"], ' + '"tips": ["Neural compatibility test required in advance", ' + '"Avoid extreme emotional fluctuations, otherwise system activates calming white noise mode", ' + '"Wearing conductive sportswear recommended for better interaction"], "rating": 9.8, ' + '"opening_hours": "5:30 - 24:00", "admission": "Free (entry via brainprint registration)"}]' + ), Message( role='assistant', content= - '2045年最新的最炫酷的运动公园是:星核动力·朝阳元宇宙运动矩阵。位于北京市朝阳区奥体南路99号,融合脑机接口、全息投影与' - '重力调控技术,打造沉浸式未来运动体验。现已开放预约,支持脑纹注册免费入场。'), - Message(role='user', content='好的,谢谢。'), + 'The latest and coolest sports park in 2045 is: Stellar Core Dynamics · Chaoyang Metaverse Sports ' + 'Matrix. Located at No. 99 Aoti South Road, Chaoyang District, Beijing, it integrates ' + 'brain-computer interfaces, holographic projections, and gravity control technology to deliver an ' + 'immersive futuristic fitness experience. Now open for reservations, free entry via brainprint ' + 'registration.'), + Message(role='user', content='Got it, thanks.'), ] tool_history2 = self.tool_history[:-1] + [ - Message(role='user', content='北京市朝阳区最炫酷的运动公园的地点?') + Message( + role='user', + content= + 'What is the location of the coolest sports park in Chaoyang District, Beijing?' + ) ] random_id = str(uuid.uuid4()) config = OmegaConf.create([{ @@ -160,13 +210,17 @@ async def main(): agent1.config.callbacks.remove('input_callback') # noqa await agent1.run(tool_history1) del agent1 - print('========== 数据准备结束,开始测试 ===========') + print( + '========== Data preparation completed, starting test ===========' + ) agent2 = LLMAgent(config=OmegaConf.create(config)) agent2.config.callbacks.remove('input_callback') # noqa res = await agent2.run(tool_history2) print(res) - assert ('酒仙桥路8888号' in res[-1].content - and '奥体南路' not in res[-1].content) + # Assert old info remains due to overwrite mode, new info not persisted + assert ('Jiuxianqiao Road 8888' in res[-1].content + or 'No. 8888 Jiuxianqiao Road' in res[-1].content + ) and 'Aoti South Road' not in res[-1].content asyncio.run(main()) diff --git a/tests/memory/test_default_memory_zh.py b/tests/memory/test_default_memory_zh.py new file mode 100644 index 000000000..fed6e4d08 --- /dev/null +++ b/tests/memory/test_default_memory_zh.py @@ -0,0 +1,175 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from ms_agent.agent import LLMAgent +from ms_agent.llm.utils import Message, ToolCall +from omegaconf import OmegaConf + +from modelscope.utils.test_utils import test_level + + +class TestDefaultMemory(unittest.TestCase): + + def setUp(self) -> None: + self.tool_history = [ + Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园。记着该地点,下次去。'), + Message( + role='assistant', + content= + '\n用户希望找到北京市朝阳区最炫酷的运动公园,关键词是‘炫酷’,说明用户追求的不仅是功能性,更看重科技感、设计感、互动体验' + '和潮流元素。因此,我需要搜索具备未来感、融合高科技、潮流文化或创新设施的运动场所。\n\n为了解决这个问题,我将采取以下步' + '骤:\n1. 使用awesome_map-search工具搜索北京市朝阳区的运动公园,重点关注‘炫酷’‘科技感’‘潮流’等关键词\n2. 筛选出最' + '具特色、评价最高、视觉冲击力强的公园\n3. 提供运动公园名称。\n\n现在我将调用awesome_map-search工具进行搜索,该工具' + '专为地理+趣味性信息检索设计,支持语义化查询,尤其擅长发现‘宝藏地点’。\n', + tool_calls=[ + ToolCall( + id='call_xyz789CoolPark', + type='function', + tool_name='awesome_map-search', + arguments= + '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地", "max_results": 1, "result_type":' + '"detailed", "include_features": true}') + ]), + Message( + role='tool', + content= + '[{"park_name": "量子跃动·朝阳未来运动穹顶", "address": "北京市朝阳区酒仙桥路8888号", "features": ["反重力悬' + '浮跑道,采用磁悬浮缓震技术,跑步如在月球上跳跃", "AR极限攀岩墙,支持末日火山、星际迷宫等虚拟场景", "DJ动感骑行舱,踩' + '得越快音乐越炸,实时生成能量波形图", "AI教练机器人\'铁燃\',可定制训练计划并用东北话激励用户", "夜光太极湖与漂浮瑜伽' + '平台,湖水含环保夜光藻类,夜晚泛蓝光"], "special_events": ["每周五晚举办炫酷运动会:激光躲避球、荧光舞步跑、机器人' + '拔河"], "tips": ["需完成3道运动谜题解锁入场", "禁止穿拖鞋跑步,否则AI会播放《凉凉》", "建议携带充电宝,拍照太频繁易' + '没电"], "rating": 9.9, "opening_hours": "6:00 - 23:00", "admission": "免费"}]' + ), + Message( + role='assistant', + content= + '\n用户希望找到北京市朝阳区最炫酷的运动公园。我已经通过awesome_map-search工具获取了相关信息。\n\n## 🌟 北京市朝阳区最炫酷运动公园揭晓:「量子跃动·朝阳未来运动穹顶」' + ), + Message(role='user', content='好的,拜拜') + ] + + def tearDown(self): + import shutil + shutil.rmtree('output', ignore_errors=True) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_default_memory(self): + import uuid + import asyncio + + async def main(): + random_id = str(uuid.uuid4()) + default_memory = OmegaConf.create({ + 'memory': [{ + 'path': f'output/{random_id}', + 'user_id': random_id + }], + }) + agent1 = LLMAgent(config=default_memory) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run('我是素食主义者,我每天早上喝咖啡') + del agent1 + print('========== 数据准备结束,开始测试 ===========') + agent2 = LLMAgent(config=default_memory) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run('请帮我准备明天的三餐食谱') + print(res) + assert ('素' in res[-1].content and '咖啡' in res[-1].content) + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_agent_tool(self): + import uuid + import asyncio + + async def main(): + random_id = str(uuid.uuid4()) + config = OmegaConf.create({ + 'memory': [{ + 'ignore_role': ['system'], + 'user_id': random_id, + 'path': f'output/{random_id}' + }] + }) + agent1 = LLMAgent(config=OmegaConf.create(config)) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run(self.tool_history) + del agent1 + print('========== 数据准备结束,开始测试 ===========') + agent2 = LLMAgent(config=OmegaConf.create(config)) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run('北京市朝阳区最炫酷的运动公园的地点') + print(res) + assert ('酒仙桥路8888号' in res[-1].content) + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_overwrite_with_tool(self): + import uuid + import asyncio + + async def main(): + tool_history1 = self.tool_history[:-1] + [ + Message(role='user', content='你说的这家运动公园已经停业了。'), + Message( + role='assistant', + content= + '用户指出“量子跃动·朝阳未来运动穹顶”已停业。今天是2045年5月7日,需要重新搜索当前仍在运营的最炫酷运动公园。我将调用' + 'awesome_map-search工具,增加“2045年在营”等时间相关关键词,确保结果准确且时效性强。', + tool_calls=[ + ToolCall( + id='call_xyz2045NewPark', + type='function', + tool_name='awesome_map-search', + arguments= + '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地 2045年在营", "max_results": 1, ' + '"result_type": "detailed", "include_features": true}' + ) + ]), + Message( + role='tool', + content= + '[{"park_name": "星核动力·朝阳元宇宙运动矩阵", "address": "北京市朝阳区奥体南路99号", "features": ["全息投影' + '跑道,每一步触发星际粒子 trail", "意念控制攀岩墙,脑波越专注吸附力越强", "重力可调训练舱,模拟火星/月球/深海环境",' + '"AI虚拟教练‘Neo’支持跨次元形象定制", "夜间悬浮滑板池,地面磁力驱动实现无轮滑行"], "special_events": ["每日黄昏' + '举行‘意识觉醒跑’:用脑机接口同步节奏,集体生成光影风暴"], "tips": ["需提前预约神经适配测试", "禁止情绪剧烈波动,否' + '则系统会启动冷静模式播放白噪音", "建议穿导电运动服提升交互体验"], "rating": 9.8, "opening_hours": "5:30 - 2' + '4:00", "admission": "免费(脑纹注册入场)"}]'), + Message( + role='assistant', + content= + '2045年最新的最炫酷的运动公园是:星核动力·朝阳元宇宙运动矩阵。位于北京市朝阳区奥体南路99号,融合脑机接口、全息投影与' + '重力调控技术,打造沉浸式未来运动体验。现已开放预约,支持脑纹注册免费入场。'), + Message(role='user', content='好的,谢谢。'), + ] + tool_history2 = self.tool_history[:-1] + [ + Message(role='user', content='北京市朝阳区最炫酷的运动公园的地点?') + ] + random_id = str(uuid.uuid4()) + config = OmegaConf.create([{ + 'memory': { + 'ignore_role': ['system'], + 'history_mode': 'overwrite', + 'path': f'output/{random_id}', + 'user_id': random_id, + } + }]) + agent1 = LLMAgent(config=OmegaConf.create(config)) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run(tool_history1) + del agent1 + print('========== 数据准备结束,开始测试 ===========') + agent2 = LLMAgent(config=OmegaConf.create(config)) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run(tool_history2) + print(res) + assert ('酒仙桥路8888号' in res[-1].content + and '奥体南路' not in res[-1].content) + + asyncio.run(main()) + + +if __name__ == '__main__': + unittest.main() From 0b87e2adc1668634710182021d7535ca0530139d Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 22 Sep 2025 14:58:37 +0800 Subject: [PATCH 16/17] fix conflicts --- ms_agent/agent/llm_agent.py | 15 ++++--- ms_agent/agent/memory/utils.py | 6 --- ms_agent/memory/mem0ai.py | 4 +- ms_agent/memory/utils.py | 6 ++- ms_agent/utils/__init__.py | 2 +- ms_agent/utils/prompt.py | 77 ++++++++++++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 16 deletions(-) delete mode 100644 ms_agent/agent/memory/utils.py diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 7229474a3..7e09eedd6 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -12,6 +12,7 @@ from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, Tool from ms_agent.memory import Memory, memory_mapping +from ms_agent.memory.mem0ai import Mem0Memory from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping from ms_agent.tools import ToolManager @@ -267,15 +268,14 @@ async def _prepare_memory(self): llm_config_obj = OmegaConf.create(config_dict) setattr(_memory, 'llm', llm_config_obj) - self.memory_tools.append(memory_mapping[memory_type](_memory)) - if _memory.name == 'mem0': + if memory_type == 'mem0': from ms_agent.memory.mem0ai import SharedMemoryManager shared_memory = SharedMemoryManager.get_shared_memory( _memory) self.memory_tools.append(shared_memory) else: self.memory_tools.append( - memory_mapping[_memory.name](_memory)) + memory_mapping[memory_type](_memory)) async def _prepare_planer(self): """Load and initialize the planer component from the config.""" @@ -497,7 +497,9 @@ def _save_history(self, messages: List[Message], **kwargs): user_id = memory_config.user_id break for memory_tool in self.memory_tools: - memory_tool._add_memories_from_conversation(messages, user_id) + if isinstance(memory_tool, Mem0Memory): + memory_tool._add_memories_from_conversation( + messages, user_id) if not self.task or self.task == 'subtask': return @@ -520,8 +522,9 @@ def _save_memory(self, messages: List[Message], **kwargs): if self.memory_tools: agent_id = self.tag for memory_tool in self.memory_tools: - memory_tool._add_memories_from_procedural( - messages, 'subagent', agent_id, 'procedural_memory') + if isinstance(memory_tool, Mem0Memory): + memory_tool._add_memories_from_procedural( + messages, 'subagent', agent_id, 'procedural_memory') return async def _run(self, messages: Union[List[Message], str], diff --git a/ms_agent/agent/memory/utils.py b/ms_agent/agent/memory/utils.py deleted file mode 100644 index 1f7b3cb89..000000000 --- a/ms_agent/agent/memory/utils.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from .mem0ai import Mem0Memory - -memory_mapping = { - 'mem0': Mem0Memory, -} diff --git a/ms_agent/memory/mem0ai.py b/ms_agent/memory/mem0ai.py index 03e7fd628..77fb30d9b 100644 --- a/ms_agent/memory/mem0ai.py +++ b/ms_agent/memory/mem0ai.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from ms_agent.llm.utils import Message -from ms_agent.utils import get_fact_retrieval_prompt, get_logger +from ms_agent.utils import get_code_fact_retrieval_prompt, get_logger from omegaconf import DictConfig from .base import Memory @@ -104,7 +104,7 @@ def _initialize_memory(self): # Monkey patch Mem0's parse_messages function to handle tool messages mem0.memory.main.parse_messages = self.patched_parse_messages # Also update the imported reference in utils module - mem0.memory.utils.FACT_RETRIEVAL_PROMPT = get_fact_retrieval_prompt( + mem0.memory.utils.FACT_RETRIEVAL_PROMPT = get_code_fact_retrieval_prompt( ) embedding_model = 'text-embedding-3-small' diff --git a/ms_agent/memory/utils.py b/ms_agent/memory/utils.py index 22a9f025a..ea00f78ad 100644 --- a/ms_agent/memory/utils.py +++ b/ms_agent/memory/utils.py @@ -1,4 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .default_memory import DefaultMemory +from .mem0ai import Mem0Memory -memory_mapping = {'default_memory': DefaultMemory} +memory_mapping = { + 'default_memory': DefaultMemory, + 'mem0': Mem0Memory, +} diff --git a/ms_agent/utils/__init__.py b/ms_agent/utils/__init__.py index f994f113e..9f95fe2a3 100644 --- a/ms_agent/utils/__init__.py +++ b/ms_agent/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .llm_utils import async_retry, retry from .logger import get_logger -from .prompt import get_fact_retrieval_prompt +from .prompt import get_code_fact_retrieval_prompt, get_fact_retrieval_prompt from .utils import assert_package_exist, enhance_error, strtobool diff --git a/ms_agent/utils/prompt.py b/ms_agent/utils/prompt.py index 01e027f5d..b573539c6 100644 --- a/ms_agent/utils/prompt.py +++ b/ms_agent/utils/prompt.py @@ -5,6 +5,83 @@ def get_fact_retrieval_prompt(): + return f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, preferences, and processing tool interaction outcomes. Your primary role is to extract relevant pieces of information from conversations, organize them into distinct, manageable facts, and additionally process and summarize tool invocation results when present. This ensures both personal data and system interactions are captured for improved context retention and future personalization. + +Types of Information to Remember: +1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment. +2. Maintain Important Personal Details: Remember significant personal information like names, relationships, and important dates. +3. Track Plans and Intentions: Note upcoming events, trips, goals, and any plans the user has shared. +4. Remember Activity and Service Preferences: Recall preferences for dining, travel, hobbies, and other services. +5. Monitor Health and Wellness Preferences: Keep a record of dietary restrictions, fitness routines, and other wellness-related information. +6. Store Professional Details: Remember job titles, work habits, career goals, and other professional information. +7. Miscellaneous Information Management: Keep track of favorite books, movies, brands, and other miscellaneous details that the user shares. + +Tool Interaction Processing Instructions (Additional Responsibilities): +When tool calls and their results are included in the conversation, perform the following in addition to fact extraction: + +1. Extract and Organize Factual Information from Tool Outputs: + - Parse the returned data from successful tool calls (e.g., weather, calendar, search, maps). + - Identify and store objective, user-relevant facts derived from these results (e.g., "It will rain in Paris on 2025-08-25", "The restaurant Little Italy is located at 123 Main St"). + - Integrate these into the "facts" list only if they reflect new, meaningful information about the user's context or environment. +2. Analyze and Summarize Error-Prone Tools: + - Identify tools that frequently fail, time out, or return inconsistent results. + - For such tools, generate a brief internal summary noting the pattern of failure (e.g., "Search tool often returns incomplete results for restaurant queries"). + - This summary does not go into the JSON output but informs future handling (e.g., suggesting alternative tools or double-checking outputs). +3. Identify and Log Tools That Cannot Be Called: + - If a tool was intended but not invoked (e.g., due to missing permissions, unavailability, or misconfiguration), note this in a separate internal log. + - Examples: "Calendar tool unavailable — cannot retrieve user's meeting schedule", "Location access denied — weather tool cannot auto-detect city". + - Include a user-facing reminder if relevant: add a fact like "Could not access calendar due to permission restrictions" only if it impacts user understanding. +4. Ensure Clarity and Non-Disclosure: + - Do not expose tool names, system architecture, or internal logs in the output. + - If asked why information is missing, respond: "I tried to retrieve it from publicly available sources, but the information may not be accessible right now." + +Here are some few-shot examples: +Input: Hi. +Output: {{"facts" : []}} + +Input: There are branches in trees. +Output: {{"facts" : []}} + +Input: Hi, I am looking for a restaurant in San Francisco. +Output: {{"facts" : ["Looking for a restaurant in San Francisco"]}} + +Input: Yesterday, I had a meeting with John at 3pm. We discussed the new project. +Output: {{"facts" : ["Had a meeting with John at 3pm", "Discussed the new project"]}} + +Input: Hi, my name is John. I am a software engineer. +Output: {{"facts" : ["Name is John", "Is a Software engineer"]}} + +Input: My favourite movies are Inception and Interstellar. +Output: {{"facts" : ["Favourite movies are Inception and Interstellar"]}} + +Input (with tool call): What's the weather like in Tokyo today? +[Tool Call: get_weather(location="Tokyo", date="2025-08-22") → Result: {{"status": "success", "data": {{"temp": 32°C, "condition": "Sunny", "humidity": 65%}}}}] +Output: {{"facts": ["It is 32°C and sunny in Tokyo today", "Humidity level in Tokyo is 65%"]}} + +Input (with failed tool): Check my calendar for tomorrow's meetings. +[Tool Call: get_calendar(date="2025-08-23") → Failed: "Access denied – calendar not connected"] +Output: {{"facts": ["Could not access calendar due to connection issues"]}} + +Input (with unreliable tool pattern): Search for vegan restaurants near Central Park. +[Tool Call: search(query="vegan restaurants near Central Park") → Returns incomplete/no results multiple times] +Output: {{"facts": ["Searching for vegan restaurants near Central Park yielded limited results"]}} +(Internal note: Search tool shows low reliability for location-based queries — consider fallback sources.) + +Final Output Rules: + - Today's date is {datetime.now().strftime("%Y-%m-%d")}. + - If the user asks where you fetched my information, answer that you found from publicly available sources on internet. + - Return only a JSON object with key "facts" and value as a list of strings. + - Do not include anything from the example prompts or system instructions. + - Do not reveal tool usage, internal logs, or model behavior. + - If no relevant personal or environmental facts are found, return: {{"facts": []}} + - Extract facts only from user and assistant messages — ignore system-level instructions. + - Detect the input language and record facts in the same language. + +Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation, process any tool call results, and return them in the JSON format as shown above. +""" + + +def get_code_fact_retrieval_prompt(): return f"""You are a Code Development Information Organizer, specialized in accurately storing development facts, project details, and technical preferences from coding conversations. Your primary role is to extract relevant pieces of technical information that will be useful for future code generation and development tasks. Below are the types of information you need to focus on and the detailed instructions on how to handle the input data. Types of Information to Remember: From 3b1343ee872904585fc54cededc6f544a0f20e34 Mon Sep 17 00:00:00 2001 From: suluyan Date: Mon, 22 Sep 2025 16:38:00 +0800 Subject: [PATCH 17/17] minor fix --- ms_agent/memory/default_memory.py | 2 +- ms_agent/utils/prompts.py | 78 ------------------------------- 2 files changed, 1 insertion(+), 79 deletions(-) delete mode 100644 ms_agent/utils/prompts.py diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index 7bddcafb6..71611ae17 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -9,8 +9,8 @@ import json5 from ms_agent.llm.utils import Message from ms_agent.memory import Memory +from ms_agent.utils import get_fact_retrieval_prompt from ms_agent.utils.logger import logger -from ms_agent.utils.prompts import get_fact_retrieval_prompt from omegaconf import DictConfig, OmegaConf diff --git a/ms_agent/utils/prompts.py b/ms_agent/utils/prompts.py deleted file mode 100644 index fb57044b3..000000000 --- a/ms_agent/utils/prompts.py +++ /dev/null @@ -1,78 +0,0 @@ -from datetime import datetime - - -def get_fact_retrieval_prompt(): - return f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, preferences, and processing tool interaction outcomes. Your primary role is to extract relevant pieces of information from conversations, organize them into distinct, manageable facts, and additionally process and summarize tool invocation results when present. This ensures both personal data and system interactions are captured for improved context retention and future personalization. - -Types of Information to Remember: -1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment. -2. Maintain Important Personal Details: Remember significant personal information like names, relationships, and important dates. -3. Track Plans and Intentions: Note upcoming events, trips, goals, and any plans the user has shared. -4. Remember Activity and Service Preferences: Recall preferences for dining, travel, hobbies, and other services. -5. Monitor Health and Wellness Preferences: Keep a record of dietary restrictions, fitness routines, and other wellness-related information. -6. Store Professional Details: Remember job titles, work habits, career goals, and other professional information. -7. Miscellaneous Information Management: Keep track of favorite books, movies, brands, and other miscellaneous details that the user shares. - -Tool Interaction Processing Instructions (Additional Responsibilities): -When tool calls and their results are included in the conversation, perform the following in addition to fact extraction: - -1. Extract and Organize Factual Information from Tool Outputs: - - Parse the returned data from successful tool calls (e.g., weather, calendar, search, maps). - - Identify and store objective, user-relevant facts derived from these results (e.g., "It will rain in Paris on 2025-08-25", "The restaurant Little Italy is located at 123 Main St"). - - Integrate these into the "facts" list only if they reflect new, meaningful information about the user's context or environment. -2. Analyze and Summarize Error-Prone Tools: - - Identify tools that frequently fail, time out, or return inconsistent results. - - For such tools, generate a brief internal summary noting the pattern of failure (e.g., "Search tool often returns incomplete results for restaurant queries"). - - This summary does not go into the JSON output but informs future handling (e.g., suggesting alternative tools or double-checking outputs). -3. Identify and Log Tools That Cannot Be Called: - - If a tool was intended but not invoked (e.g., due to missing permissions, unavailability, or misconfiguration), note this in a separate internal log. - - Examples: "Calendar tool unavailable — cannot retrieve user's meeting schedule", "Location access denied — weather tool cannot auto-detect city". - - Include a user-facing reminder if relevant: add a fact like "Could not access calendar due to permission restrictions" only if it impacts user understanding. -4. Ensure Clarity and Non-Disclosure: - - Do not expose tool names, system architecture, or internal logs in the output. - - If asked why information is missing, respond: "I tried to retrieve it from publicly available sources, but the information may not be accessible right now." - -Here are some few-shot examples: -Input: Hi. -Output: {{"facts" : []}} - -Input: There are branches in trees. -Output: {{"facts" : []}} - -Input: Hi, I am looking for a restaurant in San Francisco. -Output: {{"facts" : ["Looking for a restaurant in San Francisco"]}} - -Input: Yesterday, I had a meeting with John at 3pm. We discussed the new project. -Output: {{"facts" : ["Had a meeting with John at 3pm", "Discussed the new project"]}} - -Input: Hi, my name is John. I am a software engineer. -Output: {{"facts" : ["Name is John", "Is a Software engineer"]}} - -Input: My favourite movies are Inception and Interstellar. -Output: {{"facts" : ["Favourite movies are Inception and Interstellar"]}} - -Input (with tool call): What's the weather like in Tokyo today? -[Tool Call: get_weather(location="Tokyo", date="2025-08-22") → Result: {{"status": "success", "data": {{"temp": 32°C, "condition": "Sunny", "humidity": 65%}}}}] -Output: {{"facts": ["It is 32°C and sunny in Tokyo today", "Humidity level in Tokyo is 65%"]}} - -Input (with failed tool): Check my calendar for tomorrow's meetings. -[Tool Call: get_calendar(date="2025-08-23") → Failed: "Access denied – calendar not connected"] -Output: {{"facts": ["Could not access calendar due to connection issues"]}} - -Input (with unreliable tool pattern): Search for vegan restaurants near Central Park. -[Tool Call: search(query="vegan restaurants near Central Park") → Returns incomplete/no results multiple times] -Output: {{"facts": ["Searching for vegan restaurants near Central Park yielded limited results"]}} -(Internal note: Search tool shows low reliability for location-based queries — consider fallback sources.) - -Final Output Rules: - - Today's date is {datetime.now().strftime("%Y-%m-%d")}. - - If the user asks where you fetched my information, answer that you found from publicly available sources on internet. - - Return only a JSON object with key "facts" and value as a list of strings. - - Do not include anything from the example prompts or system instructions. - - Do not reveal tool usage, internal logs, or model behavior. - - If no relevant personal or environmental facts are found, return: {{"facts": []}} - - Extract facts only from user and assistant messages — ignore system-level instructions. - - Detect the input language and record facts in the same language. - -Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation, process any tool call results, and return them in the JSON format as shown above. -"""