Skip to content

Commit cded812

Browse files
isthaisonKevinHuSh
andauthored
Feat: add OpenAI compatible API for agent (#6329)
### What problem does this PR solve? add openai agent _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
1 parent 2acb023 commit cded812

4 files changed

Lines changed: 433 additions & 17 deletions

File tree

api/apps/sdk/session.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717
import re
1818
import time
1919

20+
import tiktoken
2021
from flask import Response, jsonify, request
21-
22+
from api.db.services.conversation_service import ConversationService, iframe_completion
23+
from api.db.services.conversation_service import completion as rag_completion
24+
from api.db.services.canvas_service import completion as agent_completion, completionOpenAI
2225
from agent.canvas import Canvas
2326
from api.db import LLMType, StatusEnum
2427
from api.db.db_models import APIToken
2528
from api.db.services.api_service import API4ConversationService
2629
from api.db.services.canvas_service import UserCanvasService
27-
from api.db.services.canvas_service import completion as agent_completion
28-
from api.db.services.conversation_service import ConversationService, iframe_completion
29-
from api.db.services.conversation_service import completion as rag_completion
3030
from api.db.services.dialog_service import DialogService, ask, chat
3131
from api.db.services.file_service import FileService
3232
from api.db.services.knowledgebase_service import KnowledgebaseService
33-
from api.db.services.llm_service import LLMBundle
3433
from api.utils import get_uuid
35-
from api.utils.api_utils import get_error_data_result, get_result, token_required, validate_request
34+
from api.utils.api_utils import get_result, token_required, get_data_openai, get_error_data_result, validate_request
35+
from api.db.services.llm_service import LLMBundle
36+
3637

3738

3839
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@@ -71,14 +72,11 @@ def create_agent_session(tenant_id, agent_id):
7172
req = request.form
7273
files = request.files
7374
user_id = request.args.get("user_id", "")
74-
7575
e, cvs = UserCanvasService.get_by_id(agent_id)
7676
if not e:
7777
return get_error_data_result("Agent not found.")
78-
7978
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
8079
return get_error_data_result("You cannot access the agent.")
81-
8280
if not isinstance(cvs.dsl, str):
8381
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
8482

@@ -352,6 +350,40 @@ def streamed_response_generator(chat_id, dia, msg):
352350
}
353351
return jsonify(response)
354352

