diff --git a/core/providers/ollama/chat.go b/core/providers/ollama/chat.go new file mode 100644 index 000000000..3fdb9db89 --- /dev/null +++ b/core/providers/ollama/chat.go @@ -0,0 +1,223 @@ +// Package ollama implements the Ollama provider using native Ollama APIs. +// This file contains converters for chat completion requests and responses. +package ollama + +import ( + "github.com/maximhq/bifrost/core/schemas" +) + +// ToOllamaChatRequest converts a Bifrost chat request to Ollama native format. +func ToOllamaChatRequest(bifrostReq *schemas.BifrostChatRequest) *OllamaChatRequest { + if bifrostReq == nil || bifrostReq.Input == nil { + return nil + } + + ollamaReq := &OllamaChatRequest{ + Model: bifrostReq.Model, + Messages: convertMessagesToOllama(bifrostReq.Input), + } + + // Convert parameters + if bifrostReq.Params != nil { + options := &OllamaOptions{} + hasOptions := false + + // Map standard parameters + if bifrostReq.Params.MaxCompletionTokens != nil { + options.NumPredict = bifrostReq.Params.MaxCompletionTokens + hasOptions = true + } + if bifrostReq.Params.Temperature != nil { + options.Temperature = bifrostReq.Params.Temperature + hasOptions = true + } + if bifrostReq.Params.TopP != nil { + options.TopP = bifrostReq.Params.TopP + hasOptions = true + } + if bifrostReq.Params.PresencePenalty != nil { + options.PresencePenalty = bifrostReq.Params.PresencePenalty + hasOptions = true + } + if bifrostReq.Params.FrequencyPenalty != nil { + options.FrequencyPenalty = bifrostReq.Params.FrequencyPenalty + hasOptions = true + } + if bifrostReq.Params.Stop != nil { + options.Stop = bifrostReq.Params.Stop + hasOptions = true + } + if bifrostReq.Params.Seed != nil { + options.Seed = bifrostReq.Params.Seed + hasOptions = true + } + + // Handle extra parameters for Ollama-specific fields + if bifrostReq.Params.ExtraParams != nil { + // Top-k sampling + if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok { + options.TopK = topK + hasOptions = true + } + + // Context window size + if numCtx, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["num_ctx"]); ok { + options.NumCtx = numCtx + hasOptions = true + } + + // Repeat penalty + if repeatPenalty, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["repeat_penalty"]); ok { + options.RepeatPenalty = repeatPenalty + hasOptions = true + } + + // Repeat last N + if repeatLastN, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["repeat_last_n"]); ok { + options.RepeatLastN = repeatLastN + hasOptions = true + } + + // Mirostat sampling + if mirostat, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["mirostat"]); ok { + options.Mirostat = mirostat + hasOptions = true + } + if mirostatEta, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["mirostat_eta"]); ok { + options.MirostatEta = mirostatEta + hasOptions = true + } + if mirostatTau, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["mirostat_tau"]); ok { + options.MirostatTau = mirostatTau + hasOptions = true + } + + // TFS-Z sampling + if tfsZ, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["tfs_z"]); ok { + options.TfsZ = tfsZ + hasOptions = true + } + + // Typical-P sampling + if typicalP, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["typical_p"]); ok { + options.TypicalP = typicalP + hasOptions = true + } + + // Performance options + if numBatch, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["num_batch"]); ok { + options.NumBatch = numBatch + hasOptions = true + } + if numGPU, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["num_gpu"]); ok { + options.NumGPU = numGPU + hasOptions = true + } + if numThread, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["num_thread"]); ok { + options.NumThread = numThread + hasOptions = true + } + + // Keep-alive duration + if keepAlive, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["keep_alive"]); ok { + ollamaReq.KeepAlive = keepAlive + } + + // Enable thinking mode (for thinking-specific models) + if think, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["think"]); ok { + ollamaReq.Think = think + } + } + + if hasOptions { + ollamaReq.Options = options + } + + // Handle response format (JSON mode) + if bifrostReq.Params.ResponseFormat != nil { + if rf, ok := (*bifrostReq.Params.ResponseFormat).(map[string]interface{}); ok { + if t, exists := rf["type"]; exists && t == "json_object" { + ollamaReq.Format = "json" + } else if schema, exists := rf["json_schema"]; exists { + // Pass JSON schema directly for structured output + ollamaReq.Format = schema + } + } + } + + // Convert tools + if bifrostReq.Params.Tools != nil { + ollamaReq.Tools = convertToolsToOllama(bifrostReq.Params.Tools) + } + } + + return ollamaReq +} + +// ToBifrostChatRequest converts an Ollama chat request to Bifrost format. +// This is used for passthrough/reverse conversion scenarios. +func (r *OllamaChatRequest) ToBifrostChatRequest() *schemas.BifrostChatRequest { + if r == nil { + return nil + } + + provider, model := schemas.ParseModelString(r.Model, schemas.Ollama) + + bifrostReq := &schemas.BifrostChatRequest{ + Provider: provider, + Model: model, + Input: convertMessagesFromOllama(r.Messages), + } + + // Convert options to parameters + if r.Options != nil { + params := &schemas.ChatParameters{ + ExtraParams: make(map[string]interface{}), + } + + if r.Options.NumPredict != nil { + params.MaxCompletionTokens = r.Options.NumPredict + } + if r.Options.Temperature != nil { + params.Temperature = r.Options.Temperature + } + if r.Options.TopP != nil { + params.TopP = r.Options.TopP + } + if r.Options.Stop != nil { + params.Stop = r.Options.Stop + } + if r.Options.PresencePenalty != nil { + params.PresencePenalty = r.Options.PresencePenalty + } + if r.Options.FrequencyPenalty != nil { + params.FrequencyPenalty = r.Options.FrequencyPenalty + } + if r.Options.Seed != nil { + params.Seed = r.Options.Seed + } + + // Map Ollama-specific parameters to ExtraParams + if r.Options.TopK != nil { + params.ExtraParams["top_k"] = *r.Options.TopK + } + if r.Options.NumCtx != nil { + params.ExtraParams["num_ctx"] = *r.Options.NumCtx + } + if r.Options.RepeatPenalty != nil { + params.ExtraParams["repeat_penalty"] = *r.Options.RepeatPenalty + } + + bifrostReq.Params = params + } + + // Convert tools + if r.Tools != nil { + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ChatParameters{} + } + bifrostReq.Params.Tools = convertToolsFromOllama(r.Tools) + } + + return bifrostReq +} diff --git a/core/providers/ollama/embedding.go b/core/providers/ollama/embedding.go new file mode 100644 index 000000000..42784013d --- /dev/null +++ b/core/providers/ollama/embedding.go @@ -0,0 +1,118 @@ +// Package ollama implements the Ollama provider using native Ollama APIs. +// This file contains converters for embedding requests and responses. +package ollama + +import ( + "github.com/maximhq/bifrost/core/schemas" +) + +func ToOllamaEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *OllamaEmbeddingRequest { + if bifrostReq == nil { + return nil + } + + ollamaReq := &OllamaEmbeddingRequest{ + Model: bifrostReq.Model, + } + + // Handle input - Bifrost uses EmbeddingInput type + if bifrostReq.Input != nil { + if bifrostReq.Input.Text != nil { + ollamaReq.Input = *bifrostReq.Input.Text + } else if bifrostReq.Input.Texts != nil { + ollamaReq.Input = bifrostReq.Input.Texts + } + } + + // Handle extra parameters from Params + if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil { + // Truncate option + if truncate, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["truncate"]); ok { + ollamaReq.Truncate = truncate + } + + // Keep-alive duration + if keepAlive, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["keep_alive"]); ok { + ollamaReq.KeepAlive = keepAlive + } + + // Model options + options := &OllamaOptions{} + hasOptions := false + + if numCtx, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["num_ctx"]); ok { + options.NumCtx = numCtx + hasOptions = true + } + + if hasOptions { + ollamaReq.Options = options + } + } + + return ollamaReq +} + +// ToBifrostEmbeddingRequest converts an Ollama embedding request to Bifrost format. +// This is used for passthrough/reverse conversion scenarios. +func (r *OllamaEmbeddingRequest) ToBifrostEmbeddingRequest() *schemas.BifrostEmbeddingRequest { + if r == nil { + return nil + } + + provider, model := schemas.ParseModelString(r.Model, schemas.Ollama) + + bifrostReq := &schemas.BifrostEmbeddingRequest{ + Provider: provider, + Model: model, + } + + // Convert input to EmbeddingInput + if r.Input != nil { + input := &schemas.EmbeddingInput{} + converted := false + switch v := r.Input.(type) { + case string: + input.Text = &v + converted = true + case []string: + input.Texts = v + converted = true + case []interface{}: + ss := make([]string, 0, len(v)) + for _, it := range v { + s, ok := it.(string) + if !ok { + converted = false + break + } + ss = append(ss, s) + } + if len(ss) > 0 { + input.Texts = ss + converted = true + } + } + if converted { + bifrostReq.Input = input + } + } + + // Map Ollama-specific options back to extra params + if r.Truncate != nil || r.KeepAlive != nil || (r.Options != nil && r.Options.NumCtx != nil) { + bifrostReq.Params = &schemas.EmbeddingParameters{ + ExtraParams: make(map[string]interface{}), + } + if r.Truncate != nil { + bifrostReq.Params.ExtraParams["truncate"] = *r.Truncate + } + if r.KeepAlive != nil { + bifrostReq.Params.ExtraParams["keep_alive"] = *r.KeepAlive + } + if r.Options != nil && r.Options.NumCtx != nil { + bifrostReq.Params.ExtraParams["num_ctx"] = *r.Options.NumCtx + } + } + + return bifrostReq +} diff --git a/core/providers/ollama/models.go b/core/providers/ollama/models.go new file mode 100644 index 000000000..6d429719c --- /dev/null +++ b/core/providers/ollama/models.go @@ -0,0 +1,67 @@ +// Package ollama implements the Ollama provider using native Ollama APIs. +// This file contains converters for list models requests and responses. +package ollama + +import ( + "github.com/maximhq/bifrost/core/schemas" +) + +// ToOllamaModel converts a Bifrost model to Ollama format. +// Note: Ollama's /api/tags endpoint is GET-only and doesn't need a request body. +// This function is included for completeness and potential future use. +func ToOllamaModel(bifrostModel *schemas.Model) *OllamaModel { + if bifrostModel == nil { + return nil + } + + return &OllamaModel{ + Name: bifrostModel.ID, + Model: bifrostModel.ID, + } +} + +// ToBifrostModel converts an Ollama model to Bifrost format. +func (m *OllamaModel) ToBifrostModel() *schemas.Model { + if m == nil { + return nil + } + + created := m.ModifiedAt.Unix() + ownedBy := "ollama" + + return &schemas.Model{ + ID: m.Name, + Created: &created, + OwnedBy: &ownedBy, + } +} + +// GetModelInfo returns formatted model information for display. +func (m *OllamaModel) GetModelInfo() map[string]interface{} { + if m == nil { + return nil + } + + info := map[string]interface{}{ + "name": m.Name, + "model": m.Model, + "modified_at": m.ModifiedAt, + "size": m.Size, + "digest": m.Digest, + } + + if m.Details.Family != "" { + info["family"] = m.Details.Family + } + if m.Details.ParameterSize != "" { + info["parameter_size"] = m.Details.ParameterSize + } + if m.Details.QuantizationLevel != "" { + info["quantization_level"] = m.Details.QuantizationLevel + } + if m.Details.Format != "" { + info["format"] = m.Details.Format + } + + return info +} diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index 62012151d..ae79fa668 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -1,20 +1,36 @@ -// Package providers implements various LLM providers and their utility functions. -// This file contains the Ollama provider implementation. +// Package ollama implements the Ollama provider using native Ollama APIs. +// This file contains the main provider implementation for Ollama's native API. +// +// Ollama API Documentation: https://github.com/ollama/ollama/blob/main/docs/api.md +// +// Supported endpoints: +// - /api/chat - Chat completion +// - /api/embed - Embeddings +// - /api/tags - List models +// +// Key differences from OpenAI-compatible API: +// - Native endpoints instead of /v1/* paths +// - Newline-delimited JSON streaming instead of SSE +// - Different request/response structure +// - Options object for model parameters package ollama import ( + "bufio" "context" "fmt" + "net/http" "strings" + "sync" "time" - "github.com/maximhq/bifrost/core/providers/openai" + "github.com/bytedance/sonic" providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// OllamaProvider implements the Provider interface for Ollama's API. +// OllamaProvider implements the Provider interface for Ollama's native API. type OllamaProvider struct { logger schemas.Logger // Logger for provider operations client *fasthttp.Client // HTTP client for API requests @@ -23,6 +39,48 @@ type OllamaProvider struct { sendBackRawResponse bool // Whether to include raw response in BifrostResponse } +// Response pools for efficient memory usage +var ( + ollamaChatResponsePool = sync.Pool{ + New: func() interface{} { + return &OllamaChatResponse{} + }, + } + ollamaEmbeddingResponsePool = sync.Pool{ + New: func() interface{} { + return &OllamaEmbeddingResponse{} + }, + } +) + +// acquireOllamaChatResponse gets an Ollama chat response from the pool. +func acquireOllamaChatResponse() *OllamaChatResponse { + resp := ollamaChatResponsePool.Get().(*OllamaChatResponse) + *resp = OllamaChatResponse{} // Reset the struct + return resp +} + +// releaseOllamaChatResponse returns an Ollama chat response to the pool. +func releaseOllamaChatResponse(resp *OllamaChatResponse) { + if resp != nil { + ollamaChatResponsePool.Put(resp) + } +} + +// acquireOllamaEmbeddingResponse gets an Ollama embedding response from the pool. +func acquireOllamaEmbeddingResponse() *OllamaEmbeddingResponse { + resp := ollamaEmbeddingResponsePool.Get().(*OllamaEmbeddingResponse) + *resp = OllamaEmbeddingResponse{} // Reset the struct + return resp +} + +// releaseOllamaEmbeddingResponse returns an Ollama embedding response to the pool. +func releaseOllamaEmbeddingResponse(resp *OllamaEmbeddingResponse) { + if resp != nil { + ollamaEmbeddingResponsePool.Put(resp) + } +} + // NewOllamaProvider creates a new Ollama provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. @@ -37,19 +95,20 @@ func NewOllamaProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* MaxConnWaitTimeout: 10 * time.Second, } - // // Pre-warm response pools - // for range config.ConcurrencyAndBufferSize.Concurrency { - // ollamaResponsePool.Put(&schemas.BifrostResponse{}) - // } + // Pre-warm response pools + for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ { + ollamaChatResponsePool.Put(&OllamaChatResponse{}) + ollamaEmbeddingResponsePool.Put(&OllamaEmbeddingResponse{}) + } // Configure proxy if provided - client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) + providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") - // BaseURL is required for Ollama + // Set default BaseURL for local Ollama if not provided if config.NetworkConfig.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for ollama provider") + config.NetworkConfig.BaseURL = "http://localhost:11434" } return &OllamaProvider{ @@ -66,102 +125,409 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { return schemas.Ollama } -// ListModels performs a list models request to Ollama's API. +// completeRequest sends a request to Ollama's native API and handles the response. +// It constructs the API URL, sets up authentication, and processes the response. +// Returns the response body or an error if the request fails. +func (provider *OllamaProvider) completeRequest(ctx context.Context, jsonData []byte, url string, key string) ([]byte, time.Duration, *schemas.BifrostError) { + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(url) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + + // Uses Authorization: Bearer for Ollama Cloud / authenticated instances. + if key != "" { + req.Header.Set("Authorization", "Bearer "+key) + } + + req.SetBody(jsonData) + + // Send the request + latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, latency, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))) + return nil, latency, parseOllamaError(resp, provider.GetProviderKey()) + } + + body, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + } + + // Copy body before releasing response + bodyCopy := append([]byte(nil), body...) + + return bodyCopy, latency, nil +} + +// parseOllamaError parses an error response from Ollama's API. +func parseOllamaError(resp *fasthttp.Response, providerType schemas.ModelProvider) *schemas.BifrostError { + statusCode := resp.StatusCode() + body := resp.Body() + + var errorResp OllamaError + if err := sonic.Unmarshal(body, &errorResp); err == nil && errorResp.Error != "" { + return providerUtils.NewProviderAPIError(errorResp.Error, nil, statusCode, providerType, nil, nil) + } + + return providerUtils.NewProviderAPIError(string(body), nil, statusCode, providerType, nil, nil) +} + +// ListModels performs a list models request to Ollama's native API. +// Uses the /api/tags endpoint to fetch available models. func (provider *OllamaProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + // Use first key if available, otherwise empty (for local Ollama) + var key schemas.Key + if len(keys) > 0 { + key = keys[0] } - return openai.HandleOpenAIListModelsRequest( - ctx, - provider.client, - request, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"), - keys, - provider.networkConfig.ExtraHeaders, - provider.GetProviderKey(), - providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), - providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), - provider.logger, - ) + + return provider.listModelsByKey(ctx, key, request) +} + +// listModelsByKey performs a list models request for a single key. +func (provider *OllamaProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + + // Build URL - Ollama uses GET /api/tags + // Use GetPathFromContext to support path overrides + req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/api/tags")) + req.Header.SetMethod(http.MethodGet) + + // Set API key if provided + if key.Value != "" { + req.Header.Set("Authorization", "Bearer "+key.Value) + } + + // Make request + latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + return nil, parseOllamaError(resp, provider.GetProviderKey()) + } + + // Decode response body (handles gzip, etc.) + body, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + } + + // Parse response + var ollamaResponse OllamaListModelsResponse + _, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &ollamaResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Convert to Bifrost format + response := ollamaResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models) + response.ExtraFields.Latency = latency.Milliseconds() + + // Set raw response if enabled + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil } -// TextCompletion performs a text completion request to the Ollama API. +// TextCompletion is not directly supported by Ollama's native API. +// Use ChatCompletion instead for text generation. func (provider *OllamaProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { - return openai.HandleOpenAITextCompletionRequest( - ctx, - provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), - request, - key, - provider.networkConfig.ExtraHeaders, - provider.GetProviderKey(), - providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), - providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), - provider.logger, - ) + return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } -// TextCompletionStream performs a streaming text completion request to Ollama's API. -// It formats the request, sends it to Ollama, and processes the response. -// Returns a channel of BifrostStream objects or an error if the request fails. +// TextCompletionStream is not directly supported by Ollama's native API. +// Use ChatCompletionStream instead for text generation. func (provider *OllamaProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - return openai.HandleOpenAITextCompletionStreaming( - ctx, - provider.client, - provider.networkConfig.BaseURL+"/v1/completions", - request, - nil, - provider.networkConfig.ExtraHeaders, - providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), - providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), - provider.GetProviderKey(), - postHookRunner, - nil, - provider.logger, - ) + return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } -// ChatCompletion performs a chat completion request to the Ollama API. +// ChatCompletion performs a chat completion request to Ollama's native API. +// Uses the /api/chat endpoint with stream=false. func (provider *OllamaProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - return openai.HandleOpenAIChatCompletionRequest( + // Convert to Ollama format + jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, - provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, - key, - provider.networkConfig.ExtraHeaders, - providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), - providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), - provider.GetProviderKey(), - provider.logger, + func() (any, error) { + ollamaReq := ToOllamaChatRequest(request) + if ollamaReq != nil { + ollamaReq.Stream = schemas.Ptr(false) // Non-streaming request + } + return ollamaReq, nil + }, + provider.GetProviderKey()) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Make request + responseBody, latency, bifrostErr := provider.completeRequest( + ctx, + jsonData, + provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/api/chat"), + key.Value, ) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Parse response + response := acquireOllamaChatResponse() + defer releaseOllamaChatResponse(response) + + _, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Convert to Bifrost format + bifrostResponse := response.ToBifrostChatResponse(request.Model) + + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest + bifrostResponse.ExtraFields.Latency = latency.Milliseconds() + + // Set raw response if enabled + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + return bifrostResponse, nil } -// ChatCompletionStream performs a streaming chat completion request to the Ollama API. -// It supports real-time streaming of responses using Server-Sent Events (SSE). -// Uses Ollama's OpenAI-compatible streaming format. -// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +// ChatCompletionStream performs a streaming chat completion request to Ollama's native API. +// Uses newline-delimited JSON streaming format (not SSE). func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - // Use shared OpenAI-compatible streaming logic - return openai.HandleOpenAIChatCompletionStreaming( + // Check if the request is a redirect from ResponsesStream to ChatCompletionStream + isResponsesToChatCompletionsFallback := false + var responsesStreamState *schemas.ChatToResponsesStreamState + if ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback) != nil { + isResponsesToChatCompletionsFallbackValue, ok := ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback).(bool) + if ok && isResponsesToChatCompletionsFallbackValue { + isResponsesToChatCompletionsFallback = true + responsesStreamState = schemas.AcquireChatToResponsesStreamState() + defer schemas.ReleaseChatToResponsesStreamState(responsesStreamState) + } + } + + // Convert to Ollama format with streaming enabled + jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, - provider.client, - provider.networkConfig.BaseURL+"/v1/chat/completions", request, - nil, - provider.networkConfig.ExtraHeaders, - providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), - providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), - schemas.Ollama, - postHookRunner, - nil, - nil, - nil, - provider.logger, - ) + func() (any, error) { + ollamaReq := ToOllamaChatRequest(request) + if ollamaReq != nil { + ollamaReq.Stream = schemas.Ptr(true) // Enable streaming + } + return ollamaReq, nil + }, + provider.GetProviderKey()) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + resp.StreamBody = true // Enable streaming + defer fasthttp.ReleaseRequest(req) + + // Set headers + providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + req.SetRequestURI(provider.networkConfig.BaseURL + "/api/chat") + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + + if key.Value != "" { + req.Header.Set("Authorization", "Bearer "+key.Value) + } + + req.SetBody(jsonData) + + // Make the request with context support + // NOTE: fasthttp does not natively support context cancellation for streaming requests. + // MakeRequestWithContext only cancels waiting for the initial request, not the ongoing stream. + // The scanner loop below includes context cancellation checks to exit early when ctx is cancelled. + _, bifrostErr = providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + defer providerUtils.ReleaseStreamingResponse(resp) + return nil, bifrostErr + } + + // Check for HTTP errors + if resp.StatusCode() != fasthttp.StatusOK { + defer providerUtils.ReleaseStreamingResponse(resp) + return nil, parseOllamaError(resp, provider.GetProviderKey()) + } + + // Create response channel + responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + + // Start streaming in a goroutine + go func() { + defer close(responseChan) + defer providerUtils.ReleaseStreamingResponse(resp) + + if resp.BodyStream() == nil { + bifrostErr := providerUtils.NewBifrostOperationError( + "Provider returned an empty response", + fmt.Errorf("provider returned an empty response"), + provider.GetProviderKey(), + ) + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + return + } + + scanner := bufio.NewScanner(resp.BodyStream()) + // Increase buffer size for large responses + buf := make([]byte, 0, 1024*1024) + scanner.Buffer(buf, 10*1024*1024) + + chunkIndex := 0 + startTime := time.Now() + lastChunkTime := startTime + + for { + // Check for context cancellation before attempting to scan + select { + case <-ctx.Done(): + // Context was cancelled - exit the goroutine + bifrostErr := &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Type: schemas.Ptr(schemas.RequestCancelled), + Message: fmt.Sprintf("Stream cancelled or timed out by context: %v", ctx.Err()), + Error: ctx.Err(), + }, + } + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + return + default: + // Continue to scanner.Scan() + } + + // Attempt to scan next line + if !scanner.Scan() { + // Scanner reached end of stream or encountered an error + break + } + + line := scanner.Text() + + // Skip empty lines + if line == "" { + continue + } + + // Parse the JSON chunk (Ollama uses newline-delimited JSON) + var streamChunk OllamaStreamResponse + if err := sonic.Unmarshal([]byte(line), &streamChunk); err != nil { + provider.logger.Warn(fmt.Sprintf("Failed to parse Ollama stream chunk: %v", err)) + continue + } + + // Convert to Bifrost format + bifrostResponse, isDone := streamChunk.ToBifrostStreamResponse(chunkIndex) + if bifrostResponse != nil { + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.ChunkIndex = chunkIndex + chunkLatencyMs := time.Since(lastChunkTime).Milliseconds() + bifrostResponse.ExtraFields.Latency = chunkLatencyMs + + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostResponse.ExtraFields.RawResponse = line + } + + lastChunkTime = time.Now() + chunkIndex++ + + // Handle Responses API fallback conversion + if isResponsesToChatCompletionsFallback { + // Convert chat completion stream to responses stream + spreadResponses := bifrostResponse.ToBifrostResponsesStreamResponse(responsesStreamState) + for _, responsesResponse := range spreadResponses { + if responsesResponse == nil { + continue + } + + // Update ExtraFields for Responses API + responsesResponse.ExtraFields.RequestType = schemas.ResponsesStreamRequest + responsesResponse.ExtraFields.Provider = provider.GetProviderKey() + responsesResponse.ExtraFields.ModelRequested = request.Model + responsesResponse.ExtraFields.ChunkIndex = responsesResponse.SequenceNumber + + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + responsesResponse.ExtraFields.RawResponse = line + } + + // Send response chunk + if isDone && responsesResponse.Type == schemas.ResponsesStreamResponseTypeCompleted { + responsesResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, responsesResponse, nil, nil), responseChan) + return + } + + responsesResponse.ExtraFields.Latency = chunkLatencyMs + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, responsesResponse, nil, nil), responseChan) + } + } else { + // Regular chat completion stream + if isDone { + bifrostResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() + ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, bifrostResponse, nil, nil, nil), responseChan) + return + } + + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, bifrostResponse, nil, nil, nil), responseChan) + } + } + } + + if err := scanner.Err(); err != nil { + provider.logger.Warn(fmt.Sprintf("Error reading Ollama stream: %v", err)) + requestType := schemas.ChatCompletionStreamRequest + if isResponsesToChatCompletionsFallback { + requestType = schemas.ResponsesStreamRequest + } + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, requestType, provider.GetProviderKey(), request.Model, provider.logger) + } + }() + + return responseChan, nil } -// Responses performs a responses request to the Ollama API. +// Responses performs a responses request to Ollama's API. +// Falls back to ChatCompletion with conversion. func (provider *OllamaProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { @@ -176,7 +542,8 @@ func (provider *OllamaProvider) Responses(ctx context.Context, key schemas.Key, return response, nil } -// ResponsesStream performs a streaming responses request to the Ollama API. +// ResponsesStream performs a streaming responses request to Ollama's API. +// Falls back to ChatCompletionStream with conversion. func (provider *OllamaProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( @@ -187,20 +554,54 @@ func (provider *OllamaProvider) ResponsesStream(ctx context.Context, postHookRun ) } -// Embedding performs an embedding request to the Ollama API. +// Embedding performs an embedding request to Ollama's native API. +// Uses the /api/embed endpoint. func (provider *OllamaProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - return openai.HandleOpenAIEmbeddingRequest( + // Convert to Ollama format + jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, - provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, - key, - provider.networkConfig.ExtraHeaders, - provider.GetProviderKey(), - providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), - providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), - provider.logger, + func() (any, error) { return ToOllamaEmbeddingRequest(request), nil }, + provider.GetProviderKey()) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Make request + responseBody, latency, bifrostErr := provider.completeRequest( + ctx, + jsonData, + provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/api/embed"), + key.Value, ) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Parse response + response := acquireOllamaEmbeddingResponse() + defer releaseOllamaEmbeddingResponse(response) + + _, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Convert to Bifrost format + bifrostResponse := response.ToBifrostEmbeddingResponse(request.Model) + + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest + bifrostResponse.ExtraFields.Latency = latency.Milliseconds() + + // Set raw response if enabled + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + return bifrostResponse, nil } // Speech is not supported by the Ollama provider. diff --git a/core/providers/ollama/ollama_test.go b/core/providers/ollama/ollama_test.go index 090676261..1cdf3aa0d 100644 --- a/core/providers/ollama/ollama_test.go +++ b/core/providers/ollama/ollama_test.go @@ -6,10 +6,21 @@ import ( "testing" "github.com/maximhq/bifrost/core/internal/testutil" - "github.com/maximhq/bifrost/core/schemas" ) +// TestOllama runs comprehensive tests against a local or remote Ollama instance. +// +// Environment variables: +// - OLLAMA_BASE_URL: Required. The base URL of the Ollama instance (e.g., "http://localhost:11434") +// - OLLAMA_API_KEY: Optional. API key for authenticated Ollama Cloud instances +// - OLLAMA_MODEL: Optional. Model to test with (default: "llama3.2:latest") +// - OLLAMA_EMBEDDING_MODEL: Optional. Embedding model to test with (default: "nomic-embed-text:latest") +// +// The tests use Ollama's native API endpoints: +// - /api/chat for chat completion +// - /api/embed for embeddings +// - /api/tags for listing models func TestOllama(t *testing.T) { t.Parallel() if strings.TrimSpace(os.Getenv("OLLAMA_BASE_URL")) == "" { @@ -22,13 +33,24 @@ func TestOllama(t *testing.T) { } defer cancel() + // Get model names from environment or use defaults + chatModel := os.Getenv("OLLAMA_MODEL") + if chatModel == "" { + chatModel = "llama3.2:latest" + } + + embeddingModel := os.Getenv("OLLAMA_EMBEDDING_MODEL") + if embeddingModel == "" { + embeddingModel = "nomic-embed-text:latest" + } + testConfig := testutil.ComprehensiveTestConfig{ Provider: schemas.Ollama, - ChatModel: "llama3.1:latest", - TextModel: "", // Ollama doesn't support text completion in newer models - EmbeddingModel: "", // Ollama doesn't support embedding + ChatModel: chatModel, + TextModel: "", // Text completion uses chat endpoint in native API + EmbeddingModel: embeddingModel, Scenarios: testutil.TestScenarios{ - TextCompletion: false, // Not supported + TextCompletion: false, // Not supported - use chat instead SimpleChat: true, CompletionStream: true, MultiTurnConversation: true, @@ -37,16 +59,74 @@ func TestOllama(t *testing.T) { MultipleToolCalls: true, End2EndToolCalling: true, AutomaticFunctionCall: true, + ImageURL: false, // Ollama expects base64 images + ImageBase64: true, // Multimodal models support base64 images + MultipleImages: false, + CompleteEnd2End: true, + Embedding: true, // Native API supports embeddings + ListModels: true, + }, + } + + t.Run("OllamaTests", func(t *testing.T) { + testutil.RunAllComprehensiveTests(t, client, ctx, testConfig) + }) + client.Shutdown() +} + +// TestOllamaCloud tests Ollama Cloud with API key authentication. +// This test is separate to allow testing against Ollama Cloud specifically. +// +// Environment variables: +// - OLLAMA_CLOUD_URL: Required. The Ollama Cloud URL +// - OLLAMA_API_KEY: Required. API key for Ollama Cloud +// - OLLAMA_CLOUD_MODEL: Optional. Model to test with +func TestOllamaCloud(t *testing.T) { + t.Parallel() + cloudURL := os.Getenv("OLLAMA_CLOUD_URL") + apiKey := os.Getenv("OLLAMA_API_KEY") + + if cloudURL == "" || apiKey == "" { + t.Skip("Skipping Ollama Cloud tests because OLLAMA_CLOUD_URL or OLLAMA_API_KEY is not set") + } + + client, ctx, cancel, err := testutil.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + + // Get model name from environment or use default + chatModel := os.Getenv("OLLAMA_CLOUD_MODEL") + if chatModel == "" { + chatModel = "llama3.2:latest" + } + + testConfig := testutil.ComprehensiveTestConfig{ + Provider: schemas.Ollama, + ChatModel: chatModel, + TextModel: "", + EmbeddingModel: "", + Scenarios: testutil.TestScenarios{ + TextCompletion: false, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, + MultipleToolCalls: false, // May not be supported in cloud + End2EndToolCalling: true, + AutomaticFunctionCall: false, ImageURL: false, ImageBase64: false, MultipleImages: false, CompleteEnd2End: true, - Embedding: false, + Embedding: false, // May not be available ListModels: true, }, } - t.Run("OllamaTests", func(t *testing.T) { + t.Run("OllamaCloudTests", func(t *testing.T) { testutil.RunAllComprehensiveTests(t, client, ctx, testConfig) }) client.Shutdown() diff --git a/core/providers/ollama/types.go b/core/providers/ollama/types.go new file mode 100644 index 000000000..965572119 --- /dev/null +++ b/core/providers/ollama/types.go @@ -0,0 +1,520 @@ +// Package ollama implements the Ollama provider using native Ollama APIs. +// This file contains the type definitions for Ollama's native API. +package ollama + +import ( + "encoding/json" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// ==================== REQUEST TYPES ==================== + +// OllamaChatRequest represents an Ollama chat completion request using native API. +// See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion +type OllamaChatRequest struct { + Model string `json:"model"` // Required: Name of the model to use + Messages []OllamaMessage `json:"messages"` // Required: Messages for the chat + Tools []OllamaTool `json:"tools,omitempty"` // Optional: List of tools the model may use + Think *bool `json:"think,omitempty"` // Optional: Enable thinking (default: false) + Format interface{} `json:"format,omitempty"` // Optional: Format of the response (e.g., "json" or JSON schema) + Options *OllamaOptions `json:"options,omitempty"` // Optional: Model parameters + Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming (default: true) + KeepAlive *string `json:"keep_alive,omitempty"` // Optional: How long to keep model loaded (e.g., "5m", "0" to unload) +} + +// OllamaMessage represents a message in Ollama format. +type OllamaMessage struct { + Role string `json:"role"` // "system", "user", "assistant", or "tool" + Content string `json:"content"` // Message content + Thinking *string `json:"thinking,omitempty"` // Optional: Thinking content + Images []string `json:"images,omitempty"` // Optional: Base64 encoded images for multimodal models + ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"` // Optional: Tool calls made by the assistant + ToolName *string `json:"tool_name,omitempty"` // Optional: Tool name +} + +// OllamaToolCall represents a tool call in Ollama format. +type OllamaToolCall struct { + Function OllamaToolCallFunction `json:"function"` +} + +// OllamaToolCallFunction represents the function details of a tool call. +type OllamaToolCallFunction struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// OllamaTool represents a tool definition in Ollama format. +type OllamaTool struct { + Type string `json:"type"` // "function" + Function OllamaToolFunction `json:"function"` +} + +// OllamaToolFunction represents a function definition for tools. +type OllamaToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters *schemas.ToolFunctionParameters `json:"parameters,omitempty"` +} + +// OllamaOptions represents model parameters for Ollama requests. +// See: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values +type OllamaOptions struct { + // Generation parameters + NumPredict *int `json:"num_predict,omitempty"` // Maximum number of tokens to generate (similar to max_tokens) + Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-2.0) + TopP *float64 `json:"top_p,omitempty"` // Top-p sampling + TopK *int `json:"top_k,omitempty"` // Top-k sampling + Seed *int `json:"seed,omitempty"` // Random seed for reproducibility + Stop []string `json:"stop,omitempty"` // Stop sequences + + // Penalty parameters + RepeatPenalty *float64 `json:"repeat_penalty,omitempty"` // Repetition penalty + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Presence penalty + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Frequency penalty + RepeatLastN *int `json:"repeat_last_n,omitempty"` // Last N tokens for repeat penalty + + // Context and performance + NumCtx *int `json:"num_ctx,omitempty"` // Context window size + NumBatch *int `json:"num_batch,omitempty"` // Batch size for prompt processing + NumGPU *int `json:"num_gpu,omitempty"` // Number of layers to offload to GPU + NumThread *int `json:"num_thread,omitempty"` // Number of threads + + // Advanced parameters + Mirostat *int `json:"mirostat,omitempty"` // Mirostat sampling (0, 1, or 2) + MirostatEta *float64 `json:"mirostat_eta,omitempty"` // Mirostat learning rate + MirostatTau *float64 `json:"mirostat_tau,omitempty"` // Mirostat target entropy + TfsZ *float64 `json:"tfs_z,omitempty"` // Tail-free sampling + TypicalP *float64 `json:"typical_p,omitempty"` // Typical p sampling + + // Low-level parameters + UseMlock *bool `json:"use_mlock,omitempty"` // Lock model in memory + UseMmap *bool `json:"use_mmap,omitempty"` // Use memory mapping + Numa *bool `json:"numa,omitempty"` // Enable NUMA support +} + +// ==================== RESPONSE TYPES ==================== + +// OllamaChatResponse represents an Ollama chat completion response. +type OllamaChatResponse struct { + Model string `json:"model"` // Model used for generation + CreatedAt string `json:"created_at"` // Timestamp when response was created + Message *OllamaMessage `json:"message,omitempty"` // Generated message + Done bool `json:"done"` // Whether generation is complete + DoneReason *string `json:"done_reason,omitempty"` // Reason for completion ("stop", "length", "load", "unload") + TotalDuration *int64 `json:"total_duration,omitempty"` // Total time in nanoseconds + LoadDuration *int64 `json:"load_duration,omitempty"` // Time to load model in nanoseconds + PromptEvalCount *int `json:"prompt_eval_count,omitempty"` // Number of tokens in prompt + PromptEvalDuration *int64 `json:"prompt_eval_duration,omitempty"` // Time to evaluate prompt in nanoseconds + EvalCount *int `json:"eval_count,omitempty"` // Number of tokens generated + EvalDuration *int64 `json:"eval_duration,omitempty"` // Time to generate response in nanoseconds +} + +// ==================== EMBEDDING TYPES ==================== + +// OllamaEmbeddingRequest represents an Ollama embedding request. +// See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings +type OllamaEmbeddingRequest struct { + Model string `json:"model"` // Required: Name of the embedding model + Input interface{} `json:"input"` // Required: Text to embed (string or []string) + Truncate *bool `json:"truncate,omitempty"` // Optional: Truncate input to fit context length + Options *OllamaOptions `json:"options,omitempty"` // Optional: Model parameters + KeepAlive *string `json:"keep_alive,omitempty"` // Optional: How long to keep model loaded +} + +// OllamaEmbeddingResponse represents an Ollama embedding response. +type OllamaEmbeddingResponse struct { + Model string `json:"model"` // Model used for embedding + Embeddings [][]float64 `json:"embeddings"` // Generated embeddings + TotalDuration *int64 `json:"total_duration,omitempty"` // Total time in nanoseconds + LoadDuration *int64 `json:"load_duration,omitempty"` // Time to load model in nanoseconds + PromptEvalCount *int `json:"prompt_eval_count,omitempty"` // Number of tokens processed +} + +// ==================== LIST MODELS TYPES ==================== + +// OllamaListModelsResponse represents the response from /api/tags endpoint. +type OllamaListModelsResponse struct { + Models []OllamaModel `json:"models"` +} + +// OllamaModel represents a model in Ollama's list. +type OllamaModel struct { + Name string `json:"name"` // Model name (e.g., "llama3.2:latest") + Model string `json:"model"` // Model identifier + ModifiedAt time.Time `json:"modified_at"` // Last modified timestamp + Size int64 `json:"size"` // Model size in bytes + Digest string `json:"digest"` // Model digest + Details OllamaModelDetails `json:"details"` // Model details +} + +// OllamaModelDetails contains detailed information about a model. +type OllamaModelDetails struct { + ParentModel string `json:"parent_model,omitempty"` // Parent model name + Format string `json:"format"` // Model format (e.g., "gguf") + Family string `json:"family"` // Model family (e.g., "llama") + Families []string `json:"families,omitempty"` // Additional families + ParameterSize string `json:"parameter_size"` // Parameter count (e.g., "8B") + QuantizationLevel string `json:"quantization_level"` // Quantization (e.g., "Q4_0") +} + +// ==================== ERROR TYPES ==================== + +// OllamaError represents an error response from Ollama's API. +type OllamaError struct { + Error string `json:"error"` +} + +// ==================== STREAMING TYPES ==================== + +// OllamaStreamResponse represents a single streaming chunk from Ollama. +// It's the same structure as OllamaChatResponse but used during streaming. +type OllamaStreamResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message *OllamaMessage `json:"message,omitempty"` + Done bool `json:"done"` + DoneReason *string `json:"done_reason,omitempty"` + TotalDuration *int64 `json:"total_duration,omitempty"` + LoadDuration *int64 `json:"load_duration,omitempty"` + PromptEvalCount *int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration *int64 `json:"prompt_eval_duration,omitempty"` + EvalCount *int `json:"eval_count,omitempty"` + EvalDuration *int64 `json:"eval_duration,omitempty"` +} + +// ==================== HELPER METHODS ==================== + +// ToBifrostChatResponse converts an Ollama chat response to Bifrost format. +func (r *OllamaChatResponse) ToBifrostChatResponse(model string) *schemas.BifrostChatResponse { + if r == nil { + return nil + } + + // Parse timestamp + created := int(time.Now().Unix()) + if r.CreatedAt != "" { + if t, err := time.Parse(time.RFC3339Nano, r.CreatedAt); err == nil { + created = int(t.Unix()) + } + } + + response := &schemas.BifrostChatResponse{ + Model: model, + Created: created, + Object: "chat.completion", + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.Ollama, + }, + } + + // Build the choice + if r.Message != nil { + var toolCalls []schemas.ChatAssistantMessageToolCall + if len(r.Message.ToolCalls) > 0 { + for i, tc := range r.Message.ToolCalls { + args, _ := json.Marshal(tc.Function.Arguments) + toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ + Index: uint16(i), + Type: schemas.Ptr("function"), + ID: schemas.Ptr(tc.Function.Name), // Ollama doesn't provide IDs, use name + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &tc.Function.Name, + Arguments: string(args), + }, + }) + } + } + + var assistantMessage *schemas.ChatAssistantMessage + if len(toolCalls) > 0 { + assistantMessage = &schemas.ChatAssistantMessage{ + ToolCalls: toolCalls, + } + } + + // Handle thinking content for non-streaming responses + // Store thinking in tool call ExtraContent (similar to how we preserve it in message conversion) + if r.Message.Thinking != nil && *r.Message.Thinking != "" { + if assistantMessage == nil { + assistantMessage = &schemas.ChatAssistantMessage{} + } + // If we have tool calls, store thinking in the first one's ExtraContent + // Otherwise, create a placeholder tool call to preserve thinking + if len(assistantMessage.ToolCalls) > 0 { + if assistantMessage.ToolCalls[0].ExtraContent == nil { + assistantMessage.ToolCalls[0].ExtraContent = make(map[string]interface{}) + } + assistantMessage.ToolCalls[0].ExtraContent["ollama"] = map[string]interface{}{ + "thinking": *r.Message.Thinking, + } + } else { + // Create placeholder tool call to preserve thinking + assistantMessage.ToolCalls = []schemas.ChatAssistantMessageToolCall{ + { + Index: 0, + Type: schemas.Ptr("function"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("_thinking_placeholder"), + Arguments: "{}", + }, + ExtraContent: map[string]interface{}{ + "ollama": map[string]interface{}{ + "thinking": *r.Message.Thinking, + }, + }, + }, + } + } + } + + choice := schemas.BifrostResponseChoice{ + Index: 0, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRole(r.Message.Role), + Content: &schemas.ChatMessageContent{ + ContentStr: &r.Message.Content, + }, + ChatAssistantMessage: assistantMessage, + }, + }, + FinishReason: r.mapFinishReason(), + } + response.Choices = []schemas.BifrostResponseChoice{choice} + } + + // Map usage + response.Usage = r.toUsage() + + return response +} + +// mapFinishReason maps Ollama's done_reason to Bifrost format. +func (r *OllamaChatResponse) mapFinishReason() *string { + if r.DoneReason == nil { + if r.Done { + return schemas.Ptr("stop") + } + return nil + } + + switch *r.DoneReason { + case "stop": + return schemas.Ptr("stop") + case "length": + return schemas.Ptr("length") + case "load", "unload": + return schemas.Ptr("stop") + default: + return r.DoneReason + } +} + +// toUsage converts Ollama usage info to Bifrost format. +func (r *OllamaChatResponse) toUsage() *schemas.BifrostLLMUsage { + usage := &schemas.BifrostLLMUsage{} + + if r.PromptEvalCount != nil { + usage.PromptTokens = *r.PromptEvalCount + } + if r.EvalCount != nil { + usage.CompletionTokens = *r.EvalCount + } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + return usage +} + +// ToBifrostStreamResponse converts an Ollama streaming chunk to Bifrost format. +func (r *OllamaStreamResponse) ToBifrostStreamResponse(chunkIndex int) (*schemas.BifrostChatResponse, bool) { + if r == nil { + return nil, false + } + + response := &schemas.BifrostChatResponse{ + Model: r.Model, + Object: "chat.completion.chunk", + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ChatCompletionStreamRequest, + Provider: schemas.Ollama, + ChunkIndex: chunkIndex, + }, + } + + // Parse timestamp + if r.CreatedAt != "" { + if t, err := time.Parse(time.RFC3339Nano, r.CreatedAt); err == nil { + response.Created = int(t.Unix()) + } + } + + // Build delta content + if r.Message != nil { + var toolCalls []schemas.ChatAssistantMessageToolCall + if len(r.Message.ToolCalls) > 0 { + for i, tc := range r.Message.ToolCalls { + args, _ := json.Marshal(tc.Function.Arguments) + toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ + Index: uint16(i), + Type: schemas.Ptr("function"), + ID: schemas.Ptr(tc.Function.Name), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &tc.Function.Name, + Arguments: string(args), + }, + }) + } + } + + delta := &schemas.ChatStreamResponseChoiceDelta{} + + if r.Message.Role != "" { + role := string(r.Message.Role) + delta.Role = &role + } + + if r.Message.Content != "" { + delta.Content = &r.Message.Content + } + + // Handle thinking content (for thinking-specific models) + // Ollama may send thinking incrementally in streaming chunks, similar to content + if r.Message.Thinking != nil && *r.Message.Thinking != "" { + delta.Reasoning = r.Message.Thinking + } + + if len(toolCalls) > 0 { + delta.ToolCalls = toolCalls + } + + // Always create a choice if we have any delta content (content, thinking, tool calls, or role) + hasDelta := delta.Role != nil || delta.Content != nil || delta.Reasoning != nil || len(delta.ToolCalls) > 0 + if hasDelta { + choice := schemas.BifrostResponseChoice{ + Index: 0, + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ + Delta: delta, + }, + } + + // Set finish reason on final chunk + if r.Done { + if r.DoneReason != nil { + switch *r.DoneReason { + case "stop": + choice.FinishReason = schemas.Ptr("stop") + case "length": + choice.FinishReason = schemas.Ptr("length") + default: + choice.FinishReason = schemas.Ptr("stop") + } + } else { + choice.FinishReason = schemas.Ptr("stop") + } + } + + response.Choices = []schemas.BifrostResponseChoice{choice} + } + } + + // Add usage on final chunk + if r.Done { + usage := &schemas.BifrostLLMUsage{} + if r.PromptEvalCount != nil { + usage.PromptTokens = *r.PromptEvalCount + } + if r.EvalCount != nil { + usage.CompletionTokens = *r.EvalCount + } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + response.Usage = usage + } + + return response, r.Done +} + +// ToBifrostEmbeddingResponse converts an Ollama embedding response to Bifrost format. +func (r *OllamaEmbeddingResponse) ToBifrostEmbeddingResponse(model string) *schemas.BifrostEmbeddingResponse { + if r == nil { + return nil + } + + response := &schemas.BifrostEmbeddingResponse{ + Model: model, + Object: "list", + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.EmbeddingRequest, + Provider: schemas.Ollama, + }, + } + + // Convert embeddings to Bifrost format + for i, embedding := range r.Embeddings { + // Convert []float64 to []float32 + embeddingFloat32 := make([]float32, len(embedding)) + for j, v := range embedding { + embeddingFloat32[j] = float32(v) + } + + response.Data = append(response.Data, schemas.EmbeddingData{ + Object: "embedding", + Embedding: schemas.EmbeddingStruct{ + EmbeddingArray: embeddingFloat32, + }, + Index: i, + }) + } + + // Convert usage + if r.PromptEvalCount != nil { + response.Usage = &schemas.BifrostLLMUsage{ + PromptTokens: *r.PromptEvalCount, + TotalTokens: *r.PromptEvalCount, + } + } + + return response +} + +// ToBifrostListModelsResponse converts an Ollama list models response to Bifrost format. +func (r *OllamaListModelsResponse) ToBifrostListModelsResponse(providerName schemas.ModelProvider, configuredModels []string) *schemas.BifrostListModelsResponse { + if r == nil { + return nil + } + + response := &schemas.BifrostListModelsResponse{ + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ListModelsRequest, + Provider: providerName, + }, + } + + // Create a set of configured models for quick lookup + configuredSet := make(map[string]bool) + for _, m := range configuredModels { + configuredSet[m] = true + } + + for _, model := range r.Models { + // Filter models if configuredModels is non-empty + if len(configuredModels) > 0 && !configuredSet[model.Name] { + continue + } + + created := model.ModifiedAt.Unix() + ownedBy := "ollama" + + bifrostModel := schemas.Model{ + ID: model.Name, + Created: &created, + OwnedBy: &ownedBy, + } + + response.Data = append(response.Data, bifrostModel) + } + + return response +} diff --git a/core/providers/ollama/utils.go b/core/providers/ollama/utils.go new file mode 100644 index 000000000..f06bad205 --- /dev/null +++ b/core/providers/ollama/utils.go @@ -0,0 +1,469 @@ +// Package ollama implements the Ollama provider using native Ollama APIs. +// This file contains utility functions for converting between Bifrost and Ollama formats. +package ollama + +import ( + "encoding/base64" + "encoding/json" + "log" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// convertMessagesToOllama converts Bifrost messages to Ollama format. +// Ollama has specific semantics for tool calls: +// - Tool calls only appear on assistant messages +// - Assistant messages with tool_calls are function invocation requests and must have NO content or images +// - Tool responses must be separate messages with role="tool" and tool_name set +// - Ollama correlates tool calls and responses by function name directly, not by opaque IDs + +// NOTE: Ollama does not provide tool call IDs. When multiple calls to the same function occur +// in a single turn, tool responses are correlated by function name only. This is a lossy conversion +// but accurately reflects Ollama's native semantics. Bifrost allows toolCallId to be optional, +// so IDs are intentionally omitted. Do not generate synthetic tool call IDs. +func convertMessagesToOllama(messages []schemas.ChatMessage) []OllamaMessage { + var ollamaMessages []OllamaMessage + + for _, msg := range messages { + ollamaMsg := OllamaMessage{ + Role: mapRoleToOllama(msg.Role), + } + + if ollamaMsg.Role == "" { + continue // Skip unsupported roles + } + + // Check if this is an assistant message with tool calls + hasToolCalls := msg.Role == schemas.ChatMessageRoleAssistant && msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil + + // Convert content - but NOT for assistant messages with tool_calls + // In Ollama, assistant messages with tool_calls are function invocation requests + // and must contain no content or images, exactly as shown in native /api/chat behavior + if !hasToolCalls { + ollamaMsg.Content, ollamaMsg.Images = convertContentToOllama(msg.Content) + } else { + // Assistant message with tool_calls: no content or images + ollamaMsg.Content = "" + ollamaMsg.Images = nil + } + + // Handle tool calls - ONLY on assistant messages per Ollama semantics + if hasToolCalls { + // Filter out thinking placeholder tool calls before converting + var realToolCalls []schemas.ChatAssistantMessageToolCall + var thinkingContent *string + for _, tc := range msg.ChatAssistantMessage.ToolCalls { + // Check if this is a thinking placeholder + if tc.Function.Name != nil && *tc.Function.Name == "_thinking_placeholder" { + // Extract thinking from ExtraContent + if tc.ExtraContent != nil { + if ollamaData, ok := tc.ExtraContent["ollama"].(map[string]interface{}); ok { + if thinking, ok := ollamaData["thinking"].(string); ok && thinking != "" { + thinkingContent = &thinking + } + } + } + continue // Skip the placeholder tool call + } + // Extract thinking from tool call's ExtraContent if present + if tc.ExtraContent != nil { + if ollamaData, ok := tc.ExtraContent["ollama"].(map[string]interface{}); ok { + if thinking, ok := ollamaData["thinking"].(string); ok && thinking != "" { + thinkingContent = &thinking + } + } + } + realToolCalls = append(realToolCalls, tc) + } + if len(realToolCalls) > 0 { + ollamaMsg.ToolCalls = convertToolCallsToOllama(realToolCalls) + } + // Set thinking if we found it + if thinkingContent != nil { + ollamaMsg.Thinking = thinkingContent + } + } + + // Handle tool response messages - must set tool_name per Ollama semantics + // Ollama uses tool_name (function name) to correlate, not tool_call_id + // We ignore ToolCallID since Ollama doesn't support it + if msg.Role == schemas.ChatMessageRoleTool && msg.ChatToolMessage != nil { + if msg.Name != nil { + ollamaMsg.ToolName = msg.Name + } else { + log.Printf("Error in Tool message without Name field - Ollama requires tool_name field") + } + } + + if ollamaMsg.Role == "tool" && ollamaMsg.ToolName == nil { + continue // Skip invalid tool messages that would be silently ignored by Ollama + } + ollamaMessages = append(ollamaMessages, ollamaMsg) + } + + return ollamaMessages +} + +// NOTE: Ollama does not provide tool call IDs. When multiple calls to the same function occur +// in a single turn, tool responses are correlated by function name only. This is a lossy conversion +// but accurately reflects Ollama's native semantics. Bifrost allows toolCallId to be optional, +// so IDs are intentionally omitted. Do not generate synthetic tool call IDs. +func convertMessagesFromOllama(messages []OllamaMessage) []schemas.ChatMessage { + var bifrostMessages []schemas.ChatMessage + + for _, msg := range messages { + bifrostMsg := schemas.ChatMessage{ + Role: schemas.ChatMessageRole(msg.Role), + } + + // Check if this is an assistant message with tool calls + hasToolCalls := msg.Role == "assistant" && len(msg.ToolCalls) > 0 + + // Set content - but NOT for assistant messages with tool_calls + // In Ollama, assistant messages with tool_calls are function invocation requests + // and contain no content or images + if !hasToolCalls { + bifrostMsg.Content = &schemas.ChatMessageContent{ + ContentStr: &msg.Content, + } + } + // If hasToolCalls is true, Content remains nil (no content for function invocation requests) + + // Handle assistant messages with tool calls + // Ollama doesn't provide tool call IDs - ID field is optional in Bifrost, so we don't set it + if hasToolCalls { + var toolCalls []schemas.ChatAssistantMessageToolCall + for i, tc := range msg.ToolCalls { + args, _ := json.Marshal(tc.Function.Arguments) + toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ + Index: uint16(i), + Type: schemas.Ptr("function"), + // ID is intentionally not set - Ollama doesn't provide tool call IDs + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &tc.Function.Name, + Arguments: string(args), + }, + }) + } + bifrostMsg.ChatAssistantMessage = &schemas.ChatAssistantMessage{ + ToolCalls: toolCalls, + } + } + + // Handle thinking content for assistant messages + // Store thinking in the first tool call's ExtraContent (if tool calls exist) or create assistant message + // This preserves thinking for passthrough scenarios + if msg.Role == "assistant" && msg.Thinking != nil && *msg.Thinking != "" { + if bifrostMsg.ChatAssistantMessage == nil { + bifrostMsg.ChatAssistantMessage = &schemas.ChatAssistantMessage{} + } + // Store thinking in the first tool call's ExtraContent if tool calls exist + // Otherwise, we'll need to store it somewhere - but ChatAssistantMessage doesn't have ExtraContent + // So we'll store it in the first tool call's ExtraContent, or create a dummy tool call if none exist + if len(bifrostMsg.ChatAssistantMessage.ToolCalls) > 0 { + if bifrostMsg.ChatAssistantMessage.ToolCalls[0].ExtraContent == nil { + bifrostMsg.ChatAssistantMessage.ToolCalls[0].ExtraContent = make(map[string]interface{}) + } + bifrostMsg.ChatAssistantMessage.ToolCalls[0].ExtraContent["ollama"] = map[string]interface{}{ + "thinking": *msg.Thinking, + } + } else { + // No tool calls - create a dummy tool call to store thinking + // This is a workaround since ChatAssistantMessage doesn't have ExtraContent + bifrostMsg.ChatAssistantMessage.ToolCalls = []schemas.ChatAssistantMessageToolCall{ + { + Index: 0, + Type: schemas.Ptr("function"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("_thinking_placeholder"), + Arguments: "{}", + }, + ExtraContent: map[string]interface{}{ + "ollama": map[string]interface{}{ + "thinking": *msg.Thinking, + }, + }, + }, + } + } + } + + // Handle tool response messages + // Ollama uses tool_name (function name) to correlate, not tool_call_id + // Since ToolCallID is optional in Bifrost, we don't set it for Ollama + if msg.Role == "tool" && msg.ToolName != nil { + bifrostMsg.ChatToolMessage = &schemas.ChatToolMessage{ + // ToolCallID is intentionally not set - Ollama doesn't use tool call IDs + } + bifrostMsg.Name = msg.ToolName + } + + // Handle images - but NOT for assistant messages with tool_calls + // Assistant messages with tool_calls are function invocation requests and have no content/images + if !hasToolCalls && len(msg.Images) > 0 { + var contentBlocks []schemas.ChatContentBlock + + // Add text content if present + if msg.Content != "" { + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, + Text: &msg.Content, + }) + } + + // Add images + for _, img := range msg.Images { + dataURL := "data:image/jpeg;base64," + img + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ + URL: dataURL, + }, + }) + } + + bifrostMsg.Content = &schemas.ChatMessageContent{ + ContentBlocks: contentBlocks, + } + } + + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + + return bifrostMessages +} + +// ==================== ROLE MAPPING UTILITIES ==================== + +// mapRoleToOllama maps Bifrost roles to Ollama roles. +func mapRoleToOllama(role schemas.ChatMessageRole) string { + switch role { + case schemas.ChatMessageRoleDeveloper: + return "system" // Ollama doesn't support developer role, map to system + case schemas.ChatMessageRoleSystem: + return "system" + case schemas.ChatMessageRoleUser: + return "user" + case schemas.ChatMessageRoleAssistant: + return "assistant" + case schemas.ChatMessageRoleTool: + return "tool" + default: + return "" // Unsupported + } +} + +// ==================== CONTENT CONVERSION UTILITIES ==================== + +// convertContentToOllama extracts text and images from Bifrost content. +// Returns the combined text content and a slice of raw base64-encoded images. +// Note: Ollama expects raw base64 strings WITHOUT data URL prefixes. +func convertContentToOllama(content *schemas.ChatMessageContent) (string, []string) { + if content == nil { + return "", nil + } + + // Simple string content - no images + if content.ContentStr != nil { + return *content.ContentStr, nil + } + + // Content blocks - may contain text and/or images + if content.ContentBlocks == nil { + return "", nil + } + + var textParts []string + var images []string + + for _, block := range content.ContentBlocks { + switch block.Type { + case schemas.ChatContentBlockTypeText: + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + + case schemas.ChatContentBlockTypeImage: + // Extract base64 image data + // Note: ImageURLStruct.URL can be: + // 1. A data URL: "data:image/jpeg;base64," + // 2. Raw base64: "" + // 3. HTTP(S) URL: "https://..." (not supported by Ollama) + if block.ImageURLStruct != nil && block.ImageURLStruct.URL != "" { + imageData := extractBase64Image(block.ImageURLStruct.URL) + if imageData != "" { + images = append(images, imageData) + } + // extractBase64Image logs warnings for unsupported formats + } + } + } + + return strings.Join(textParts, "\n"), images +} + +// ==================== IMAGE UTILITIES ==================== + +// extractBase64Image extracts raw base64 data from various image URL formats. +// Ollama expects raw base64 strings without data URL prefixes. +// +// Supported formats: +// - data:image/jpeg;base64, -> extracts +// - data:image/png;base64, -> extracts +// - -> returns as-is +// - http(s)://... -> logs warning, returns empty (not supported) +func extractBase64Image(url string) string { + if url == "" { + return "" + } + + // Handle data URLs: data:image/jpeg;base64, + // Must strip the prefix to get raw base64 that Ollama expects + if strings.HasPrefix(url, "data:") { + // Find the comma that separates the metadata from the base64 data + commaIndex := strings.Index(url, ",") + if commaIndex != -1 && commaIndex < len(url)-1 { + // Extract everything after the comma (the raw base64 data) + base64Data := url[commaIndex+1:] + // Validate it's actually base64 + if isValidBase64(base64Data) { + return base64Data + } + log.Printf("Data URL contains invalid base64 data: %s", url[:min(50, len(url))]) + return "" + } + log.Printf("Malformed data URL (no comma separator): %s", url[:min(50, len(url))]) + return "" + } + + // Check if it's a regular HTTP(S) URL + if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { + log.Printf("Ollama does not support HTTP(S) image URLs. Please convert to base64: %s", url[:min(100, len(url))]) + return "" + } + + // Assume it's raw base64 - validate and return + if isValidBase64(url) { + return url + } + + log.Printf("Image URL is neither a valid data URL nor base64: %s", url[:min(50, len(url))]) + return "" +} + +// isValidBase64 checks if a string is valid base64 encoded data. +// This is more robust than just checking if it decodes, as it also validates +// that the string contains only valid base64 characters. +func isValidBase64(s string) bool { + if len(s) < 4 { + return false + } + + // Try to decode - this validates both format and content + decoded, err := base64.StdEncoding.DecodeString(s) + if err != nil { + // Try with padding issues fixed + decoded, err = base64.RawStdEncoding.DecodeString(s) + if err != nil { + return false + } + } + + // Sanity check: decoded data should be non-empty for images + return len(decoded) > 0 +} + +// min returns the minimum of two integers. +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// ==================== TOOL CONVERSION UTILITIES ==================== + +// convertToolCallsToOllama converts Bifrost tool calls to Ollama format. +// Ollama tool calls don't require an ID field - they use function name for correlation +func convertToolCallsToOllama(toolCalls []schemas.ChatAssistantMessageToolCall) []OllamaToolCall { + var ollamaToolCalls []OllamaToolCall + + for _, tc := range toolCalls { + var args map[string]interface{} + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + log.Printf("Failed to unmarshal tool call arguments: %v. Raw arguments: %s", err, tc.Function.Arguments) + args = map[string]interface{}{ + "_raw_arguments": tc.Function.Arguments, + } + } + } + if args == nil { + args = make(map[string]interface{}) + } + + name := "" + if tc.Function.Name != nil { + name = *tc.Function.Name + } + + ollamaToolCalls = append(ollamaToolCalls, OllamaToolCall{ + Function: OllamaToolCallFunction{ + Name: name, + Arguments: args, + }, + }) + } + + return ollamaToolCalls +} + +// convertToolsToOllama converts Bifrost tools to Ollama format. +func convertToolsToOllama(tools []schemas.ChatTool) []OllamaTool { + var ollamaTools []OllamaTool + + for _, tool := range tools { + if tool.Function == nil { + continue + } + + ollamaTool := OllamaTool{ + Type: "function", + Function: OllamaToolFunction{ + Name: tool.Function.Name, + }, + } + + if tool.Function.Description != nil { + ollamaTool.Function.Description = *tool.Function.Description + } + + if tool.Function.Parameters != nil { + ollamaTool.Function.Parameters = tool.Function.Parameters + } + + ollamaTools = append(ollamaTools, ollamaTool) + } + + return ollamaTools +} + +// convertToolsFromOllama converts Ollama tools to Bifrost format. +func convertToolsFromOllama(tools []OllamaTool) []schemas.ChatTool { + var bifrostTools []schemas.ChatTool + + for _, tool := range tools { + bifrostTool := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: tool.Function.Name, + Description: &tool.Function.Description, + Parameters: tool.Function.Parameters, + }, + } + bifrostTools = append(bifrostTools, bifrostTool) + } + + return bifrostTools +} diff --git a/core/providers/ollama/utils_test.go b/core/providers/ollama/utils_test.go new file mode 100644 index 000000000..44a874d6c --- /dev/null +++ b/core/providers/ollama/utils_test.go @@ -0,0 +1,481 @@ +package ollama + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestExtractBase64Image(t *testing.T) { + tests := []struct { + name string + input string + expected string + wantWarn bool + }{ + { + name: "data URL with JPEG", + input: "", + expected: "/9j/4AAQSkZJRg==", + wantWarn: false, + }, + { + name: "data URL with PNG", + input: "", + expected: "iVBORw0KGgoAAAANSUhEUg==", + wantWarn: false, + }, + { + name: "raw base64", + input: "iVBORw0KGgoAAAANSUhEUg==", + expected: "iVBORw0KGgoAAAANSUhEUg==", + wantWarn: false, + }, + { + name: "HTTP URL", + input: "https://example.com/image.jpg", + expected: "", + wantWarn: true, + }, + { + name: "HTTPS URL", + input: "https://example.com/image.png", + expected: "", + wantWarn: true, + }, + { + name: "empty string", + input: "", + expected: "", + wantWarn: false, + }, + { + name: "malformed data URL - no comma", + input: "data:image/jpeg;base64", + expected: "", + wantWarn: true, + }, + { + name: "malformed data URL - empty after comma", + input: "data:image/jpeg;base64,", + expected: "", + wantWarn: true, + }, + { + name: "invalid base64", + input: "not-valid-base64!@#$%", + expected: "", + wantWarn: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractBase64Image(tt.input) + if result != tt.expected { + t.Errorf("extractBase64Image(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestIsValidBase64(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "valid base64 - standard", + input: "iVBORw0KGgoAAAANSUhEUg==", + want: true, + }, + { + name: "valid base64 - JPEG header", + input: "/9j/4AAQSkZJRg==", + want: true, + }, + { + name: "valid base64 - no padding", + input: "SGVsbG8gV29ybGQ", + want: true, + }, + { + name: "invalid - too short", + input: "abc", + want: false, + }, + { + name: "invalid - special characters", + input: "abc!@#$%", + want: false, + }, + { + name: "empty string", + input: "", + want: false, + }, + { + name: "URL", + input: "https://example.com", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidBase64(tt.input) + if result != tt.want { + t.Errorf("isValidBase64(%q) = %v, want %v", tt.input, result, tt.want) + } + }) + } +} + +func TestMin(t *testing.T) { + tests := []struct { + a, b int + want int + }{ + {5, 10, 5}, + {10, 5, 5}, + {5, 5, 5}, + {0, 10, 0}, + {-5, 5, -5}, + } + + for _, tt := range tests { + result := min(tt.a, tt.b) + if result != tt.want { + t.Errorf("min(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.want) + } + } +} + +func TestConvertMessagesToOllama_ToolCalls(t *testing.T) { + t.Run("assistant message with tool calls", func(t *testing.T) { + functionName := "getWeather" + messages := []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("I'll check the weather for you."), + }, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Index: 0, + Type: schemas.Ptr("function"), + ID: schemas.Ptr("call_123"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &functionName, + Arguments: `{"location":"San Francisco"}`, + }, + }, + }, + }, + }, + } + + result := convertMessagesToOllama(messages) + + if len(result) != 1 { + t.Fatalf("Expected 1 message, got %d", len(result)) + } + + msg := result[0] + if msg.Role != "assistant" { + t.Errorf("Expected role 'assistant', got %q", msg.Role) + } + + if len(msg.ToolCalls) != 1 { + t.Fatalf("Expected 1 tool call, got %d", len(msg.ToolCalls)) + } + + if msg.ToolCalls[0].Function.Name != "getWeather" { + t.Errorf("Expected function name 'getWeather', got %q", msg.ToolCalls[0].Function.Name) + } + + if msg.ToolName != nil { + t.Errorf("ToolName should be nil for assistant messages, got %q", *msg.ToolName) + } + }) + + t.Run("tool response message with correct mapping", func(t *testing.T) { + functionName := "getWeather" + // First: assistant makes a tool call + // Second: tool response references that call by tool_call_id + messages := []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("I'll check the weather."), + }, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Index: 0, + Type: schemas.Ptr("function"), + ID: schemas.Ptr("call_abc123"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &functionName, + Arguments: `{"location":"Tokyo"}`, + }, + }, + }, + }, + }, + { + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(`{"temperature": 72, "condition": "sunny"}`), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call_abc123"), // References the tool call + }, + }, + } + + result := convertMessagesToOllama(messages) + + if len(result) != 2 { + t.Fatalf("Expected 2 messages, got %d", len(result)) + } + + // Verify assistant message + assistantMsg := result[0] + if assistantMsg.Role != "assistant" { + t.Errorf("Expected role 'assistant', got %q", assistantMsg.Role) + } + + // Verify tool response message + toolMsg := result[1] + if toolMsg.Role != "tool" { + t.Errorf("Expected role 'tool', got %q", toolMsg.Role) + } + + if toolMsg.ToolName == nil { + t.Fatal("ToolName should be set for tool messages") + } + + // CRITICAL: tool_name should be "getWeather" (from the mapping), NOT "call_abc123" + if *toolMsg.ToolName != "getWeather" { + t.Errorf("Expected tool_name 'getWeather', got %q", *toolMsg.ToolName) + } + + if len(toolMsg.ToolCalls) != 0 { + t.Errorf("Tool response messages should not have tool_calls") + } + }) + + t.Run("tool response without prior assistant message", func(t *testing.T) { + // Edge case: tool response arrives without a prior tool call in the conversation + messages := []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleTool, + Name: schemas.Ptr("getWeather"), // Fallback to Name field + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(`{"temperature": 72}`), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call_unknown"), + }, + }, + } + + result := convertMessagesToOllama(messages) + + if len(result) != 1 { + t.Fatalf("Expected 1 message, got %d", len(result)) + } + + msg := result[0] + if msg.ToolName == nil { + t.Fatal("ToolName should be set using Name field as fallback") + } + + if *msg.ToolName != "getWeather" { + t.Errorf("Expected tool_name 'getWeather' from Name field, got %q", *msg.ToolName) + } + }) + + t.Run("multiple tool calls and responses", func(t *testing.T) { + weatherFunc := "getWeather" + timeFunc := "getTime" + messages := []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("I'll check both."), + }, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + ID: schemas.Ptr("call_weather"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &weatherFunc, + Arguments: `{"location":"NYC"}`, + }, + }, + { + ID: schemas.Ptr("call_time"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &timeFunc, + Arguments: `{"timezone":"EST"}`, + }, + }, + }, + }, + }, + { + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(`{"temp": 65}`), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call_weather"), + }, + }, + { + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(`{"time": "3pm"}`), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call_time"), + }, + }, + } + + result := convertMessagesToOllama(messages) + + if len(result) != 3 { + t.Fatalf("Expected 3 messages, got %d", len(result)) + } + + // Check first tool response + if result[1].ToolName == nil || *result[1].ToolName != "getWeather" { + t.Errorf("Expected first tool response to have tool_name 'getWeather'") + } + + // Check second tool response + if result[2].ToolName == nil || *result[2].ToolName != "getTime" { + t.Errorf("Expected second tool response to have tool_name 'getTime'") + } + }) + + t.Run("tool calls on non-assistant message should be ignored", func(t *testing.T) { + functionName := "someFunction" + messages := []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &functionName, + }, + }, + }, + }, + }, + } + + result := convertMessagesToOllama(messages) + + if len(result) != 1 { + t.Fatalf("Expected 1 message, got %d", len(result)) + } + + // Tool calls should not be present for non-assistant messages + if len(result[0].ToolCalls) != 0 { + t.Errorf("User messages should not have tool_calls in Ollama format") + } + }) +} + +func TestConvertMessagesFromOllama_ToolCalls(t *testing.T) { + t.Run("assistant message with tool calls", func(t *testing.T) { + messages := []OllamaMessage{ + { + Role: "assistant", + Content: "I'll check the weather for you.", + ToolCalls: []OllamaToolCall{ + { + Function: OllamaToolCallFunction{ + Name: "getWeather", + Arguments: map[string]interface{}{ + "location": "San Francisco", + }, + }, + }, + }, + }, + } + + result := convertMessagesFromOllama(messages) + + if len(result) != 1 { + t.Fatalf("Expected 1 message, got %d", len(result)) + } + + msg := result[0] + if msg.Role != schemas.ChatMessageRoleAssistant { + t.Errorf("Expected role 'assistant', got %q", msg.Role) + } + + if msg.ChatAssistantMessage == nil { + t.Fatal("ChatAssistantMessage should not be nil") + } + + if len(msg.ChatAssistantMessage.ToolCalls) != 1 { + t.Fatalf("Expected 1 tool call, got %d", len(msg.ChatAssistantMessage.ToolCalls)) + } + + toolCall := msg.ChatAssistantMessage.ToolCalls[0] + if toolCall.Function.Name == nil || *toolCall.Function.Name != "getWeather" { + t.Errorf("Expected function name 'getWeather'") + } + }) + + t.Run("tool response message", func(t *testing.T) { + toolName := "getWeather" + messages := []OllamaMessage{ + { + Role: "tool", + Content: `{"temperature": 72, "condition": "sunny"}`, + ToolName: &toolName, + }, + } + + result := convertMessagesFromOllama(messages) + + if len(result) != 1 { + t.Fatalf("Expected 1 message, got %d", len(result)) + } + + msg := result[0] + if msg.Role != schemas.ChatMessageRoleTool { + t.Errorf("Expected role 'tool', got %q", msg.Role) + } + + if msg.ChatToolMessage == nil { + t.Fatal("ChatToolMessage should not be nil") + } + + if msg.ChatToolMessage.ToolCallID == nil { + t.Fatal("ToolCallID should be set") + } + + if *msg.ChatToolMessage.ToolCallID != "getWeather" { + t.Errorf("Expected tool_call_id 'getWeather', got %q", *msg.ChatToolMessage.ToolCallID) + } + + if msg.Name == nil || *msg.Name != "getWeather" { + t.Errorf("Expected Name 'getWeather'") + } + }) +}