Skip to content

Commit b62e4d4

Browse files
committed
added short-term memory to the agent
1 parent a44d0a3 commit b62e4d4

File tree

17 files changed

+680
-41
lines changed

17 files changed

+680
-41
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
deploy.py
22
destinations.py
33

4-
# knowledge/
5-
memory/
4+
entity_memory.py
5+
long_term_memory.py
66

77
# default_agents.py
88
train.py

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,17 @@ Tasks can be delegated to a team manager, peers in the team, or completely new a
156156
<hr />
157157

158158
## Technologies Used
159-
**Schema, Database, Data Validation**
159+
**Schema, Data Validation**
160160
- [Pydantic](https://docs.pydantic.dev/latest/): Data validation and serialization library for Python.
161161
- [Pydantic_core](https://pypi.org/project/pydantic-core/): Core func packages for Pydantic.
162-
- [Chroma DB](https://docs.trychroma.com/): Vector database for storing and querying usage data.
163-
- [SQLite](https://www.sqlite.org/docs.html): C-language library to implements a small SQL database engine.
164162
- [Upstage](https://console.upstage.ai/docs/getting-started/overview): Document processer for ML tasks. (Use `Document Parser API` to extract data from documents)
165163
- [Docling](https://ds4sd.github.io/docling/): Document parsing
166164

165+
**Storage**
166+
- [mem0ai](https://docs.mem0.ai/quickstart#install-package): Agents' memory storage and management.
167+
- [Chroma DB](https://docs.trychroma.com/): Vector database for storing and querying usage data.
168+
- [SQLite](https://www.sqlite.org/docs.html): C-language library to implements a small SQL database engine.
169+
167170
**LLM-curation**
168171
- [LiteLLM](https://docs.litellm.ai/docs/providers): Curation platform to access LLMs
169172

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"json-repair>=0.35.0",
4646
"wheel>=0.45.1",
4747
"pdfplumber>=0.11.5",
48+
"mem0ai>=0.1.48",
4849
]
4950
classifiers = [
5051
"Programming Language :: Python",

requirements.txt

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,19 @@ googleapis-common-protos==1.66.0
136136
grpcio==1.70.0
137137
# via
138138
# chromadb
139+
# grpcio-tools
139140
# opentelemetry-exporter-otlp-proto-grpc
141+
# qdrant-client
142+
grpcio-tools==1.70.0
143+
# via qdrant-client
140144
h11==0.14.0
141145
# via
142146
# httpcore
143147
# uvicorn
148+
h2==4.1.0
149+
# via httpx
150+
hpack==4.1.0
151+
# via h2
144152
httpcore==1.0.7
145153
# via httpx
146154
httptools==0.6.4
@@ -151,6 +159,7 @@ httpx==0.27.2
151159
# langsmith
152160
# litellm
153161
# openai
162+
# qdrant-client
154163
huggingface-hub==0.27.0
155164
# via
156165
# docling
@@ -159,6 +168,8 @@ huggingface-hub==0.27.0
159168
# transformers
160169
humanfriendly==10.0
161170
# via coloredlogs
171+
hyperframe==6.1.0
172+
# via h2
162173
idna==3.10
163174
# via
164175
# anyio
@@ -243,6 +254,8 @@ markupsafe==3.0.2
243254
# werkzeug
244255
mdurl==0.1.2
245256
# via markdown-it-py
257+
mem0ai==0.1.48
258+
# via versionhq (pyproject.toml)
246259
mmh3==5.1.0
247260
# via chromadb
248261
monotonic==1.6
@@ -274,6 +287,7 @@ numpy==2.2.1
274287
# onnxruntime
275288
# opencv-python-headless
276289
# pandas
290+
# qdrant-client
277291
# safetensors
278292
# scikit-image
279293
# scipy
@@ -293,6 +307,7 @@ openai==1.58.1
293307
# composio-openai
294308
# langchain-openai
295309
# litellm
310+
# mem0ai
296311
opencv-python-headless==4.11.0.86
297312
# via
298313
# docling-ibm-models
@@ -377,15 +392,20 @@ pillow==10.4.0
377392
# python-pptx
378393
# scikit-image
379394
# torchvision
395+
portalocker==2.10.1
396+
# via qdrant-client
380397
posthog==3.11.0
381-
# via chromadb
398+
# via
399+
# chromadb
400+
# mem0ai
382401
propcache==0.2.1
383402
# via
384403
# aiohttp
385404
# yarl
386405
protobuf==5.29.3
387406
# via
388407
# googleapis-common-protos
408+
# grpcio-tools
389409
# onnxruntime
390410
# opentelemetry-proto
391411
pyasn1==0.6.1
@@ -412,8 +432,10 @@ pydantic==2.10.6
412432
# langchain-core
413433
# langsmith
414434
# litellm
435+
# mem0ai
415436
# openai
416437
# pydantic-settings
438+
# qdrant-client
417439
pydantic-core==2.27.2
418440
# via pydantic
419441
pydantic-settings==2.7.1
@@ -454,7 +476,9 @@ python-dotenv==1.0.1
454476
python-pptx==1.0.2
455477
# via docling
456478
pytz==2024.2
457-
# via pandas
479+
# via
480+
# mem0ai
481+
# pandas
458482
pyyaml==6.0.2
459483
# via
460484
# chromadb
@@ -466,6 +490,8 @@ pyyaml==6.0.2
466490
# langchain-core
467491
# transformers
468492
# uvicorn
493+
qdrant-client==1.13.2
494+
# via mem0ai
469495
referencing==0.35.1
470496
# via
471497
# jsonschema
@@ -528,6 +554,7 @@ sentry-sdk==2.19.2
528554
setuptools==75.6.0
529555
# via
530556
# versionhq (pyproject.toml)
557+
# grpcio-tools
531558
# torch
532559
shapely==2.0.6
533560
# via easyocr
@@ -546,7 +573,9 @@ sniffio==1.3.1
546573
soupsieve==2.6
547574
# via beautifulsoup4
548575
sqlalchemy==2.0.36
549-
# via langchain
576+
# via
577+
# langchain
578+
# mem0ai
550579
starlette==0.41.3
551580
# via fastapi
552581
sympy==1.13.3
@@ -627,6 +656,7 @@ tzdata==2025.1
627656
urllib3==2.3.0
628657
# via
629658
# kubernetes
659+
# qdrant-client
630660
# requests
631661
# sentry-sdk
632662
# types-requests

src/versionhq/agent/model.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from versionhq.llm.model import LLM, DEFAULT_CONTEXT_WINDOW_SIZE, DEFAULT_MODEL_NAME
1313
from versionhq.tool.model import Tool, ToolSet
1414
from versionhq.knowledge.model import BaseKnowledgeSource, Knowledge
15+
from versionhq.memory.contextual_memory import ContextualMemory
16+
from versionhq.memory.model import ShortTermMemory, UserMemory
1517
from versionhq._utils.logger import Logger
1618
from versionhq.agent.rpm_controller import RPMController
1719
from versionhq._utils.usage_metrics import UsageMetrics
@@ -95,10 +97,20 @@ class Agent(BaseModel):
9597
backstory: Optional[str] = Field(default=None, description="developer prompt to the llm")
9698
skillsets: Optional[List[str]] = Field(default_factory=list)
9799
tools: Optional[List[Tool | ToolSet | Type[Tool]]] = Field(default_factory=list)
100+
101+
# knowledge
98102
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(default=None)
99103
_knowledge: Optional[Knowledge] = PrivateAttr(default=None)
100-
embedder_config: Optional[Dict[str, Any]] = Field(default=None, description="embedder configuration for the agent's knowledge")
101104

105+
# memory
106+
use_memory: bool = Field(default=False, description="whether to store/use memory when executing the task")
107+
memory_config: Optional[Dict[str, Any]] = Field(default=None, description="configuration for the memory")
108+
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(default=None)
109+
user_memory: Optional[InstanceOf[UserMemory]] = Field(default=None)
110+
# _short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
111+
# _user_memory: Optional[InstanceOf[UserMemory]] = PrivateAttr()
112+
113+
embedder_config: Optional[Dict[str, Any]] = Field(default=None, description="embedder configuration for the agent's knowledge")
102114

103115
# prompting
104116
use_developer_prompt: Optional[bool] = Field(default=True, description="Use developer prompt when calling the llm")
@@ -347,14 +359,30 @@ def set_up_knowledge(self) -> Self:
347359
return self
348360

349361

362+
@model_validator(mode="after")
363+
def set_up_memory(self) -> Self:
364+
"""
365+
Set up memories: stm, um
366+
"""
367+
368+
if self.use_memory == True:
369+
self.short_term_memory = self.short_term_memory if self.short_term_memory else ShortTermMemory(agent=self, embedder_config=self.embedder_config)
370+
371+
if hasattr(self, "memory_config") and self.memory_config is not None:
372+
self.user_memory = self.user_memory if self.user_memory else UserMemory(agent=self)
373+
else:
374+
self.user_memory = None
375+
376+
return self
377+
378+
350379
def _train(self) -> Self:
351380
"""
352381
Fine-tuned the base model using OpenAI train framework.
353382
"""
354383
if not isinstance(self.llm, LLM):
355384
pass
356385

357-
358386
def invoke(
359387
self,
360388
prompts: str,
@@ -440,6 +468,14 @@ def execute_task(self, task, context: Optional[str] = None, task_tools: Optional
440468
if agent_knowledge_context:
441469
task_prompt += agent_knowledge_context
442470

471+
472+
if self.use_memory == True:
473+
contextual_memory = ContextualMemory(memory_config=self.memory_config, stm=self.short_term_memory, um=self.user_memory)
474+
memory = contextual_memory.build_context_for_task(task, context)
475+
if memory.strip() != "":
476+
task_prompt += memory.strip()
477+
478+
443479
# if self.team and self.team._train:
444480
# task_prompt = self._training_handler(task_prompt=task_prompt)
445481
# else:

src/versionhq/memory/__init__.py

Whitespace-only changes.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from typing import Any, Dict, Optional, List
2+
3+
from versionhq.memory.model import ShortTermMemory, UserMemory
4+
5+
6+
class ContextualMemory:
7+
"""
8+
A class to construct context from memories (ShortTermMemory, UserMemory).
9+
The context will be added to the prompt when the agent executes the task.
10+
"""
11+
12+
def __init__(
13+
self,
14+
memory_config: Optional[Dict[str, Any]],
15+
stm: ShortTermMemory,
16+
um: UserMemory,
17+
# ltm: LongTermMemory,
18+
# em: EntityMemory,
19+
):
20+
self.memory_provider = memory_config.get("provider") if memory_config is not None else None
21+
self.stm = stm
22+
self.um = um
23+
24+
25+
def build_context_for_task(self, task, context: List[Any] | str) -> str:
26+
"""
27+
Automatically builds a minimal, highly relevant set of contextual information for a given task.
28+
"""
29+
30+
query = f"{task.description} {context}".strip()
31+
32+
if query == "":
33+
return ""
34+
35+
context = []
36+
context.append(self._fetch_stm_context(query))
37+
if self.memory_provider == "mem0":
38+
context.append(self._fetch_user_context(query))
39+
return "\n".join(filter(None, context))
40+
41+
42+
def _fetch_stm_context(self, query) -> str:
43+
"""
44+
Fetches recent relevant insights from STM related to the task's description and expected_output, formatted as bullet points.
45+
"""
46+
stm_results = self.stm.search(query)
47+
formatted_results = "\n".join(
48+
[
49+
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
50+
for result in stm_results
51+
]
52+
)
53+
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
54+
55+
56+
def _fetch_user_context(self, query: str) -> str:
57+
"""
58+
Fetches and formats relevant user information from User Memory.
59+
"""
60+
61+
user_memories = self.um.search(query)
62+
if not user_memories:
63+
return ""
64+
65+
formatted_memories = "\n".join(f"- {result['memory']}" for result in user_memories)
66+
return f"User memories/preferences:\n{formatted_memories}"
67+
68+
69+
# def _fetch_ltm_context(self, task) -> Optional[str]:
70+
# """
71+
# Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
72+
# formatted as bullet points.
73+
# """
74+
# ltm_results = self.ltm.search(task, latest_n=2)
75+
# if not ltm_results:
76+
# return None
77+
78+
# formatted_results = [
79+
# suggestion
80+
# for result in ltm_results
81+
# for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
82+
# ]
83+
# formatted_results = list(dict.fromkeys(formatted_results))
84+
# formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
85+
86+
# return f"Historical Data:\n{formatted_results}" if ltm_results else ""
87+
88+
# def _fetch_entity_context(self, query) -> str:
89+
# """
90+
# Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
91+
# formatted as bullet points.
92+
# """
93+
# em_results = self.em.search(query)
94+
# formatted_results = "\n".join(
95+
# [
96+
# f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
97+
# for result in em_results
98+
# ] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
99+
# )
100+
# return f"Entities:\n{formatted_results}" if em_results else ""

0 commit comments

Comments
 (0)