353+
@manager.route('/agents_openai/<agent_id>/chat/completions', methods=['POST']) # noqa: F821
354+
@validate_request("model", "messages") # noqa: F821
355+
@token_required
356+
def agents_completion_openai_compatibility (tenant_id, agent_id):
357+
req = request.json
358+
tiktokenenc = tiktoken.get_encoding("cl100k_base")
359+
messages = req.get("messages", [])
360+
if not messages:
361+
return get_error_data_result("You must provide at least one message.")
362+
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
363+
return get_error_data_result(f"You don't own the agent {agent_id}")
364+
365+
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
366+
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
367+
if not filtered_messages:
368+
return jsonify(get_data_openai(
369+
id=agent_id,
370+
content="No valid messages found (user or assistant).",
371+
finish_reason="stop",
372+
model=req.get("model", ""),
373+
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
374+
prompt_tokens=prompt_tokens,
375+
))
376+
377+
# Get the last user message as the question
378+
question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
379+
380+
if req.get("stream", True):
381+
return Response(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", ""), stream=True), mimetype="text/event-stream")
382+
else:
383+
# For non-streaming, just return the response directly
384+
response = next(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", ""), stream=False))
385+
return jsonify(response)
386+
355387

356388
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
357389
@token_required
@@ -364,9 +396,7 @@ def agent_completions(tenant_id, agent_id):
364396
dsl = cvs[0].dsl
365397
if not isinstance(dsl, str):
366398
dsl = json.dumps(dsl)
367-
# canvas = Canvas(dsl, tenant_id)
368-
# if canvas.get_preset_param():
369-
# req["question"] = ""
399+
370400
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
371401
if not conv:
372402
return get_error_data_result(f"You don't own the session {req['session_id']}")

api/db/services/canvas_service.py

Lines changed: 207 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
from api.db.services.common_service import CommonService
2525
from api.db.services.conversation_service import structure_answer
2626
from api.utils import get_uuid
27+
from api.utils.api_utils import get_data_openai
28+
import tiktoken
2729
from peewee import fn
28-
2930
class CanvasTemplateService(CommonService):
3031
model = CanvasTemplate
3132

@@ -100,14 +101,14 @@ def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
100101
]
101102
if keywords:
102103
angents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
103-
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
104+
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
104105
TenantPermission.TEAM.value)) | (
105106
cls.model.user_id == user_id)),
106107
(fn.LOWER(cls.model.title).contains(keywords.lower()))
107108
)
108109
else:
109110
angents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
110-
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
111+
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
111112
TenantPermission.TEAM.value)) | (
112113
cls.model.user_id == user_id))
113114
)
@@ -154,8 +155,6 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
154155
"dsl": cvs.dsl
155156
}
156157
API4ConversationService.save(**conv)
157-
158-
159158
conv = API4Conversation(**conv)
160159
else:
161160
e, conv = API4ConversationService.get_by_id(session_id)
@@ -221,3 +220,206 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
221220
API4ConversationService.append_message(conv.id, conv.to_dict())
222221
yield result
223222
break
223+
def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
224+
"""Main function for OpenAI-compatible completions, structured similarly to the completion function."""
225+
tiktokenenc = tiktoken.get_encoding("cl100k_base")
226+
e, cvs = UserCanvasService.get_by_id(agent_id)
227+
228+
if not e:
229+
yield get_data_openai(
230+
id=session_id,
231+
model=agent_id,
232+
content="**ERROR**: Agent not found."
233+
)
234+
return
235+
236+
if cvs.user_id != tenant_id:
237+
yield get_data_openai(
238+
id=session_id,
239+
model=agent_id,
240+
content="**ERROR**: You do not own the agent"
241+
)
242+
return
243+
244+
if not isinstance(cvs.dsl, str):
245+
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
246+
247+
canvas = Canvas(cvs.dsl, tenant_id)
248+
canvas.reset()
249+
message_id = str(uuid4())
250+
251+
# Handle new session creation
252+
if not session_id:
253+
query = canvas.get_preset_param()
254+
if query:
255+
for ele in query:
256+
if not ele["optional"]:
257+
if not kwargs.get(ele["key"]):
258+
yield get_data_openai(
259+
id=None,
260+
model=agent_id,
261+
content=f"`{ele['key']}` is required",
262+
completion_tokens=len(tiktokenenc.encode(f"`{ele['key']}` is required")),
263+
prompt_tokens=len(tiktokenenc.encode(question if question else ""))
264+
)
265+
return
266+
ele["value"] = kwargs[ele["key"]]
267+
if ele["optional"]:
268+
if kwargs.get(ele["key"]):
269+
ele["value"] = kwargs[ele['key']]
270+
else:
271+
if "value" in ele:
272+
ele.pop("value")
273+
274+
cvs.dsl = json.loads(str(canvas))
275+
session_id = get_uuid()
276+
conv = {
277+
"id": session_id,
278+
"dialog_id": cvs.id,
279+
"user_id": kwargs.get("user_id", "") if isinstance(kwargs, dict) else "",
280+
"message": [{"role": "assistant", "content": canvas.get_prologue(), "created_at": time.time()}],
281+
"source": "agent",
282+
"dsl": cvs.dsl
283+
}
284+
API4ConversationService.save(**conv)
285+
conv = API4Conversation(**conv)
286+
287+
# Handle existing session
288+
else:
289+
e, conv = API4ConversationService.get_by_id(session_id)
290+
if not e:
291+
yield get_data_openai(
292+
id=session_id,
293+
model=agent_id,
294+
content="**ERROR**: Session not found!"
295+
)
296+
return
297+
298+
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
299+
canvas.messages.append({"role": "user", "content": question, "id": message_id})
300+
canvas.add_user_input(question)
301+
302+
if not conv.message:
303+
conv.message = []
304+
conv.message.append({
305+
"role": "user",
306+
"content": question,
307+
"id": message_id
308+
})
309+
310+
if not conv.reference:
311+
conv.reference = []
312+
conv.reference.append({"chunks": [], "doc_aggs": []})
313+
314+
# Process request based on stream mode
315+
final_ans = {"reference": [], "content": ""}
316+
prompt_tokens = len(tiktokenenc.encode(str(question)))
317+
318+
if stream:
319+
try:
320+
completion_tokens = 0
321+
for ans in canvas.run(stream=True):
322+
if ans.get("running_status"):
323+
completion_tokens += len(tiktokenenc.encode(ans.get("content", "")))
324+
yield "data: " + json.dumps(
325+
get_data_openai(
326+
id=session_id,
327+
model=agent_id,
328+
content=ans["content"],
329+
object="chat.completion.chunk",
330+
completion_tokens=completion_tokens,
331+
prompt_tokens=prompt_tokens
332+
),
333+
ensure_ascii=False
334+
) + "\n\n"
335+
continue
336+
337+
for k in ans.keys():
338+
final_ans[k] = ans[k]
339+
340+
completion_tokens += len(tiktokenenc.encode(final_ans.get("content", "")))
341+
yield "data: " + json.dumps(
342+
get_data_openai(
343+
id=session_id,
344+
model=agent_id,
345+
content=final_ans["content"],
346+
object="chat.completion.chunk",
347+
finish_reason="stop",
348+
completion_tokens=completion_tokens,
349+
prompt_tokens=prompt_tokens
350+
),
351+
ensure_ascii=False
352+
) + "\n\n"
353+
354+
# Update conversation
355+
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "created_at": time.time(), "id": message_id})
356+
canvas.history.append(("assistant", final_ans["content"]))
357+
if final_ans.get("reference"):
358+
canvas.reference.append(final_ans["reference"])
359+
conv.dsl = json.loads(str(canvas))
360+
API4ConversationService.append_message(conv.id, conv.to_dict())
361+
362+
yield "data: [DONE]\n\n"
363+
364+
except Exception as e:
365+
traceback.print_exc()
366+
conv.dsl = json.loads(str(canvas))
367+
API4ConversationService.append_message(conv.id, conv.to_dict())
368+
yield "data: " + json.dumps(
369+
get_data_openai(
370+
id=session_id,
371+
model=agent_id,
372+
content="**ERROR**: " + str(e),
373+
finish_reason="stop",
374+
completion_tokens=len(tiktokenenc.encode("**ERROR**: " + str(e))),
375+
prompt_tokens=prompt_tokens
376+
),
377+
ensure_ascii=False
378+
) + "\n\n"
379+
yield "data: [DONE]\n\n"
380+
381+
else: # Non-streaming mode
382+
try:
383+
all_answer_content = ""
384+
for answer in canvas.run(stream=False):
385+
if answer.get("running_status"):
386+
continue
387+
388+
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
389+
final_ans["reference"] = answer.get("reference", [])
390+
all_answer_content += final_ans["content"]
391+
392+
final_ans["content"] = all_answer_content
393+
394+
# Update conversation
395+
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "created_at": time.time(), "id": message_id})
396+
canvas.history.append(("assistant", final_ans["content"]))
397+
if final_ans.get("reference"):
398+
canvas.reference.append(final_ans["reference"])
399+
conv.dsl = json.loads(str(canvas))
400+
API4ConversationService.append_message(conv.id, conv.to_dict())
401+
402+
# Return the response in OpenAI format
403+
yield get_data_openai(
404+
id=session_id,
405+
model=agent_id,
406+
content=final_ans["content"],
407+
finish_reason="stop",
408+
completion_tokens=len(tiktokenenc.encode(final_ans["content"])),
409+
prompt_tokens=prompt_tokens,
410+
param=canvas.get_preset_param() # Added param info like in completion
411+
)
412+
413+
except Exception as e:
414+
traceback.print_exc()
415+
conv.dsl = json.loads(str(canvas))
416+
API4ConversationService.append_message(conv.id, conv.to_dict())
417+
yield get_data_openai(
418+
id=session_id,
419+
model=agent_id,
420+
content="**ERROR**: " + str(e),
421+
finish_reason="stop",
422+
completion_tokens=len(tiktokenenc.encode("**ERROR**: " + str(e))),
423+
prompt_tokens=prompt_tokens
424+
)
425+

