Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compextAI-executor/anthropic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ def get_client(api_key):
def get_instructor_client(api_key):
return instructor.from_anthropic(Anthropic(api_key=api_key))

def chat_completion(api_key, system_prompt, model, messages, temperature, timeout, max_tokens, response_format, tools):
def chat_completion(api_keys:dict, system_prompt, model, messages, temperature, timeout, max_tokens, response_format, tools):
if response_format is None or response_format == {}:
client = get_client(api_key)
client = get_client(api_keys["anthropic"])
response = client.messages.create(
model=model,
system=system_prompt if system_prompt else NOT_GIVEN,
Expand All @@ -24,7 +24,7 @@ def chat_completion(api_key, system_prompt, model, messages, temperature, timeou
)
llm_response = response.model_dump_json()
else:
client = get_instructor_client(api_key)
client = get_instructor_client(api_keys["anthropic"])
response_model = create_pydantic_model_from_dict(
response_format["json_schema"]["name"],
response_format["json_schema"]["schema"]
Expand Down
19 changes: 14 additions & 5 deletions compextAI-executor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import openai_models as openai
import anthropic_models as anthropic
import json
import litellm_base as litellm
app = fastapi.FastAPI()


@app.get("/")
def read_root():
return {"pong"}
Expand All @@ -16,7 +18,7 @@ class ChatCompletionRequest(BaseModel):
"""
Request body for the chat completion endpoint.
"""
api_key: str
api_keys: dict
model: str
messages: list[dict]
temperature: float = 0.5
Expand All @@ -30,8 +32,7 @@ class ChatCompletionRequest(BaseModel):
@app.post("/chatcompletion/openai")
def chat_completion_openai(request: ChatCompletionRequest):
try:
response = openai.chat_completion(request.api_key, request.model, request.messages, request.temperature, request.timeout, request.max_completion_tokens, request.response_format, request.tools)
print(response)
response = openai.chat_completion(request.api_keys, request.model, request.messages, request.temperature, request.timeout, request.max_completion_tokens, request.response_format, request.tools)
return JSONResponse(status_code=200, content=json.loads(response))
except Exception as e:
print(e)
Expand All @@ -40,8 +41,16 @@ def chat_completion_openai(request: ChatCompletionRequest):
@app.post("/chatcompletion/anthropic")
def chat_completion_anthropic(request: ChatCompletionRequest):
try:
response = anthropic.chat_completion(request.api_key, request.system_prompt, request.model, request.messages, request.temperature, request.timeout, request.max_tokens, request.response_format, request.tools)
print(response)
response = anthropic.chat_completion(request.api_keys, request.system_prompt, request.model, request.messages, request.temperature, request.timeout, request.max_tokens, request.response_format, request.tools)
return JSONResponse(status_code=200, content=json.loads(response))
except Exception as e:
print(e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/chatcompletion/litellm")
def chat_completion_litellm(request: ChatCompletionRequest):
try:
response = litellm.chat_completion(request.api_keys, request.model, request.messages, request.temperature, request.timeout, request.max_completion_tokens, request.response_format, request.tools)
return JSONResponse(status_code=200, content=json.loads(response))
except Exception as e:
print(e)
Expand Down
159 changes: 159 additions & 0 deletions compextAI-executor/litellm_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from litellm import Router
from litellm.utils import token_counter, get_model_info
import json
import litellm


litellm.vertex_location = "us-east5"
litellm.vertex_project = "dashwave"

# litellm.set_verbose = True

AZURE_LOCATION = "eastus"
AZURE_VERSION = "2024-08-01-preview"

def get_model_list(api_keys:dict):
return [
{
"model_name": "gpt4",
"litellm_params": {
"model": "gpt-4",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "o1",
"litellm_params": {
"model": "o1",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "o1-preview",
"litellm_params": {
"model": "o1-preview",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "o1-mini",
"litellm_params": {
"model": "o1-mini",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "gpt-4o",
"litellm_params": {
"model": "azure/gpt-4o",
"api_key": api_keys.get("azure", ""),
"api_base": api_keys.get("azure_endpoint", ""),
"api_version": AZURE_VERSION
}
},
{
"model_name": "gpt-4o",
"litellm_params": {
"model": "gpt-4o",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "claude-3-5-sonnet",
"litellm_params": {
"model": "vertex_ai/claude-3-5-sonnet-v2@20241022",
"vertex_credentials": json.dumps(api_keys.get("google_service_account_creds", {})),
}
},
{
"model_name": "claude-3-5-sonnet",
"litellm_params": {
"model": "claude-3-5-sonnet-20240620",
"api_key": api_keys.get("anthropic", "")
}
},
{
#https://docs.litellm.ai/docs/providers/anthropic#usage---thinking--reasoning_content
"model_name": "claude-3-7-sonnet",
"litellm_params": {
"model": "claude-3-7-sonnet-20250219",
"api_key": api_keys.get("anthropic", "")
}
},
{
#https://console.cloud.google.com/vertex-ai/publishers/anthropic/model-garden/claude-3-7-sonnet?hl=en&project=dashwave
"model_name": "claude-3-7-sonnet",
"litellm_params": {
"model": "vertex_ai/claude-3-7-sonnet@20250219",
"vertex_credentials": json.dumps(api_keys.get("google_service_account_creds", {})),
}
},
]

def get_model_identifier(model_name:str):
if model_name.__contains__("claude-3-5-sonnet"):
model_name = "claude-3-5-sonnet-20240620"
if model_name.__contains__("claude-3-7-sonnet"):
model_name = "claude-3-7-sonnet-20250219"
elif model_name.__contains__("gpt-4o"):
model_name = "gpt-4o"
elif model_name.__contains__("gpt-4"):
model_name = "gpt-4"
elif model_name.__contains__("o1"):
model_name = "o1"
elif model_name.__contains__("o1-preview"):
model_name = "o1-preview"
elif model_name.__contains__("o1-mini"):
model_name = "o1-mini"
return model_name

def get_model_info_from_model_name(model_name:str):
model_name = get_model_identifier(model_name)
model_info = get_model_info(model_name)
return model_info

router = Router(
routing_strategy="latency-based-routing",
routing_strategy_args={
"ttl": 10,
"lowest_latency_buffer": 0.5
},
enable_pre_call_checks=True,
redis_host="redis",
redis_port=6379,
redis_password="mysecretpassword",
cache_responses=True,
cooldown_time=3600
)

def chat_completion(api_keys:dict, model_name:str, messages:list, temperature:float, timeout:int, max_completion_tokens:int, response_format:dict, tools:list[dict]):
router.set_model_list(get_model_list(api_keys))

model_info = get_model_info_from_model_name(model_name)
max_allowed_input_tokens = model_info["max_input_tokens"]

while True:
messages_tokens = token_counter(
model=get_model_identifier(model_name),
messages=messages,
)
if messages_tokens > max_allowed_input_tokens:
user_msg_indices = [i for i, message in enumerate(messages) if message["role"] == "user"]

# remove all the messages from top until the second user message
if len(user_msg_indices) > 1:
messages = messages[user_msg_indices[1]:]
else:
break

response = router.completion(
model=model_name,
messages=messages,
temperature=temperature,
timeout=timeout,
max_completion_tokens=max_completion_tokens if max_completion_tokens else None,
response_format=response_format if response_format else None,
tools=tools if tools else None
)

return response.model_dump_json()
6 changes: 3 additions & 3 deletions compextAI-executor/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def get_instructor_client(api_key):
api_key=api_key
))

def chat_completion(api_key:str, model:str, messages:list, temperature:float, timeout:int, max_completion_tokens:int, response_format:dict, tools:list[dict]):
def chat_completion(api_keys:dict, model:str, messages:list, temperature:float, timeout:int, max_completion_tokens:int, response_format:dict, tools:list[dict]):
if response_format is None or response_format == {}:
client = get_client(api_key)
client = get_client(api_keys["openai"])
response = client.chat.completions.create(
model=model,
messages=messages,
Expand All @@ -27,7 +27,7 @@ def chat_completion(api_key:str, model:str, messages:list, temperature:float, ti
)
llm_response = response.model_dump_json()
else:
client = get_instructor_client(api_key)
client = get_instructor_client(api_keys["openai"])
response_model = create_pydantic_model_from_dict(
response_format["json_schema"]["name"],
response_format["json_schema"]["schema"]
Expand Down
4 changes: 3 additions & 1 deletion compextAI-executor/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ markdown-it-py==3.0.0
MarkupSafe==3.0.2
mdurl==0.1.2
multidict==6.1.0
openai==1.54.1
propcache==0.2.1
pydantic==2.9.2
pydantic_core==2.23.4
Expand All @@ -40,3 +39,6 @@ typing_extensions==4.12.2
urllib3==2.2.3
uvicorn==0.32.0
yarl==1.18.3
litellm==1.63.0
google-cloud-aiplatform>=1.38.0
redis==5.2.1
2 changes: 1 addition & 1 deletion compextAI-executor/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@

model = create_pydantic_model_from_dict(schema["json_schema"]["name"], schema["json_schema"]["schema"])
# print all the fields
print(model.model_fields)
# print(model.model_fields)
41 changes: 26 additions & 15 deletions compextAI-server/controllers/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/burnerlee/compextAI/constants"
"github.com/burnerlee/compextAI/internal/logger"
"github.com/burnerlee/compextAI/internal/providers/chat"
"github.com/burnerlee/compextAI/internal/providers/chat/litellm"
"github.com/burnerlee/compextAI/models"
"gorm.io/gorm"
)
Expand All @@ -25,10 +26,19 @@ func ExecuteThread(db *gorm.DB, req *ExecuteThreadRequest) (interface{}, error)
threadExecutionParamsTemplate.SystemPrompt = req.ThreadExecutionSystemPrompt
}

chatProvider, err := chat.GetChatCompletionsProvider(threadExecutionParamsTemplate.Model)
if err != nil {
logger.GetLogger().Errorf("Error getting chat provider: %s: %v", threadExecutionParamsTemplate.Model, err)
return nil, err
var chatProvider chat.ChatCompletionsProvider
if threadExecutionParamsTemplate.UseLiteLLM {
chatProvider, err = chat.GetChatCompletionsProvider(litellm.LITELLM_IDENTIFIER)
if err != nil {
logger.GetLogger().Errorf("Error getting litellm chat provider: %v", err)
return nil, err
}
} else {
chatProvider, err = chat.GetChatCompletionsProvider(threadExecutionParamsTemplate.Model)
if err != nil {
logger.GetLogger().Errorf("Error getting chat provider: %s: %v", threadExecutionParamsTemplate.Model, err)
return nil, err
}
}

var messages []*models.Message
Expand Down Expand Up @@ -93,6 +103,7 @@ func ExecuteThread(db *gorm.DB, req *ExecuteThreadRequest) (interface{}, error)
logger.GetLogger().Errorf("Error marshalling execution message content: %v", err)
return nil, err
}

if err := models.CreateMessage(db, &models.Message{
ThreadID: req.ThreadID,
Role: "execution",
Expand Down Expand Up @@ -137,16 +148,6 @@ func ExecuteThread(db *gorm.DB, req *ExecuteThreadRequest) (interface{}, error)
}

func handleThreadExecutionError(db *gorm.DB, threadExecution *models.ThreadExecution, execErr error) {
errJson, jsonErr := json.Marshal(struct {
Error string `json:"error"`
}{
Error: execErr.Error(),
})
if jsonErr != nil {
logger.GetLogger().Errorf("Error marshalling error: %v", jsonErr)
return
}

executionTime := time.Since(threadExecution.CreatedAt).Seconds()

updatedThreadExecution := models.ThreadExecution{
Expand All @@ -155,10 +156,20 @@ func handleThreadExecutionError(db *gorm.DB, threadExecution *models.ThreadExecu
Identifier: threadExecution.Identifier,
},
Status: models.ThreadExecutionStatus_FAILED,
Output: errJson,
ExecutionTime: uint(executionTime),
}
errJson, jsonErr := json.Marshal(struct {
Error string `json:"error"`
}{
Error: execErr.Error(),
})
if jsonErr != nil {
logger.GetLogger().Errorf("Error marshalling error: %v", jsonErr)
models.UpdateThreadExecution(db, &updatedThreadExecution)
return
}

updatedThreadExecution.Output = errJson
models.UpdateThreadExecution(db, &updatedThreadExecution)
}

Expand Down
23 changes: 19 additions & 4 deletions compextAI-server/controllers/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,26 @@ func CreateMessages(db *gorm.DB, req *CreateMessageRequest) ([]*models.Message,
return nil, fmt.Errorf("failed to marshal content: %w", err)
}

toolCallsJsonBlob, err := json.Marshal(message.ToolCalls)
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to marshal tool calls: %w", err)
}

functionCallJsonBlob, err := json.Marshal(message.FunctionCall)
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to marshal function call: %w", err)
}

message := &models.Message{
ThreadID: req.ThreadID,
ContentMap: contentJsonBlob,
Role: message.Role,
Metadata: metadataJsonBlob,
ThreadID: req.ThreadID,
ContentMap: contentJsonBlob,
Role: message.Role,
Metadata: metadataJsonBlob,
ToolCallID: message.ToolCallID,
ToolCalls: toolCallsJsonBlob,
FunctionCall: functionCallJsonBlob,
}

if err := models.CreateMessage(tx, message); err != nil {
Expand Down
9 changes: 6 additions & 3 deletions compextAI-server/controllers/messages.types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ type CreateMessageRequest struct {
}

type CreateMessage struct {
Content interface{} `json:"content"`
Role string `json:"role"`
Metadata map[string]interface{} `json:"metadata"`
Content interface{} `json:"content"`
Role string `json:"role"`
ToolCallID string `json:"tool_call_id"`
Metadata map[string]interface{} `json:"metadata"`
ToolCalls interface{} `json:"tool_calls"`
FunctionCall interface{} `json:"function_call"`
}
Loading