Skip to content

Commit

Permalink
Make ChatCompletionCache support component config (#5658)
Browse files Browse the repository at this point in the history
<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

This PR makes makes ChatCompletionCache   support component config

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed? 

Ensures we have a path to serializing ChatCompletionCache , similar to
the ChatCompletion client that it wraps.

This PR does the following

- Makes CacheStore serializable first (part of this includes converting
from Protocol to base class). Makes it's derivatives serializable as
well (diskcache, redis)
- Makes ChatCompletionCache serializable 
- Adds some tests

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

<!-- For example: "Closes #1234" -->

Closes #5141

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed. 


cc @nour-bouzid
  • Loading branch information
victordibia authored Feb 24, 2025
1 parent a226966 commit 170b8cc
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 10 deletions.
30 changes: 27 additions & 3 deletions python/packages/autogen-core/src/autogen_core/_cache_store.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from typing import Dict, Generic, Optional, Protocol, TypeVar
from abc import ABC, abstractmethod
from typing import Dict, Generic, Optional, TypeVar

from pydantic import BaseModel
from typing_extensions import Self

from ._component_config import Component, ComponentBase

T = TypeVar("T")


class CacheStore(Protocol, Generic[T]):
class CacheStore(ABC, Generic[T], ComponentBase[BaseModel]):
"""
This protocol defines the basic interface for store/cache operations.
Sub-classes should handle the lifecycle of underlying storage.
"""

component_type = "cache_store"

@abstractmethod
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
"""
Retrieve an item from the store.
Expand All @@ -24,6 +33,7 @@ def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
"""
...

@abstractmethod
def set(self, key: str, value: T) -> None:
"""
Set an item in the store.
Expand All @@ -35,7 +45,14 @@ def set(self, key: str, value: T) -> None:
...


class InMemoryStore(CacheStore[T]):
class InMemoryStoreConfig(BaseModel):
pass


class InMemoryStore(CacheStore[T], Component[InMemoryStoreConfig]):
component_provider_override = "autogen_core.InMemoryStore"
component_config_schema = InMemoryStoreConfig

def __init__(self) -> None:
self.store: Dict[str, T] = {}

Expand All @@ -44,3 +61,10 @@ def get(self, key: str, default: Optional[T] = None) -> Optional[T]:

def set(self, key: str, value: T) -> None:
self.store[key] = value

def _to_config(self) -> InMemoryStoreConfig:
return InMemoryStoreConfig()

@classmethod
def _from_config(cls, config: InMemoryStoreConfig) -> Self:
return cls()
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import Any, Optional, TypeVar, cast

import diskcache
from autogen_core import CacheStore
from autogen_core import CacheStore, Component
from pydantic import BaseModel
from typing_extensions import Self

T = TypeVar("T")


class DiskCacheStore(CacheStore[T]):
class DiskCacheStoreConfig(BaseModel):
"""Configuration for DiskCacheStore"""

directory: str # Path where cache is stored
# Could add other diskcache.Cache parameters like size_limit, etc.


class DiskCacheStore(CacheStore[T], Component[DiskCacheStoreConfig]):
"""
A typed CacheStore implementation that uses diskcache as the underlying storage.
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
Expand All @@ -16,6 +25,9 @@ class DiskCacheStore(CacheStore[T]):
The user is responsible for managing the DiskCache instance's lifetime.
"""

component_config_schema = DiskCacheStoreConfig
component_provider_override = "autogen_ext.cache_store.diskcache.DiskCacheStore"

def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported]
self.cache = cache_instance

Expand All @@ -24,3 +36,11 @@ def get(self, key: str, default: Optional[T] = None) -> Optional[T]:

def set(self, key: str, value: T) -> None:
self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType]

def _to_config(self) -> DiskCacheStoreConfig:
# Get directory from cache instance
return DiskCacheStoreConfig(directory=self.cache.directory)

@classmethod
def _from_config(cls, config: DiskCacheStoreConfig) -> Self:
return cls(cache_instance=diskcache.Cache(config.directory)) # type: ignore[no-any-return]
57 changes: 54 additions & 3 deletions python/packages/autogen-ext/src/autogen_ext/cache_store/redis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
from typing import Any, Optional, TypeVar, cast
from typing import Any, Dict, Optional, TypeVar, cast

import redis
from autogen_core import CacheStore
from autogen_core import CacheStore, Component
from pydantic import BaseModel
from typing_extensions import Self

T = TypeVar("T")


class RedisStore(CacheStore[T]):
class RedisStoreConfig(BaseModel):
"""Configuration for RedisStore"""

host: str = "localhost"
port: int = 6379
db: int = 0
# Add other relevant redis connection parameters
username: Optional[str] = None
password: Optional[str] = None
ssl: bool = False
socket_timeout: Optional[float] = None


class RedisStore(CacheStore[T], Component[RedisStoreConfig]):
"""
A typed CacheStore implementation that uses redis as the underlying storage.
See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage.
Expand All @@ -16,6 +31,9 @@ class RedisStore(CacheStore[T]):
The user is responsible for managing the Redis instance's lifetime.
"""

component_config_schema = RedisStoreConfig
component_provider_override = "autogen_ext.cache_store.redis.RedisStore"

def __init__(self, redis_instance: redis.Redis):
self.cache = redis_instance

Expand All @@ -27,3 +45,36 @@ def get(self, key: str, default: Optional[T] = None) -> Optional[T]:

def set(self, key: str, value: T) -> None:
self.cache.set(key, cast(Any, value))

def _to_config(self) -> RedisStoreConfig:
# Extract connection info from redis instance
connection_pool = self.cache.connection_pool
connection_kwargs: Dict[str, Any] = connection_pool.connection_kwargs # type: ignore[reportUnknownMemberType]

username = connection_kwargs.get("username")
password = connection_kwargs.get("password")
socket_timeout = connection_kwargs.get("socket_timeout")

return RedisStoreConfig(
host=str(connection_kwargs.get("host", "localhost")),
port=int(connection_kwargs.get("port", 6379)),
db=int(connection_kwargs.get("db", 0)),
username=str(username) if username is not None else None,
password=str(password) if password is not None else None,
ssl=bool(connection_kwargs.get("ssl", False)),
socket_timeout=float(socket_timeout) if socket_timeout is not None else None,
)

@classmethod
def _from_config(cls, config: RedisStoreConfig) -> Self:
# Create new redis instance from config
redis_instance = redis.Redis(
host=config.host,
port=config.port,
db=config.db,
username=config.username,
password=config.password,
ssl=config.ssl,
socket_timeout=config.socket_timeout,
)
return cls(redis_instance=redis_instance)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import CacheStore, CancellationToken, InMemoryStore
from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
Expand All @@ -13,11 +13,20 @@
RequestUsage,
)
from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from typing_extensions import Self

CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]]


class ChatCompletionCache(ChatCompletionClient):
class ChatCompletionCacheConfig(BaseModel):
""" """

client: ComponentModel
store: Optional[ComponentModel] = None


class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheConfig]):
"""
A wrapper around a :class:`~autogen_ext.models.cache.ChatCompletionClient` that caches
creation results from an underlying client.
Expand Down Expand Up @@ -77,6 +86,10 @@ async def main():
Defaults to using in-memory cache.
"""

component_type = "chat_completion_cache"
component_provider_override = "autogen_ext.models.cache.ChatCompletionCache"
component_config_schema = ChatCompletionCacheConfig

def __init__(
self,
client: ChatCompletionClient,
Expand Down Expand Up @@ -213,3 +226,17 @@ def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[To

def total_usage(self) -> RequestUsage:
return self.client.total_usage()

def _to_config(self) -> ChatCompletionCacheConfig:
return ChatCompletionCacheConfig(
client=self.client.dump_component(),
store=self.store.dump_component() if not isinstance(self.store, InMemoryStore) else None,
)

@classmethod
def _from_config(cls, config: ChatCompletionCacheConfig) -> Self:
client = ChatCompletionClient.load_component(config.client)
store: Optional[CacheStore[CHAT_CACHE_VALUE_TYPE]] = (
CacheStore.load_component(config.store) if config.store else InMemoryStore()
)
return cls(client=client, store=store)
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ def test_diskcache_with_different_instances() -> None:

store_2.set(test_key, test_value_2)
assert store_2.get(test_key) == test_value_2

# test serialization
store_1_config = store_1.dump_component()
loaded_store_1: DiskCacheStore[int] = DiskCacheStore.load_component(store_1_config)
assert loaded_store_1.get(test_key) == test_value_1
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ def test_redis_with_different_instances() -> None:
redis_instance_2.set.assert_called_with(test_key, test_value_2)
redis_instance_2.get.return_value = test_value_2
assert store_2.get(test_key) == test_value_2

# test serialization
store_1_config = store_1.dump_component()
assert store_1_config.component_type == "cache_store"
assert store_1_config.component_version == 1
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,8 @@ async def test_cache_create_stream() -> None:
assert not original.cached
else:
raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}")

# test serialization
# cached_client_config = cached_client.dump_component()
# loaded_client = ChatCompletionCache.load_component(cached_client_config)
# assert loaded_client.client == cached_client.client

1 comment on commit 170b8cc

@Zochory
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should so much have a Changelog page in the documentation imo to always be aware of all these small but sometimes very useful or important changes !

Please sign in to comment.