Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions packages/memory_module/memory_module/core/memory_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
class EpisodicMemoryExtraction(BaseModel):
action: Literal["add", "update", "ignore"] = Field(..., description="Action to take on the extracted fact")
reason_for_action: Optional[str] = Field(
..., description="Reason for the action taken on the extracted fact or the reason it was ignored."
...,
description="Reason for the action taken on the extracted fact or the reason it was ignored.",
)
summary: Optional[str] = Field(
...,
Expand All @@ -89,6 +90,15 @@
)


class RetrievedMemory(BaseModel):
memory_id: str = Field(..., description="The id of the memory that was retrieved")
reason: str = Field(..., description="The reason the memory was retrieved")


class RetrievedMemories(BaseModel):
memories: List[RetrievedMemory] = Field(..., description="The memories that were retrieved")


class MemoryCore(BaseMemoryCore):
"""Implementation of the memory core component."""

Expand Down Expand Up @@ -119,7 +129,8 @@
"""Process multiple messages into semantic memories (general facts, preferences)."""
# make sure there is an author, and only one author
author_id = next(
(message.author_id for message in messages if message.author_id and message.type == "user"), None
(message.author_id for message in messages if message.author_id and message.type == "user"),
None,
)
if not author_id:
logger.error("No author found in messages")
Expand Down Expand Up @@ -166,10 +177,50 @@
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
return await self._retrieve_memories(
user_id, config.query, [config.topic] if config.topic else None, config.limit
stored_memories = await self._retrieve_memories(
user_id,
config.query,
[config.topic] if config.topic else None,
config.limit,
)

if not config.query or not stored_memories:
return stored_memories

# Format memories for LLM context
memories_context = "\n".join(
[
f"<MEMORY id='{memory.id}' created_at='{memory.created_at}'>{memory.content}</MEMORY>"
for memory in stored_memories
]
)

messages = [
{
"role": "system",
"content": """You are a memory retrieval assistant. Analyze the provided memories and select the ones that are most relevant to answering the user's query.

Check failure on line 201 in packages/memory_module/memory_module/core/memory_core.py

View workflow job for this annotation

GitHub Actions / Build, Lint & Test (3.12)

Ruff (E501)

packages/memory_module/memory_module/core/memory_core.py:201:121: E501 Line too long (171 > 120)
Return the memory IDs ordered by relevance.""",
},
{
"role": "user",
"content": f"""Here are the memories:
{memories_context}

Query: {config.query}

Return the list of memory IDs ordered by relevance.""",
},
]

retrieved_memories = await self.lm.completion(messages=messages, response_model=RetrievedMemories)

# Filter and sort memories based on LLM's selection
memory_dict = {memory.id: memory for memory in stored_memories}
filtered_memory_ids = [memory.memory_id for memory in retrieved_memories.memories]
filtered_memories = [memory_dict[id] for id in filtered_memory_ids if id in memory_dict]

return filtered_memories

async def _retrieve_memories(
self,
user_id: Optional[str],
Expand Down Expand Up @@ -215,7 +266,10 @@
# If all messages associated with a memory are removed, remove that memory too
if all(item in message_ids for item in memory.message_attributions):
removed_memory_ids.append(memory.id)
logger.info("memory %s will be removed since all associated messages are removed", memory.id)
logger.info(
"memory %s will be removed since all associated messages are removed",
memory.id,
)

# Remove selected messages and related old memories
await self.storage.remove_memories(removed_memory_ids)
Expand Down
Loading