Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions core/providers/cohere/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ func (response *CohereEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.B
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.Tokens.OutputTokens)
}
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
} else if response.Meta.BilledUnits != nil {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
if response.Meta.BilledUnits.InputTokens != nil {
bifrostResponse.Usage.PromptTokens = int(*response.Meta.BilledUnits.InputTokens)
}
if response.Meta.BilledUnits.OutputTokens != nil {
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.BilledUnits.OutputTokens)
}
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
}
}

Expand Down
150 changes: 106 additions & 44 deletions core/providers/gemini/embedding.go
Original file line number Diff line number Diff line change
@@ -1,56 +1,70 @@
package gemini

import (
"strings"

"github.com/maximhq/bifrost/core/schemas"
)

// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's embedding request format
func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiEmbeddingRequest {
// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's batch embedding request format
// GeminiGenerationRequest contains requests array for batch embed content endpoint
func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiBatchEmbeddingRequest {
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
return nil
}

embeddingInput := bifrostReq.Input
// Get the text to embed
var text string

// Collect all texts to embed
var texts []string
if embeddingInput.Text != nil {
text = *embeddingInput.Text
} else if len(embeddingInput.Texts) > 0 {
// Take the first text if multiple texts are provided
text = strings.Join(embeddingInput.Texts, " ")
texts = append(texts, *embeddingInput.Text)
}
if text == "" {
if len(embeddingInput.Texts) > 0 {
texts = append(texts, embeddingInput.Texts...)
}

if len(texts) == 0 {
return nil
}
// Create the Gemini embedding request
request := &GeminiEmbeddingRequest{
Model: bifrostReq.Model,
Content: &Content{
Parts: []*Part{
{
Text: text,

// Create batch embedding request with one request per text
batchRequest := &GeminiBatchEmbeddingRequest{
Requests: make([]GeminiEmbeddingRequest, len(texts)),
}

// Create individual embedding requests for each text
for i, text := range texts {
embeddingReq := GeminiEmbeddingRequest{
Model: "models/" + bifrostReq.Model,
Content: &Content{
Parts: []*Part{
{
Text: text,
},
},
},
},
}
// Add parameters if available
if bifrostReq.Params != nil {
if bifrostReq.Params.Dimensions != nil {
request.OutputDimensionality = bifrostReq.Params.Dimensions
}

// Handle extra parameters
if bifrostReq.Params.ExtraParams != nil {
if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok {
request.TaskType = taskType
// Add parameters if available
if bifrostReq.Params != nil {
if bifrostReq.Params.Dimensions != nil {
embeddingReq.OutputDimensionality = bifrostReq.Params.Dimensions
}
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
request.Title = title

// Handle extra parameters
if bifrostReq.Params.ExtraParams != nil {
if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok {
embeddingReq.TaskType = taskType
}
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
embeddingReq.Title = title
}
}
}

batchRequest.Requests[i] = embeddingReq
}
return request

return batchRequest
}

// ToGeminiEmbeddingResponse converts a BifrostResponse with embedding data to Gemini's embedding response format
Expand Down Expand Up @@ -99,6 +113,50 @@ func ToGeminiEmbeddingResponse(bifrostResp *schemas.BifrostEmbeddingResponse) *G
return geminiResp
}

// ToBifrostEmbeddingResponse converts a Gemini embedding response to BifrostEmbeddingResponse format
func ToBifrostEmbeddingResponse(geminiResp *GeminiEmbeddingResponse, model string) *schemas.BifrostEmbeddingResponse {
if geminiResp == nil || len(geminiResp.Embeddings) == 0 {
return nil
}

bifrostResp := &schemas.BifrostEmbeddingResponse{
Data: make([]schemas.EmbeddingData, len(geminiResp.Embeddings)),
Model: model,
Object: "list",
}

// Convert each embedding from Gemini format to Bifrost format
for i, geminiEmbedding := range geminiResp.Embeddings {
embeddingData := schemas.EmbeddingData{
Index: i,
Object: "embedding",
Embedding: schemas.EmbeddingStruct{
EmbeddingArray: geminiEmbedding.Values,
},
}

bifrostResp.Data[i] = embeddingData
}

// Convert usage metadata if available
if geminiResp.Metadata != nil || (len(geminiResp.Embeddings) > 0 && geminiResp.Embeddings[0].Statistics != nil) {
bifrostResp.Usage = &schemas.BifrostLLMUsage{}

// Use statistics from the first embedding if available
if geminiResp.Embeddings[0].Statistics != nil {
bifrostResp.Usage.PromptTokens = int(geminiResp.Embeddings[0].Statistics.TokenCount)
} else if geminiResp.Metadata != nil {
// Fall back to metadata if statistics are not available
bifrostResp.Usage.PromptTokens = int(geminiResp.Metadata.BillableCharacterCount)
}

// Set total tokens same as prompt tokens for embeddings
bifrostResp.Usage.TotalTokens = bifrostResp.Usage.PromptTokens
}

return bifrostResp
}

// ToBifrostEmbeddingRequest converts a GeminiGenerationRequest to BifrostEmbeddingRequest format
func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest() *schemas.BifrostEmbeddingRequest {
if request == nil {
Expand All @@ -114,25 +172,29 @@ func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest() *schemas.Bif
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
}

// sdk request contains multiple embedding requests with same parameters but different text fields
if len(request.Requests) > 0 {
embeddingRequest := request.Requests[0]
if embeddingRequest.Content != nil {
var texts []string
for _, part := range embeddingRequest.Content.Parts {
if part != nil && part.Text != "" {
texts = append(texts, part.Text)
var texts []string
for _, req := range request.Requests {
if req.Content != nil && len(req.Content.Parts) > 0 {
for _, part := range req.Content.Parts {
if part != nil && part.Text != "" {
texts = append(texts, part.Text)
}
}
}
if len(texts) > 0 {
bifrostReq.Input = &schemas.EmbeddingInput{}
if len(texts) == 1 {
bifrostReq.Input.Text = &texts[0]
} else {
bifrostReq.Input.Texts = texts
if len(texts) > 0 {
bifrostReq.Input = &schemas.EmbeddingInput{}
if len(texts) == 1 {
bifrostReq.Input.Text = &texts[0]
} else {
bifrostReq.Input.Texts = texts
}
}
}
}

embeddingRequest := request.Requests[0]

// Convert parameters
if embeddingRequest.OutputDimensionality != nil || embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
bifrostReq.Params = &schemas.EmbeddingParameters{}
Expand Down
93 changes: 82 additions & 11 deletions core/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"time"

"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/openai"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
Expand Down Expand Up @@ -879,19 +878,91 @@ func (provider *GeminiProvider) Embedding(ctx context.Context, key schemas.Key,
if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil {
return nil, err
}
// Use the shared embedding request handler
return openai.HandleOpenAIEmbeddingRequest(

providerName := provider.GetProviderKey()

// Convert Bifrost request to Gemini batch embedding request format
jsonData, err := providerUtils.CheckContextAndGetRequestBody(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/openai/embeddings"),
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
func() (any, error) { return ToGeminiEmbeddingRequest(request), nil },
providerName)
if err != nil {
return nil, err
}

// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)

// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)

// Use Gemini's batchEmbedContents endpoint
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/models/"+request.Model+":batchEmbedContents"))
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
if key.Value != "" {
req.Header.Set("x-goog-api-key", key.Value)
}

req.SetBody(jsonData)

// Make request
latency, bifrostErr := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
return nil, bifrostErr
}

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body())))
return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{
Provider: providerName,
Model: request.Model,
RequestType: schemas.EmbeddingRequest,
})
}

body, decodeErr := providerUtils.CheckAndDecodeBody(resp)
if decodeErr != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName)
}

// Parse Gemini's batch embedding response
var geminiResponse GeminiEmbeddingResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &geminiResponse, jsonData,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.logger,
)
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}

// Convert to Bifrost format
bifrostResponse := ToBifrostEmbeddingResponse(&geminiResponse, request.Model)
if bifrostResponse == nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal,
fmt.Errorf("failed to convert Gemini embedding response to Bifrost format"), providerName)
}

bifrostResponse.ExtraFields.Provider = providerName
bifrostResponse.ExtraFields.ModelRequested = request.Model
bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}

// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}

return bifrostResponse, nil
}

// Speech performs a speech synthesis request to the Gemini API.
Expand Down
4 changes: 4 additions & 0 deletions core/providers/gemini/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,10 @@ const (
ThinkingLevelHigh ThinkingLevel = "HIGH"
)

type GeminiBatchEmbeddingRequest struct {
Requests []GeminiEmbeddingRequest `json:"requests,omitempty"`
}

// GeminiEmbeddingRequest represents a single embedding request in a batch.
type GeminiEmbeddingRequest struct {
Content *Content `json:"content,omitempty"`
Expand Down
21 changes: 21 additions & 0 deletions docs/integrations/langchain-sdk.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,27 @@ doc_embeddings = embeddings.embed_documents([
**This means `OpenAIEmbeddings` only works reliably with OpenAI embedding models.** Using it with other providers (e.g., `model="cohere/embed-v4.0"`) will fail because those providers cannot process int array inputs.
</Warning>

### Cross-Provider Embeddings

For embedding models from other providers (Cohere, Bedrock, Gemini, etc.), you can use `GoogleGenerativeAIEmbeddings` from the `langchain_google_genai` package. This module sends text strings directly and works across multiple providers:

```python
from langchain_google_genai import GoogleGenerativeAIEmbeddings

# Works with any provider's embedding models
embeddings = GoogleGenerativeAIEmbeddings(
model="cohere/cohere-embed-v4.0", # or bedrock/..., gemini/..., etc.
base_url="http://localhost:8080/langchain",
api_key="dummy-key"
)

query_embedding = embeddings.embed_query("What is machine learning?")
doc_embeddings = embeddings.embed_documents([
"Machine learning is a subset of AI",
"Deep learning uses neural networks"
])
```

---

## Supported Features
Expand Down
Loading