api/utils/api_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,46 @@ def get_parser_config(chunk_method, parser_config):
378378
return parser_config
379379

380380

381+
def get_data_openai(id=None,
382+
created=None,
383+
model=None,
384+
prompt_tokens= 0,
385+
completion_tokens=0,
386+
content = None,
387+
finish_reason= None,
388+
object="chat.completion",
389+
param=None,
390+
):
391+
392+
total_tokens= prompt_tokens + completion_tokens
393+
return {
394+
"id":f"{id}",
395+
"object": object,
396+
"created": int(time.time()) if created else None,
397+
"model": model,
398+
"param":param,
399+
"usage": {
400+
"prompt_tokens": prompt_tokens,
401+
"completion_tokens": completion_tokens,
402+
"total_tokens": total_tokens,
403+
"completion_tokens_details": {
404+
"reasoning_tokens": 0,
405+
"accepted_prediction_tokens": 0,
406+
"rejected_prediction_tokens": 0
407+
}
408+
},
409+
"choices": [
410+
{
411+
"message": {
412+
"role": "assistant",
413+
"content": content
414+
},
415+
"logprobs": None,
416+
"finish_reason": finish_reason,
417+
"index": 0
418+
}
419+
]
420+
}
381421
def valid_parser_config(parser_config):
382422
if not parser_config:
383423
return

0 commit comments

Comments
 (0)