From 561db9595976d6a6eb6225a24040d6c9c7eeadac Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Thu, 27 Mar 2025 11:01:28 +0530 Subject: [PATCH] feat: cohere provider added --- bifrost.go | 2 + interfaces/provider.go | 38 +++++-- providers/cohere.go | 245 +++++++++++++++++++++++++++++++++++++++++ tests/account.go | 13 +++ tests/cohere_test.go | 78 +++++++++++++ tests/setup.go | 5 + 6 files changed, 371 insertions(+), 10 deletions(-) create mode 100644 providers/cohere.go create mode 100644 tests/cohere_test.go diff --git a/bifrost.go b/bifrost.go index f3c570c34..ae7454a31 100644 --- a/bifrost.go +++ b/bifrost.go @@ -45,6 +45,8 @@ func createProviderFromProviderKey(providerKey interfaces.SupportedModelProvider return providers.NewAnthropicProvider(config), nil case interfaces.Bedrock: return providers.NewBedrockProvider(config), nil + case interfaces.Cohere: + return providers.NewCohereProvider(config), nil default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } diff --git a/interfaces/provider.go b/interfaces/provider.go index a3cf5548f..4d08b3250 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -10,6 +10,13 @@ type LLMUsage struct { Latency *float64 `json:"latency"` } +type BilledLLMUsage struct { + PromptTokens *float64 `json:"prompt_tokens"` + CompletionTokens *float64 `json:"completion_tokens"` + SearchUnits *float64 `json:"search_units"` + Classifications *float64 `json:"classifications"` +} + // LLMInteractionCost represents cost information for LLM interactions type LLMInteractionCost struct { Input float64 `json:"input"` @@ -64,6 +71,14 @@ type ToolCall struct { Function *FunctionCall `json:"function"` } +type Citation struct { + Start *int `json:"start"` + End *int `json:"end"` + Text *string `json:"text"` + Sources *interface{} `json:"sources"` + Type *string `json:"type"` +} + // ModelChatMessageRole represents the role of a chat message type ModelChatMessageRole string @@ -81,6 +96,7 @@ type CompletionResponseChoice struct { Content string `json:"content"` Image json.RawMessage `json:"image"` ToolCalls *[]ToolCall `json:"tool_calls"` + Citations *[]Citation `json:"citation"` } // CompletionResultChoice represents a choice in the completion result @@ -120,16 +136,18 @@ type CompletionResult struct { Message string `json:"message"` Type string `json:"type"` } `json:"error"` - ID string `json:"id"` - Choices []CompletionResultChoice `json:"choices"` - ToolCallResult *interface{} `json:"tool_call_result"` - ToolCallResults *ToolCallResults `json:"toolCallResults"` - Provider SupportedModelProvider `json:"provider"` - Usage LLMUsage `json:"usage"` - Cost *LLMInteractionCost `json:"cost"` - Model string `json:"model"` - Created string `json:"created"` - Params *interface{} `json:"modelParams"` + ID string `json:"id"` + Choices []CompletionResultChoice `json:"choices"` + ChatHistory *[]CompletionResponseChoice `json:"chat_history"` + ToolCallResult *interface{} `json:"tool_call_result"` + ToolCallResults *ToolCallResults `json:"toolCallResults"` + Provider SupportedModelProvider `json:"provider"` + Usage LLMUsage `json:"usage"` + BilledUsage *BilledLLMUsage `json:"billed_usage"` + Cost *LLMInteractionCost `json:"cost"` + Model string `json:"model"` + Created string `json:"created"` + Params *interface{} `json:"modelParams"` Trace *struct { Input interface{} `json:"input"` Output interface{} `json:"output"` diff --git a/providers/cohere.go b/providers/cohere.go new file mode 100644 index 000000000..04062671f --- /dev/null +++ b/providers/cohere.go @@ -0,0 +1,245 @@ +package providers + +import ( + "bifrost/interfaces" + "bytes" + "encoding/json" + "fmt" + "net/http" + "slices" + "time" +) + +// CohereParameterDefinition represents a parameter definition for a Cohere tool +type CohereParameterDefinition struct { + Type string `json:"type"` + Description *string `json:"description,omitempty"` + Required bool `json:"required"` +} + +// CohereTool represents a tool definition for Cohere API +type CohereTool struct { + Name string `json:"name"` + Description string `json:"description"` + ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"` +} + +// CohereChatResponse represents the response from Cohere's chat API +type CohereChatResponse struct { + ResponseID string `json:"response_id"` + Text string `json:"text"` + GenerationID string `json:"generation_id"` + ChatHistory []struct { + Role interfaces.ModelChatMessageRole `json:"role"` + Message string `json:"message"` + } `json:"chat_history"` + FinishReason string `json:"finish_reason"` + Meta struct { + APIVersion struct { + Version string `json:"version"` + } `json:"api_version"` + BilledUnits struct { + InputTokens float64 `json:"input_tokens"` + OutputTokens float64 `json:"output_tokens"` + } `json:"billed_units"` + Tokens struct { + InputTokens float64 `json:"input_tokens"` + OutputTokens float64 `json:"output_tokens"` + } `json:"tokens"` + } `json:"meta"` +} + +// OpenAIProvider implements the Provider interface for OpenAI +type CohereProvider struct { + client *http.Client +} + +// NewOpenAIProvider creates a new OpenAI provider instance +func NewCohereProvider(config *interfaces.ProviderConfig) *CohereProvider { + return &CohereProvider{ + client: &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)}, + } +} + +func (provider *CohereProvider) GetProviderKey() interfaces.SupportedModelProvider { + return interfaces.Cohere +} + +func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { + return nil, fmt.Errorf("text completion is not supported by Cohere") +} + +func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { + startTime := time.Now() + + // Get the last message and chat history + lastMessage := messages[len(messages)-1] + chatHistory := messages[:len(messages)-1] + + // Transform chat history + var cohereHistory []map[string]interface{} + for _, msg := range chatHistory { + cohereHistory = append(cohereHistory, map[string]interface{}{ + "role": msg.Role, + "message": msg.Content, + }) + } + + preparedParams := PrepareParams(params) + + // Prepare request body + requestBody := MergeConfig(map[string]interface{}{ + "message": lastMessage.Content, + "chat_history": cohereHistory, + "model": model, + }, preparedParams) + + // Add tools if present + if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + var tools []CohereTool + for _, tool := range *params.Tools { + parameterDefinitions := make(map[string]CohereParameterDefinition) + if tool.Function.Parameters != nil { + paramsMap, ok := tool.Function.Parameters.(map[string]interface{}) + if ok { + if properties, ok := paramsMap["properties"].(map[string]interface{}); ok { + for name, prop := range properties { + propMap, ok := prop.(map[string]interface{}) + if ok { + paramDef := CohereParameterDefinition{ + Required: slices.Contains(paramsMap["required"].([]string), name), + } + + if typeStr, ok := propMap["type"].(string); ok { + paramDef.Type = typeStr + } + + if desc, ok := propMap["description"].(string); ok { + paramDef.Description = &desc + } + + parameterDefinitions[name] = paramDef + } + } + } + } + } + + tools = append(tools, CohereTool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + ParameterDefinitions: parameterDefinitions, + }) + } + requestBody["tools"] = tools + } + + // Marshal request body + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + // Create request + req, err := http.NewRequest("POST", "https://api.cohere.ai/v1/chat", bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + // Add headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+key) + + // Make request + resp, err := provider.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + defer resp.Body.Close() + + // Handle error response + if resp.StatusCode != http.StatusOK { + var errorResp struct { + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil { + return nil, fmt.Errorf("error decoding error response: %v", err) + } + return nil, fmt.Errorf("cohere error: %s", errorResp.Message) + } + + // Decode response + var response CohereChatResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("error decoding response: %v", err) + } + + // Transform tool calls if present + var toolCalls *[]interfaces.ToolCall + + // Calculate latency + latency := time.Since(startTime).Seconds() + + // Get role and content from the last message in chat history + var role interfaces.ModelChatMessageRole + var content string + if len(response.ChatHistory) > 0 { + lastMsg := response.ChatHistory[len(response.ChatHistory)-1] + role = lastMsg.Role + content = lastMsg.Message + } else { + role = interfaces.ModelChatMessageRole("assistant") + content = response.Text + } + + // Create completion result + result := &interfaces.CompletionResult{ + ID: response.ResponseID, + Choices: []interfaces.CompletionResultChoice{ + { + Index: 0, + Message: interfaces.CompletionResponseChoice{ + Role: role, + Content: content, + ToolCalls: toolCalls, + }, + StopReason: &response.FinishReason, + }, + }, + ChatHistory: convertChatHistory(response.ChatHistory), + Usage: interfaces.LLMUsage{ + PromptTokens: int(response.Meta.Tokens.InputTokens), + CompletionTokens: int(response.Meta.Tokens.OutputTokens), + TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens), + Latency: &latency, + }, + BilledUsage: &interfaces.BilledLLMUsage{ + PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens), + CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens), + }, + Model: model, + Provider: interfaces.Cohere, + } + + return result, nil +} + +// Helper function to convert chat history to the correct type +func convertChatHistory(history []struct { + Role interfaces.ModelChatMessageRole `json:"role"` + Message string `json:"message"` +}) *[]interfaces.CompletionResponseChoice { + converted := make([]interfaces.CompletionResponseChoice, len(history)) + for i, msg := range history { + converted[i] = interfaces.CompletionResponseChoice{ + Role: msg.Role, + Content: msg.Message, + } + } + return &converted +} + +// Helper function to create a pointer to a float64 +func float64Ptr(f float64) *float64 { + return &f +} diff --git a/tests/account.go b/tests/account.go index be0d14c58..3f35e069e 100644 --- a/tests/account.go +++ b/tests/account.go @@ -41,6 +41,14 @@ func (baseAccount *BaseAccount) GetKeysForProvider(provider interfaces.Provider) Weight: 1.0, }, }, nil + case interfaces.Cohere: + return []interfaces.Key{ + { + Value: os.Getenv("COHERE_API_KEY"), + Models: []string{"command-a-03-2025"}, + Weight: 1.0, + }, + }, nil default: return nil, fmt.Errorf("unsupported provider: %s", provider.GetProviderKey()) } @@ -64,6 +72,11 @@ func (baseAccount *BaseAccount) GetConcurrencyAndBufferSizeForProvider(provider Concurrency: 3, BufferSize: 10, }, nil + case interfaces.Cohere: + return &interfaces.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, nil default: return nil, fmt.Errorf("unsupported provider: %s", provider.GetProviderKey()) } diff --git a/tests/cohere_test.go b/tests/cohere_test.go new file mode 100644 index 000000000..0fcb6ec7b --- /dev/null +++ b/tests/cohere_test.go @@ -0,0 +1,78 @@ +package tests + +import ( + "bifrost" + "bifrost/interfaces" + "context" + "fmt" + "testing" + "time" +) + +// setupCohereRequests sends multiple test requests to Cohere +func setupCohereRequests(bifrost *bifrost.Bifrost) { + text := "Hello world!" + + ctx := context.Background() + + // Text completion request + go func() { + result, err := bifrost.TextCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{ + Model: "command-a-03-2025", + Input: interfaces.RequestInput{ + TextInput: &text, + }, + Params: nil, + }, ctx) + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("🐒 Text Completion Result:", result.Choices[0].Message.Content) + } + }() + + // Chat completion requests with different messages and delays + CohereMessages := []string{ + "Hello! How are you today?", + "What's the weather like?", + "Tell me a joke!", + "What's your favorite programming language?", + } + + for i, message := range CohereMessages { + delay := time.Duration(100*(i+1)) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{ + Model: "command-a-03-2025", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: nil, + }, ctx) + if err != nil { + fmt.Printf("Error in Cohere request %d: %v\n", index+1, err) + } else { + fmt.Printf("🐒 Chat Completion Result %d: %s\n", index+1, result.Choices[0].Message.Content) + } + }(message, delay, i) + } +} + +func TestCohere(t *testing.T) { + bifrost, err := getBifrost() + if err != nil { + t.Fatalf("Error initializing bifrost: %v", err) + return + } + + setupCohereRequests(bifrost) + + bifrost.Cleanup() +} diff --git a/tests/setup.go b/tests/setup.go index 7362cc54e..fc1923e15 100644 --- a/tests/setup.go +++ b/tests/setup.go @@ -66,6 +66,11 @@ func getBifrost() (*bifrost.Bifrost, error) { }, }, }, + interfaces.Cohere: { + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + }, } bifrost, err := bifrost.Init(&account, []interfaces.Plugin{plugin}, configs)