Skip to content

Commit

Permalink
Memory component base (#5380)
Browse files Browse the repository at this point in the history
<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

Currently the way to accomplish RAG behavior with agent chat,
specifically assistant agents is with the memory interface, however
there is no way to configure it via the declarative API.

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [ ] I've included any doc changes needed for
https://microsoft.github.io/autogen/. See
https://microsoft.github.io/autogen/docs/Contribute#documentation to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed.

---------

Co-authored-by: Victor Dibia <[email protected]>
  • Loading branch information
EItanya and victordibia authored Feb 6, 2025
1 parent be3c60b commit 172a16a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class AssistantAgentConfig(BaseModel):
tools: List[ComponentModel] | None
handoffs: List[HandoffBase | str] | None = None
model_context: ComponentModel | None = None
memory: List[ComponentModel] | None = None
description: str
system_message: str | None = None
model_client_stream: bool = False
Expand Down Expand Up @@ -591,6 +592,7 @@ def _to_config(self) -> AssistantAgentConfig:
tools=[tool.dump_component() for tool in self._tools],
handoffs=list(self._handoffs.values()),
model_context=self._model_context.dump_component(),
memory=[memory.dump_component() for memory in self._memory] if self._memory else None,
description=self.description,
system_message=self._system_messages[0].content
if self._system_messages and isinstance(self._system_messages[0].content, str)
Expand All @@ -609,6 +611,7 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self:
tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None,
handoffs=config.handoffs,
model_context=None,
memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None,
description=config.description,
system_message=config.system_message,
model_client_stream=config.model_client_stream,
Expand Down
12 changes: 9 additions & 3 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import FunctionCall, Image
from autogen_core import ComponentModel, FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import (
Expand Down Expand Up @@ -754,7 +754,12 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
"test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2]
)

result = await agent.run(task="test task")
# Test dump and load component with memory
agent_config: ComponentModel = agent.dump_component()
assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent"
agent2 = AssistantAgent.load_component(agent_config)

result = await agent2.run(task="test task")
assert len(result.messages) > 0
memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None)
assert memory_event is not None
Expand Down Expand Up @@ -795,9 +800,10 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_context=model_context,
memory=[ListMemory(name="test_memory")],
)

agent_config = agent.dump_component()
agent_config: ComponentModel = agent.dump_component()
assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent"

agent2 = AssistantAgent.load_component(agent_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, ConfigDict

from .._cancellation_token import CancellationToken
from .._component_config import ComponentBase
from .._image import Image
from ..model_context import ChatCompletionContext

Expand Down Expand Up @@ -49,7 +50,7 @@ class UpdateContextResult(BaseModel):
memories: MemoryQueryResult


class Memory(ABC):
class Memory(ABC, ComponentBase[BaseModel]):
"""Protocol defining the interface for memory implementations.
A memory is the storage for data that can be used to enrich or modify the model context.
Expand All @@ -64,6 +65,8 @@ class Memory(ABC):
See :class:`~autogen_core.memory.ListMemory` for an example implementation.
"""

component_type = "memory"

@abstractmethod
async def update_context(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from typing import Any, List

from pydantic import BaseModel
from typing_extensions import Self

from .._cancellation_token import CancellationToken
from .._component_config import Component
from ..model_context import ChatCompletionContext
from ..models import SystemMessage
from ._base_memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult


class ListMemory(Memory):
class ListMemoryConfig(BaseModel):
"""Configuration for ListMemory component."""

name: str | None = None
"""Optional identifier for this memory instance."""
memory_contents: List[MemoryContent] = []
"""List of memory contents stored in this memory instance."""


class ListMemory(Memory, Component[ListMemoryConfig]):
"""Simple chronological list-based memory implementation.
This memory implementation stores contents in a list and retrieves them in
Expand Down Expand Up @@ -53,9 +66,13 @@ async def main() -> None:
"""

def __init__(self, name: str | None = None) -> None:
component_type = "memory"
component_provider_override = "autogen_core.memory.ListMemory"
component_config_schema = ListMemoryConfig

def __init__(self, name: str | None = None, memory_contents: List[MemoryContent] | None = None) -> None:
self._name = name or "default_list_memory"
self._contents: List[MemoryContent] = []
self._contents: List[MemoryContent] = memory_contents if memory_contents is not None else []

@property
def name(self) -> str:
Expand Down Expand Up @@ -146,3 +163,10 @@ async def clear(self) -> None:
async def close(self) -> None:
"""Cleanup resources if needed."""
pass

@classmethod
def _from_config(cls, config: ListMemoryConfig) -> Self:
return cls(name=config.name, memory_contents=config.memory_contents)

def _to_config(self) -> ListMemoryConfig:
return ListMemoryConfig(name=self.name, memory_contents=self._contents)
30 changes: 29 additions & 1 deletion python/packages/autogen-core/tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

import pytest
from autogen_core import CancellationToken
from autogen_core import CancellationToken, ComponentModel
from autogen_core.memory import (
ListMemory,
Memory,
Expand All @@ -23,6 +23,34 @@ def test_memory_protocol_attributes() -> None:
assert hasattr(Memory, "close")


def test_memory_component_load_config_from_base_model() -> None:
"""Test that Memory component can be loaded from a BaseModel."""
config = ComponentModel(
provider="autogen_core.memory.ListMemory",
config={
"name": "test_memory",
"memory_contents": [MemoryContent(content="test", mime_type=MemoryMimeType.TEXT)],
},
)
memory = Memory.load_component(config)
assert isinstance(memory, ListMemory)
assert memory.name == "test_memory"
assert len(memory.content) == 1


def test_memory_component_dump_config_to_base_model() -> None:
"""Test that Memory component can be dumped to a BaseModel."""
memory = ListMemory(
name="test_memory", memory_contents=[MemoryContent(content="test", mime_type=MemoryMimeType.TEXT)]
)
config = memory.dump_component()
assert isinstance(config, ComponentModel)
assert config.provider == "autogen_core.memory.ListMemory"
assert config.component_type == "memory"
assert config.config["name"] == "test_memory"
assert len(config.config["memory_contents"]) == 1


def test_memory_abc_implementation() -> None:
"""Test that Memory ABC is properly implemented."""

Expand Down

0 comments on commit 172a16a

Please sign in to comment.