Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Provider (vertex) for sonnet-3-7 #5

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compextAI-executor/anthropic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ def get_client(api_key):
def get_instructor_client(api_key):
return instructor.from_anthropic(Anthropic(api_key=api_key))

def chat_completion(api_key, system_prompt, model, messages, temperature, timeout, max_tokens, response_format, tools):
def chat_completion(api_keys:dict, system_prompt, model, messages, temperature, timeout, max_tokens, response_format, tools):
if response_format is None or response_format == {}:
client = get_client(api_key)
client = get_client(api_keys["anthropic"])
response = client.messages.create(
model=model,
system=system_prompt if system_prompt else NOT_GIVEN,
Expand All @@ -24,7 +24,7 @@ def chat_completion(api_key, system_prompt, model, messages, temperature, timeou
)
llm_response = response.model_dump_json()
else:
client = get_instructor_client(api_key)
client = get_instructor_client(api_keys["anthropic"])
response_model = create_pydantic_model_from_dict(
response_format["json_schema"]["name"],
response_format["json_schema"]["schema"]
Expand Down
19 changes: 14 additions & 5 deletions compextAI-executor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import openai_models as openai
import anthropic_models as anthropic
import json
import litellm_base as litellm
app = fastapi.FastAPI()


@app.get("/")
def read_root():
return {"pong"}
Expand All @@ -16,7 +18,7 @@ class ChatCompletionRequest(BaseModel):
"""
Request body for the chat completion endpoint.
"""
api_key: str
api_keys: dict
model: str
messages: list[dict]
temperature: float = 0.5
Expand All @@ -30,8 +32,7 @@ class ChatCompletionRequest(BaseModel):
@app.post("/chatcompletion/openai")
def chat_completion_openai(request: ChatCompletionRequest):
try:
response = openai.chat_completion(request.api_key, request.model, request.messages, request.temperature, request.timeout, request.max_completion_tokens, request.response_format, request.tools)
print(response)
response = openai.chat_completion(request.api_keys, request.model, request.messages, request.temperature, request.timeout, request.max_completion_tokens, request.response_format, request.tools)
return JSONResponse(status_code=200, content=json.loads(response))
except Exception as e:
print(e)
Expand All @@ -40,8 +41,16 @@ def chat_completion_openai(request: ChatCompletionRequest):
@app.post("/chatcompletion/anthropic")
def chat_completion_anthropic(request: ChatCompletionRequest):
try:
response = anthropic.chat_completion(request.api_key, request.system_prompt, request.model, request.messages, request.temperature, request.timeout, request.max_tokens, request.response_format, request.tools)
print(response)
response = anthropic.chat_completion(request.api_keys, request.system_prompt, request.model, request.messages, request.temperature, request.timeout, request.max_tokens, request.response_format, request.tools)
return JSONResponse(status_code=200, content=json.loads(response))
except Exception as e:
print(e)
return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/chatcompletion/litellm")
def chat_completion_litellm(request: ChatCompletionRequest):
try:
response = litellm.chat_completion(request.api_keys, request.model, request.messages, request.temperature, request.timeout, request.max_completion_tokens, request.response_format, request.tools)
return JSONResponse(status_code=200, content=json.loads(response))
except Exception as e:
print(e)
Expand Down
159 changes: 159 additions & 0 deletions compextAI-executor/litellm_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from litellm import Router
from litellm.utils import token_counter, get_model_info
import json
import litellm


litellm.vertex_location = "us-east5"
litellm.vertex_project = "dashwave"

# litellm.set_verbose = True

AZURE_LOCATION = "eastus"
AZURE_VERSION = "2024-08-01-preview"

def get_model_list(api_keys:dict):
return [
{
"model_name": "gpt4",
"litellm_params": {
"model": "gpt-4",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "o1",
"litellm_params": {
"model": "o1",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "o1-preview",
"litellm_params": {
"model": "o1-preview",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "o1-mini",
"litellm_params": {
"model": "o1-mini",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "gpt-4o",
"litellm_params": {
"model": "azure/gpt-4o",
"api_key": api_keys.get("azure", ""),
"api_base": api_keys.get("azure_endpoint", ""),
"api_version": AZURE_VERSION
}
},
{
"model_name": "gpt-4o",
"litellm_params": {
"model": "gpt-4o",
"api_key": api_keys.get("openai", "")
}
},
{
"model_name": "claude-3-5-sonnet",
"litellm_params": {
"model": "vertex_ai/claude-3-5-sonnet-v2@20241022",
"vertex_credentials": json.dumps(api_keys.get("google_service_account_creds", {})),
}
},
{
"model_name": "claude-3-5-sonnet",
"litellm_params": {
"model": "claude-3-5-sonnet-20240620",
"api_key": api_keys.get("anthropic", "")
}
},
{
#https://docs.litellm.ai/docs/providers/anthropic#usage---thinking--reasoning_content
"model_name": "claude-3-7-sonnet",
"litellm_params": {
"model": "claude-3-7-sonnet-20250219",
"api_key": api_keys.get("anthropic", "")
}
},
{
#https://console.cloud.google.com/vertex-ai/publishers/anthropic/model-garden/claude-3-7-sonnet?hl=en&project=dashwave
"model_name": "claude-3-7-sonnet",
"litellm_params": {
"model": "vertex_ai/claude-3-7-sonnet@20250219",
"vertex_credentials": json.dumps(api_keys.get("google_service_account_creds", {})),
}
},
]

def get_model_identifier(model_name:str):
if model_name.__contains__("claude-3-5-sonnet"):
model_name = "claude-3-5-sonnet-20240620"
if model_name.__contains__("claude-3-7-sonnet"):
model_name = "claude-3-7-sonnet-20250219"
elif model_name.__contains__("gpt-4o"):
model_name = "gpt-4o"
elif model_name.__contains__("gpt-4"):
model_name = "gpt-4"
elif model_name.__contains__("o1"):
model_name = "o1"
elif model_name.__contains__("o1-preview"):
model_name = "o1-preview"
elif model_name.__contains__("o1-mini"):
model_name = "o1-mini"
return model_name

def get_model_info_from_model_name(model_name:str):
model_name = get_model_identifier(model_name)
model_info = get_model_info(model_name)
return model_info

router = Router(
routing_strategy="latency-based-routing",
routing_strategy_args={
"ttl": 10,
"lowest_latency_buffer": 0.5
},
enable_pre_call_checks=True,
redis_host="redis",
redis_port=6379,
redis_password="mysecretpassword",
cache_responses=True,
cooldown_time=3600
)

def chat_completion(api_keys:dict, model_name:str, messages:list, temperature:float, timeout:int, max_completion_tokens:int, response_format:dict, tools:list[dict]):
router.set_model_list(get_model_list(api_keys))

model_info = get_model_info_from_model_name(model_name)
max_allowed_input_tokens = model_info["max_input_tokens"]

while True:
messages_tokens = token_counter(
model=get_model_identifier(model_name),
messages=messages,
)
if messages_tokens > max_allowed_input_tokens:
user_msg_indices = [i for i, message in enumerate(messages) if message["role"] == "user"]

# remove all the messages from top until the second user message
if len(user_msg_indices) > 1:
messages = messages[user_msg_indices[1]:]
else:
break

response = router.completion(
model=model_name,
messages=messages,
temperature=temperature,
timeout=timeout,
max_completion_tokens=max_completion_tokens if max_completion_tokens else None,
response_format=response_format if response_format else None,
tools=tools if tools else None
)

return response.model_dump_json()
6 changes: 3 additions & 3 deletions compextAI-executor/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def get_instructor_client(api_key):
api_key=api_key
))

def chat_completion(api_key:str, model:str, messages:list, temperature:float, timeout:int, max_completion_tokens:int, response_format:dict, tools:list[dict]):
def chat_completion(api_keys:dict, model:str, messages:list, temperature:float, timeout:int, max_completion_tokens:int, response_format:dict, tools:list[dict]):
if response_format is None or response_format == {}:
client = get_client(api_key)
client = get_client(api_keys["openai"])
response = client.chat.completions.create(
model=model,
messages=messages,
Expand All @@ -27,7 +27,7 @@ def chat_completion(api_key:str, model:str, messages:list, temperature:float, ti
)
llm_response = response.model_dump_json()
else:
client = get_instructor_client(api_key)
client = get_instructor_client(api_keys["openai"])
response_model = create_pydantic_model_from_dict(
response_format["json_schema"]["name"],
response_format["json_schema"]["schema"]
Expand Down
4 changes: 3 additions & 1 deletion compextAI-executor/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ markdown-it-py==3.0.0
MarkupSafe==3.0.2
mdurl==0.1.2
multidict==6.1.0
openai==1.54.1
propcache==0.2.1
pydantic==2.9.2
pydantic_core==2.23.4
Expand All @@ -40,3 +39,6 @@ typing_extensions==4.12.2
urllib3==2.2.3
uvicorn==0.32.0
yarl==1.18.3
litellm==1.63.0
google-cloud-aiplatform>=1.38.0
redis==5.2.1
2 changes: 1 addition & 1 deletion compextAI-executor/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@

model = create_pydantic_model_from_dict(schema["json_schema"]["name"], schema["json_schema"]["schema"])
# print all the fields
print(model.model_fields)
# print(model.model_fields)
41 changes: 26 additions & 15 deletions compextAI-server/controllers/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/burnerlee/compextAI/constants"
"github.com/burnerlee/compextAI/internal/logger"
"github.com/burnerlee/compextAI/internal/providers/chat"
"github.com/burnerlee/compextAI/internal/providers/chat/litellm"
"github.com/burnerlee/compextAI/models"
"gorm.io/gorm"
)
Expand All @@ -25,10 +26,19 @@ func ExecuteThread(db *gorm.DB, req *ExecuteThreadRequest) (interface{}, error)
threadExecutionParamsTemplate.SystemPrompt = req.ThreadExecutionSystemPrompt
}

chatProvider, err := chat.GetChatCompletionsProvider(threadExecutionParamsTemplate.Model)
if err != nil {
logger.GetLogger().Errorf("Error getting chat provider: %s: %v", threadExecutionParamsTemplate.Model, err)
return nil, err
var chatProvider chat.ChatCompletionsProvider
if threadExecutionParamsTemplate.UseLiteLLM {
chatProvider, err = chat.GetChatCompletionsProvider(litellm.LITELLM_IDENTIFIER)
if err != nil {
logger.GetLogger().Errorf("Error getting litellm chat provider: %v", err)
return nil, err
}
} else {
chatProvider, err = chat.GetChatCompletionsProvider(threadExecutionParamsTemplate.Model)
if err != nil {
logger.GetLogger().Errorf("Error getting chat provider: %s: %v", threadExecutionParamsTemplate.Model, err)
return nil, err
}
}

var messages []*models.Message
Expand Down Expand Up @@ -93,6 +103,7 @@ func ExecuteThread(db *gorm.DB, req *ExecuteThreadRequest) (interface{}, error)
logger.GetLogger().Errorf("Error marshalling execution message content: %v", err)
return nil, err
}

if err := models.CreateMessage(db, &models.Message{
ThreadID: req.ThreadID,
Role: "execution",
Expand Down Expand Up @@ -137,16 +148,6 @@ func ExecuteThread(db *gorm.DB, req *ExecuteThreadRequest) (interface{}, error)
}

func handleThreadExecutionError(db *gorm.DB, threadExecution *models.ThreadExecution, execErr error) {
errJson, jsonErr := json.Marshal(struct {
Error string `json:"error"`
}{
Error: execErr.Error(),
})
if jsonErr != nil {
logger.GetLogger().Errorf("Error marshalling error: %v", jsonErr)
return
}

executionTime := time.Since(threadExecution.CreatedAt).Seconds()

updatedThreadExecution := models.ThreadExecution{
Expand All @@ -155,10 +156,20 @@ func handleThreadExecutionError(db *gorm.DB, threadExecution *models.ThreadExecu
Identifier: threadExecution.Identifier,
},
Status: models.ThreadExecutionStatus_FAILED,
Output: errJson,
ExecutionTime: uint(executionTime),
}
errJson, jsonErr := json.Marshal(struct {
Error string `json:"error"`
}{
Error: execErr.Error(),
})
if jsonErr != nil {
logger.GetLogger().Errorf("Error marshalling error: %v", jsonErr)
models.UpdateThreadExecution(db, &updatedThreadExecution)
return
}

updatedThreadExecution.Output = errJson
models.UpdateThreadExecution(db, &updatedThreadExecution)
}

Expand Down
23 changes: 19 additions & 4 deletions compextAI-server/controllers/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,26 @@ func CreateMessages(db *gorm.DB, req *CreateMessageRequest) ([]*models.Message,
return nil, fmt.Errorf("failed to marshal content: %w", err)
}

toolCallsJsonBlob, err := json.Marshal(message.ToolCalls)
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to marshal tool calls: %w", err)
}

functionCallJsonBlob, err := json.Marshal(message.FunctionCall)
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to marshal function call: %w", err)
}

message := &models.Message{
ThreadID: req.ThreadID,
ContentMap: contentJsonBlob,
Role: message.Role,
Metadata: metadataJsonBlob,
ThreadID: req.ThreadID,
ContentMap: contentJsonBlob,
Role: message.Role,
Metadata: metadataJsonBlob,
ToolCallID: message.ToolCallID,
ToolCalls: toolCallsJsonBlob,
FunctionCall: functionCallJsonBlob,
}

if err := models.CreateMessage(tx, message); err != nil {
Expand Down
9 changes: 6 additions & 3 deletions compextAI-server/controllers/messages.types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ type CreateMessageRequest struct {
}

type CreateMessage struct {
Content interface{} `json:"content"`
Role string `json:"role"`
Metadata map[string]interface{} `json:"metadata"`
Content interface{} `json:"content"`
Role string `json:"role"`
ToolCallID string `json:"tool_call_id"`
Metadata map[string]interface{} `json:"metadata"`
ToolCalls interface{} `json:"tool_calls"`
FunctionCall interface{} `json:"function_call"`
}
Loading