Skip to content

Commit

Permalink
Accept agent type in more places (#4829)
Browse files Browse the repository at this point in the history
* Accept agenttype in more places

* remove type hint
  • Loading branch information
jackgerrits authored Dec 27, 2024
1 parent 90a44b5 commit 5bd91fb
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Type, TypeVar, overload

from ._agent_type import AgentType
from ._base_agent import BaseAgent, subscription_factory
from ._subscription_context import SubscriptionInstantiationContext
from ._type_subscription import TypeSubscription
Expand All @@ -16,7 +17,7 @@ class DefaultSubscription(TypeSubscription):
agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context.
"""

def __init__(self, topic_type: str = "default", agent_type: str | None = None):
def __init__(self, topic_type: str = "default", agent_type: str | AgentType | None = None):
if agent_type is None:
try:
agent_type = SubscriptionInstantiationContext.agent_type().type
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid

from ._agent_id import AgentId
from ._agent_type import AgentType
from ._subscription import Subscription
from ._topic import TopicId
from .exceptions import CantHandleException
Expand Down Expand Up @@ -30,9 +31,12 @@ class TypePrefixSubscription(Subscription):
agent_type (str): Agent type to handle this subscription
"""

def __init__(self, topic_type_prefix: str, agent_type: str):
def __init__(self, topic_type_prefix: str, agent_type: str | AgentType):
self._topic_type_prefix = topic_type_prefix
self._agent_type = agent_type
if isinstance(agent_type, AgentType):
self._agent_type = agent_type.type
else:
self._agent_type = agent_type
self._id = str(uuid.uuid4())

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid

from ._agent_id import AgentId
from ._agent_type import AgentType
from ._subscription import Subscription
from ._topic import TopicId
from .exceptions import CantHandleException
Expand Down Expand Up @@ -29,9 +30,12 @@ class TypeSubscription(Subscription):
agent_type (str): Agent type to handle this subscription
"""

def __init__(self, topic_type: str, agent_type: str):
def __init__(self, topic_type: str, agent_type: str | AgentType):
self._topic_type = topic_type
self._agent_type = agent_type
if isinstance(agent_type, AgentType):
self._agent_type = agent_type.type
else:
self._agent_type = agent_type
self._id = str(uuid.uuid4())

@property
Expand Down

0 comments on commit 5bd91fb

Please sign in to comment.