Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions core/providers/vertex/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,30 @@ func getRequestBodyForAnthropicResponses(ctx context.Context, request *schemas.B

return jsonBody, nil
}

// getCompleteURLForGeminiEndpoint constructs the complete URL for the Gemini endpoint, for both streaming and non-streaming requests
// for custom/fine-tuned models, it uses the projectNumber
// for gemini models, it uses the projectID
func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID string, projectNumber string, isStreaming bool) string {
var url string
method := ":generateContent"
if isStreaming {
method = ":streamGenerateContent"
}
if schemas.IsAllDigitsASCII(deployment) {
// Custom/fine-tuned models use projectNumber
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s%s", projectNumber, deployment, method)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s%s", region, projectNumber, region, deployment, method)
}
} else {
// Gemini models use projectID
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s%s", projectID, deployment, method)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s%s", region, projectID, region, deployment, method)
}
}
return url
}
66 changes: 27 additions & 39 deletions core/providers/vertex/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
if err := sonic.Unmarshal(reqBytes, &requestBody); err != nil {
return nil, fmt.Errorf("failed to unmarshal request body: %w", err)
}
} else if schemas.IsGeminiModel(deployment) {
} else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) {
reqBody := gemini.ToGeminiChatCompletionRequest(request)
if reqBody == nil {
return nil, fmt.Errorf("chat completion input is not provided")
Expand Down Expand Up @@ -382,9 +382,9 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value))
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s/chat/completions", projectNumber, deployment)
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s/chat/completions", region, projectNumber, region, deployment)
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment)
}
} else if schemas.IsAnthropicModel(deployment) {
// Claude models use Anthropic publisher
Expand Down Expand Up @@ -502,7 +502,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.
}

return response, nil
} else if schemas.IsGeminiModel(deployment) {
} else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) {
geminiResponse := gemini.GenerateContentResponse{}

rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
Expand Down Expand Up @@ -675,7 +675,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo
RequestType: schemas.ChatCompletionStreamRequest,
},
)
} else if schemas.IsGeminiModel(deployment) {
} else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) {
// Use Gemini-style streaming for Gemini models
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
Expand All @@ -701,14 +701,15 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value))
}

// Construct the URL for Gemini streaming
var completeURL string
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:streamGenerateContent", projectID, deployment)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:streamGenerateContent", region, projectID, region, deployment)
// For custom/fine-tuned models, validate projectNumber is set
projectNumber := key.VertexKeyConfig.ProjectNumber
if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName)
}

// Construct the URL for Gemini streaming
completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, true)

// Add alt=sse parameter
if authQuery != "" {
completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery)
Expand Down Expand Up @@ -757,21 +758,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo
authQuery := ""
// Determine the URL based on model type
var completeURL string
if schemas.IsAllDigitsASCII(deployment) {
// Custom Fine-tuned models use OpenAPI endpoint
projectNumber := key.VertexKeyConfig.ProjectNumber
if projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName)
}
if key.Value != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value))
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s/chat/completions", projectNumber, deployment)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s/chat/completions", region, projectNumber, region, deployment)
}
} else if schemas.IsMistralModel(deployment) {
if schemas.IsMistralModel(deployment) {
// Mistral models use mistralai publisher with streamRawPredict
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, deployment)
Expand Down Expand Up @@ -947,7 +934,7 @@ func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key,
}

return response, nil
} else if schemas.IsGeminiModel(deployment) {
} else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) {
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
Expand Down Expand Up @@ -981,13 +968,14 @@ func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key,
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value))
}

var url string
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment)
// For custom/fine-tuned models, validate projectNumber is set
projectNumber := key.VertexKeyConfig.ProjectNumber
if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName)
}

url := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, false)

// Create HTTP request for streaming
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
Expand Down Expand Up @@ -1169,7 +1157,7 @@ func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRun
RequestType: schemas.ResponsesStreamRequest,
},
)
} else if schemas.IsGeminiModel(deployment) {
} else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) {
region := key.VertexKeyConfig.Region
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName)
Expand Down Expand Up @@ -1205,14 +1193,14 @@ func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRun
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value))
}

// Construct the URL for Gemini streaming
var completeURL string
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:streamGenerateContent", projectID, deployment)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:streamGenerateContent", region, projectID, region, deployment)
// For custom/fine-tuned models, validate projectNumber is set
projectNumber := key.VertexKeyConfig.ProjectNumber
if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName)
}

// Construct the URL for Gemini streaming
completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, true)
// Add alt=sse parameter
if authQuery != "" {
completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery)
Expand Down