Skip to content

Commit 8c20f8a

Browse files
Dttbdjameszyao
authored andcommitted
feat: add more params type support
1 parent cf1e006 commit 8c20f8a

File tree

2 files changed

+82
-41
lines changed

2 files changed

+82
-41
lines changed

taskingai/assistant/assistant.py

+48-36
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,28 @@
3333
DEFAULT_RETRIEVAL_CONFIG = RetrievalConfig(top_k=3, method=RetrievalMethod.USER_MESSAGE)
3434

3535

36+
def _get_assistant_dict_params(
37+
memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None,
38+
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
39+
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
40+
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
41+
):
42+
memory = memory if isinstance(memory, AssistantMemory) else (AssistantMemory(**memory) if memory else None)
43+
tools = [tool if isinstance(tool, AssistantTool) else AssistantTool(**tool) for tool in (tools or [])] or None
44+
retrievals = [
45+
retrieval if isinstance(retrieval, AssistantRetrieval) else AssistantRetrieval(**retrieval)
46+
for retrieval in (retrievals or [])
47+
] or None
48+
retrieval_configs = (
49+
retrieval_configs
50+
if isinstance(retrieval_configs, RetrievalConfig)
51+
else RetrievalConfig(**retrieval_configs)
52+
if retrieval_configs
53+
else None
54+
)
55+
return memory, tools, retrievals, retrieval_configs
56+
57+
3658
def list_assistants(
3759
order: str = "desc",
3860
limit: int = 20,
@@ -118,12 +140,12 @@ async def a_get_assistant(assistant_id: str) -> Assistant:
118140

119141
def create_assistant(
120142
model_id: str,
121-
memory: AssistantMemory,
143+
memory: Union[AssistantMemory, Dict[str, Any]],
122144
name: Optional[str] = None,
123145
description: Optional[str] = None,
124146
system_prompt_template: Optional[List[str]] = None,
125-
tools: Optional[List[AssistantTool]] = None,
126-
retrievals: Optional[List[AssistantRetrieval]] = None,
147+
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
148+
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
127149
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
128150
metadata: Optional[Dict[str, str]] = None,
129151
) -> Assistant:
@@ -140,12 +162,9 @@ def create_assistant(
140162
:param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512.
141163
:return: The created assistant object.
142164
"""
143-
if retrieval_configs:
144-
retrieval_configs = (
145-
retrieval_configs
146-
if isinstance(retrieval_configs, RetrievalConfig)
147-
else RetrievalConfig(**retrieval_configs)
148-
)
165+
memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params(
166+
memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs
167+
)
149168

150169
body = AssistantCreateRequest(
151170
model_id=model_id,
@@ -164,12 +183,12 @@ def create_assistant(
164183

165184
async def a_create_assistant(
166185
model_id: str,
167-
memory: AssistantMemory,
186+
memory: Union[AssistantMemory, Dict[str, Any]],
168187
name: Optional[str] = None,
169188
description: Optional[str] = None,
170189
system_prompt_template: Optional[List[str]] = None,
171-
tools: Optional[List[AssistantTool]] = None,
172-
retrievals: Optional[List[AssistantRetrieval]] = None,
190+
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
191+
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
173192
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
174193
metadata: Optional[Dict[str, str]] = None,
175194
) -> Assistant:
@@ -186,12 +205,9 @@ async def a_create_assistant(
186205
:param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512.
187206
:return: The created assistant object.
188207
"""
189-
if retrieval_configs:
190-
retrieval_configs = (
191-
retrieval_configs
192-
if isinstance(retrieval_configs, RetrievalConfig)
193-
else RetrievalConfig(**retrieval_configs)
194-
)
208+
memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params(
209+
memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs
210+
)
195211

196212
body = AssistantCreateRequest(
197213
model_id=model_id,
@@ -214,9 +230,9 @@ def update_assistant(
214230
name: Optional[str] = None,
215231
description: Optional[str] = None,
216232
system_prompt_template: Optional[List[str]] = None,
217-
memory: Optional[AssistantMemory] = None,
218-
tools: Optional[List[AssistantTool]] = None,
219-
retrievals: Optional[List[AssistantRetrieval]] = None,
233+
memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None,
234+
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
235+
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
220236
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
221237
metadata: Optional[Dict[str, str]] = None,
222238
) -> Assistant:
@@ -235,12 +251,10 @@ def update_assistant(
235251
:return: The updated assistant object.
236252
"""
237253

238-
if retrieval_configs:
239-
retrieval_configs = (
240-
retrieval_configs
241-
if isinstance(retrieval_configs, RetrievalConfig)
242-
else RetrievalConfig(**retrieval_configs)
243-
)
254+
memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params(
255+
memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs
256+
)
257+
244258
body = AssistantUpdateRequest(
245259
model_id=model_id,
246260
name=name,
@@ -262,9 +276,9 @@ async def a_update_assistant(
262276
name: Optional[str] = None,
263277
description: Optional[str] = None,
264278
system_prompt_template: Optional[List[str]] = None,
265-
memory: Optional[AssistantMemory] = None,
266-
tools: Optional[List[AssistantTool]] = None,
267-
retrievals: Optional[List[AssistantRetrieval]] = None,
279+
memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None,
280+
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
281+
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
268282
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
269283
metadata: Optional[Dict[str, str]] = None,
270284
) -> Assistant:
@@ -283,12 +297,10 @@ async def a_update_assistant(
283297
:return: The updated assistant object.
284298
"""
285299

286-
if retrieval_configs:
287-
retrieval_configs = (
288-
retrieval_configs
289-
if isinstance(retrieval_configs, RetrievalConfig)
290-
else RetrievalConfig(**retrieval_configs)
291-
)
300+
memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params(
301+
memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs
302+
)
303+
292304
body = AssistantUpdateRequest(
293305
model_id=model_id,
294306
name=name,

taskingai/inference/chat_completion.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, List, Dict, Union
1+
from typing import Any, Optional, List, Dict, Union
22
from ..client.stream import Stream, AsyncStream
33

44
from taskingai.client.models import *
@@ -44,12 +44,35 @@ def __init__(self, id: str, content: str):
4444
super().__init__(role=ChatCompletionRole.FUNCTION, id=id, content=content)
4545

4646

47+
def _get_completion_dict_params(
48+
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]],
49+
functions: Optional[List[Union[Function, Dict[str, Any]]]] = None,
50+
):
51+
def _build_message(message: Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]):
52+
if isinstance(message, Dict):
53+
if message["role"] == ChatCompletionRole.SYSTEM.value:
54+
return SystemMessage(**message)
55+
if message["role"] == ChatCompletionRole.USER.value:
56+
return UserMessage(**message)
57+
if message["role"] == ChatCompletionRole.ASSISTANT.value:
58+
return AssistantMessage(**message)
59+
if message["role"] == ChatCompletionRole.FUNCTION.value:
60+
return FunctionMessage(**message)
61+
return message
62+
63+
messages = [_build_message(message) for message in messages]
64+
functions = [
65+
function if isinstance(function, Function) else Function(**function) for function in (functions or [])
66+
] or None
67+
return messages, functions
68+
69+
4770
def chat_completion(
4871
model_id: str,
49-
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage]],
72+
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]],
5073
configs: Optional[Dict] = None,
5174
function_call: Optional[str] = None,
52-
functions: Optional[List[Function]] = None,
75+
functions: Optional[List[Union[Function, Dict[str, Any]]]] = None,
5376
stream: bool = False,
5477
) -> Union[ChatCompletion, Stream]:
5578
"""
@@ -63,6 +86,9 @@ def chat_completion(
6386
:param stream: Whether to request in stream mode.
6487
:return: The list of assistants.
6588
"""
89+
90+
messages, functions = _get_completion_dict_params(messages, functions)
91+
6692
# only add non-None parameters
6793
body = ChatCompletionRequest(
6894
model_id=model_id,
@@ -82,10 +108,10 @@ def chat_completion(
82108

83109
async def a_chat_completion(
84110
model_id: str,
85-
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage]],
111+
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]],
86112
configs: Optional[Dict] = None,
87113
function_call: Optional[str] = None,
88-
functions: Optional[List[Function]] = None,
114+
functions: Optional[List[Union[Function, Dict[str, Any]]]] = None,
89115
stream: bool = False,
90116
) -> Union[ChatCompletion, AsyncStream]:
91117
"""
@@ -99,6 +125,9 @@ async def a_chat_completion(
99125
:param stream: Whether to request in stream mode.
100126
:return: The list of assistants.
101127
"""
128+
129+
messages, functions = _get_completion_dict_params(messages, functions)
130+
102131
# only add non-None parameters
103132
body = ChatCompletionRequest(
104133
model_id=model_id,

0 commit comments

Comments
 (0)