|
24 | 24 | from api.db.services.common_service import CommonService |
25 | 25 | from api.db.services.conversation_service import structure_answer |
26 | 26 | from api.utils import get_uuid |
| 27 | +from api.utils.api_utils import get_data_openai |
| 28 | +import tiktoken |
27 | 29 | from peewee import fn |
28 | | - |
29 | 30 | class CanvasTemplateService(CommonService): |
30 | 31 | model = CanvasTemplate |
31 | 32 |
|
@@ -100,14 +101,14 @@ def get_by_tenant_ids(cls, joined_tenant_ids, user_id, |
100 | 101 | ] |
101 | 102 | if keywords: |
102 | 103 | angents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( |
103 | | - ((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission == |
| 104 | + ((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission == |
104 | 105 | TenantPermission.TEAM.value)) | ( |
105 | 106 | cls.model.user_id == user_id)), |
106 | 107 | (fn.LOWER(cls.model.title).contains(keywords.lower())) |
107 | 108 | ) |
108 | 109 | else: |
109 | 110 | angents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( |
110 | | - ((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission == |
| 111 | + ((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission == |
111 | 112 | TenantPermission.TEAM.value)) | ( |
112 | 113 | cls.model.user_id == user_id)) |
113 | 114 | ) |
@@ -154,8 +155,6 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw |
154 | 155 | "dsl": cvs.dsl |
155 | 156 | } |
156 | 157 | API4ConversationService.save(**conv) |
157 | | - |
158 | | - |
159 | 158 | conv = API4Conversation(**conv) |
160 | 159 | else: |
161 | 160 | e, conv = API4ConversationService.get_by_id(session_id) |
@@ -221,3 +220,206 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw |
221 | 220 | API4ConversationService.append_message(conv.id, conv.to_dict()) |
222 | 221 | yield result |
223 | 222 | break |
| 223 | +def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): |
| 224 | + """Main function for OpenAI-compatible completions, structured similarly to the completion function.""" |
| 225 | + tiktokenenc = tiktoken.get_encoding("cl100k_base") |
| 226 | + e, cvs = UserCanvasService.get_by_id(agent_id) |
| 227 | + |
| 228 | + if not e: |
| 229 | + yield get_data_openai( |
| 230 | + id=session_id, |
| 231 | + model=agent_id, |
| 232 | + content="**ERROR**: Agent not found." |
| 233 | + ) |
| 234 | + return |
| 235 | + |
| 236 | + if cvs.user_id != tenant_id: |
| 237 | + yield get_data_openai( |
| 238 | + id=session_id, |
| 239 | + model=agent_id, |
| 240 | + content="**ERROR**: You do not own the agent" |
| 241 | + ) |
| 242 | + return |
| 243 | + |
| 244 | + if not isinstance(cvs.dsl, str): |
| 245 | + cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) |
| 246 | + |
| 247 | + canvas = Canvas(cvs.dsl, tenant_id) |
| 248 | + canvas.reset() |
| 249 | + message_id = str(uuid4()) |
| 250 | + |
| 251 | + # Handle new session creation |
| 252 | + if not session_id: |
| 253 | + query = canvas.get_preset_param() |
| 254 | + if query: |
| 255 | + for ele in query: |
| 256 | + if not ele["optional"]: |
| 257 | + if not kwargs.get(ele["key"]): |
| 258 | + yield get_data_openai( |
| 259 | + id=None, |
| 260 | + model=agent_id, |
| 261 | + content=f"`{ele['key']}` is required", |
| 262 | + completion_tokens=len(tiktokenenc.encode(f"`{ele['key']}` is required")), |
| 263 | + prompt_tokens=len(tiktokenenc.encode(question if question else "")) |
| 264 | + ) |
| 265 | + return |
| 266 | + ele["value"] = kwargs[ele["key"]] |
| 267 | + if ele["optional"]: |
| 268 | + if kwargs.get(ele["key"]): |
| 269 | + ele["value"] = kwargs[ele['key']] |
| 270 | + else: |
| 271 | + if "value" in ele: |
| 272 | + ele.pop("value") |
| 273 | + |
| 274 | + cvs.dsl = json.loads(str(canvas)) |
| 275 | + session_id = get_uuid() |
| 276 | + conv = { |
| 277 | + "id": session_id, |
| 278 | + "dialog_id": cvs.id, |
| 279 | + "user_id": kwargs.get("user_id", "") if isinstance(kwargs, dict) else "", |
| 280 | + "message": [{"role": "assistant", "content": canvas.get_prologue(), "created_at": time.time()}], |
| 281 | + "source": "agent", |
| 282 | + "dsl": cvs.dsl |
| 283 | + } |
| 284 | + API4ConversationService.save(**conv) |
| 285 | + conv = API4Conversation(**conv) |
| 286 | + |
| 287 | + # Handle existing session |
| 288 | + else: |
| 289 | + e, conv = API4ConversationService.get_by_id(session_id) |
| 290 | + if not e: |
| 291 | + yield get_data_openai( |
| 292 | + id=session_id, |
| 293 | + model=agent_id, |
| 294 | + content="**ERROR**: Session not found!" |
| 295 | + ) |
| 296 | + return |
| 297 | + |
| 298 | + canvas = Canvas(json.dumps(conv.dsl), tenant_id) |
| 299 | + canvas.messages.append({"role": "user", "content": question, "id": message_id}) |
| 300 | + canvas.add_user_input(question) |
| 301 | + |
| 302 | + if not conv.message: |
| 303 | + conv.message = [] |
| 304 | + conv.message.append({ |
| 305 | + "role": "user", |
| 306 | + "content": question, |
| 307 | + "id": message_id |
| 308 | + }) |
| 309 | + |
| 310 | + if not conv.reference: |
| 311 | + conv.reference = [] |
| 312 | + conv.reference.append({"chunks": [], "doc_aggs": []}) |
| 313 | + |
| 314 | + # Process request based on stream mode |
| 315 | + final_ans = {"reference": [], "content": ""} |
| 316 | + prompt_tokens = len(tiktokenenc.encode(str(question))) |
| 317 | + |
| 318 | + if stream: |
| 319 | + try: |
| 320 | + completion_tokens = 0 |
| 321 | + for ans in canvas.run(stream=True): |
| 322 | + if ans.get("running_status"): |
| 323 | + completion_tokens += len(tiktokenenc.encode(ans.get("content", ""))) |
| 324 | + yield "data: " + json.dumps( |
| 325 | + get_data_openai( |
| 326 | + id=session_id, |
| 327 | + model=agent_id, |
| 328 | + content=ans["content"], |
| 329 | + object="chat.completion.chunk", |
| 330 | + completion_tokens=completion_tokens, |
| 331 | + prompt_tokens=prompt_tokens |
| 332 | + ), |
| 333 | + ensure_ascii=False |
| 334 | + ) + "\n\n" |
| 335 | + continue |
| 336 | + |
| 337 | + for k in ans.keys(): |
| 338 | + final_ans[k] = ans[k] |
| 339 | + |
| 340 | + completion_tokens += len(tiktokenenc.encode(final_ans.get("content", ""))) |
| 341 | + yield "data: " + json.dumps( |
| 342 | + get_data_openai( |
| 343 | + id=session_id, |
| 344 | + model=agent_id, |
| 345 | + content=final_ans["content"], |
| 346 | + object="chat.completion.chunk", |
| 347 | + finish_reason="stop", |
| 348 | + completion_tokens=completion_tokens, |
| 349 | + prompt_tokens=prompt_tokens |
| 350 | + ), |
| 351 | + ensure_ascii=False |
| 352 | + ) + "\n\n" |
| 353 | + |
| 354 | + # Update conversation |
| 355 | + canvas.messages.append({"role": "assistant", "content": final_ans["content"], "created_at": time.time(), "id": message_id}) |
| 356 | + canvas.history.append(("assistant", final_ans["content"])) |
| 357 | + if final_ans.get("reference"): |
| 358 | + canvas.reference.append(final_ans["reference"]) |
| 359 | + conv.dsl = json.loads(str(canvas)) |
| 360 | + API4ConversationService.append_message(conv.id, conv.to_dict()) |
| 361 | + |
| 362 | + yield "data: [DONE]\n\n" |
| 363 | + |
| 364 | + except Exception as e: |
| 365 | + traceback.print_exc() |
| 366 | + conv.dsl = json.loads(str(canvas)) |
| 367 | + API4ConversationService.append_message(conv.id, conv.to_dict()) |
| 368 | + yield "data: " + json.dumps( |
| 369 | + get_data_openai( |
| 370 | + id=session_id, |
| 371 | + model=agent_id, |
| 372 | + content="**ERROR**: " + str(e), |
| 373 | + finish_reason="stop", |
| 374 | + completion_tokens=len(tiktokenenc.encode("**ERROR**: " + str(e))), |
| 375 | + prompt_tokens=prompt_tokens |
| 376 | + ), |
| 377 | + ensure_ascii=False |
| 378 | + ) + "\n\n" |
| 379 | + yield "data: [DONE]\n\n" |
| 380 | + |
| 381 | + else: # Non-streaming mode |
| 382 | + try: |
| 383 | + all_answer_content = "" |
| 384 | + for answer in canvas.run(stream=False): |
| 385 | + if answer.get("running_status"): |
| 386 | + continue |
| 387 | + |
| 388 | + final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" |
| 389 | + final_ans["reference"] = answer.get("reference", []) |
| 390 | + all_answer_content += final_ans["content"] |
| 391 | + |
| 392 | + final_ans["content"] = all_answer_content |
| 393 | + |
| 394 | + # Update conversation |
| 395 | + canvas.messages.append({"role": "assistant", "content": final_ans["content"], "created_at": time.time(), "id": message_id}) |
| 396 | + canvas.history.append(("assistant", final_ans["content"])) |
| 397 | + if final_ans.get("reference"): |
| 398 | + canvas.reference.append(final_ans["reference"]) |
| 399 | + conv.dsl = json.loads(str(canvas)) |
| 400 | + API4ConversationService.append_message(conv.id, conv.to_dict()) |
| 401 | + |
| 402 | + # Return the response in OpenAI format |
| 403 | + yield get_data_openai( |
| 404 | + id=session_id, |
| 405 | + model=agent_id, |
| 406 | + content=final_ans["content"], |
| 407 | + finish_reason="stop", |
| 408 | + completion_tokens=len(tiktokenenc.encode(final_ans["content"])), |
| 409 | + prompt_tokens=prompt_tokens, |
| 410 | + param=canvas.get_preset_param() # Added param info like in completion |
| 411 | + ) |
| 412 | + |
| 413 | + except Exception as e: |
| 414 | + traceback.print_exc() |
| 415 | + conv.dsl = json.loads(str(canvas)) |
| 416 | + API4ConversationService.append_message(conv.id, conv.to_dict()) |
| 417 | + yield get_data_openai( |
| 418 | + id=session_id, |
| 419 | + model=agent_id, |
| 420 | + content="**ERROR**: " + str(e), |
| 421 | + finish_reason="stop", |
| 422 | + completion_tokens=len(tiktokenenc.encode("**ERROR**: " + str(e))), |
| 423 | + prompt_tokens=prompt_tokens |
| 424 | + ) |
| 425 | + |
0 commit comments