diff --git a/src/providers/google-vertex-ai/chatComplete.ts b/src/providers/google-vertex-ai/chatComplete.ts index 55b9f48c4..6a016f4a6 100644 --- a/src/providers/google-vertex-ai/chatComplete.ts +++ b/src/providers/google-vertex-ai/chatComplete.ts @@ -6,7 +6,6 @@ import { ContentType, Message, Params, - Tool, ToolCall, SYSTEM_MESSAGE_ROLES, MESSAGE_ROLES, @@ -46,27 +45,16 @@ import type { GoogleGenerateContentResponse, VertexLlamaChatCompleteStreamChunk, VertexLLamaChatCompleteResponse, - GoogleSearchRetrievalTool, } from './types'; import { getMimeType, + googleTools, recursivelyDeleteUnsupportedParameters, - transformGeminiToolParameters, + transformGoogleTools, transformInputAudioPart, transformVertexLogprobs, } from './utils'; -export const buildGoogleSearchRetrievalTool = (tool: Tool) => { - const googleSearchRetrievalTool: GoogleSearchRetrievalTool = { - googleSearchRetrieval: {}, - }; - if (tool.function.parameters?.dynamicRetrievalConfig) { - googleSearchRetrievalTool.googleSearchRetrieval.dynamicRetrievalConfig = - tool.function.parameters.dynamicRetrievalConfig; - } - return googleSearchRetrievalTool; -}; - export const VertexGoogleChatCompleteConfig: ProviderConfig = { // https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions model: { @@ -296,27 +284,9 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = { // these are not supported by google recursivelyDeleteUnsupportedParameters(tool.function?.parameters); delete tool.function?.strict; - - if (['googleSearch', 'google_search'].includes(tool.function.name)) { - const timeRangeFilter = tool.function.parameters?.timeRangeFilter; - tools.push({ - googleSearch: { - // allow null - ...(timeRangeFilter !== undefined && { timeRangeFilter }), - }, - }); - } else if ( - ['googleSearchRetrieval', 'google_search_retrieval'].includes( - tool.function.name - ) - ) { - tools.push(buildGoogleSearchRetrievalTool(tool)); + if (googleTools.includes(tool.function.name)) { + tools.push(...transformGoogleTools(tool)); } else { - if (tool.function?.parameters) { - tool.function.parameters = transformGeminiToolParameters( - tool.function.parameters - ); - } functionDeclarations.push(tool.function); } } @@ -359,11 +329,11 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = { param: 'generationConfig', transform: (params: Params) => transformGenerationConfig(params), }, - seed: { + modalities: { param: 'generationConfig', transform: (params: Params) => transformGenerationConfig(params), }, - modalities: { + seed: { param: 'generationConfig', transform: (params: Params) => transformGenerationConfig(params), }, diff --git a/src/providers/google-vertex-ai/utils.ts b/src/providers/google-vertex-ai/utils.ts index 5365370b7..2dfcda64e 100644 --- a/src/providers/google-vertex-ai/utils.ts +++ b/src/providers/google-vertex-ai/utils.ts @@ -3,6 +3,7 @@ import { GoogleResponseCandidate, GoogleBatchRecord, GoogleFinetuneRecord, + GoogleSearchRetrievalTool, } from './types'; import { generateErrorResponse } from '../utils'; import { @@ -13,7 +14,7 @@ import { import { ErrorResponse, FinetuneRequest, Logprobs } from '../types'; import { Context } from 'hono'; import { env } from 'hono/adapter'; -import { ContentType, JsonSchema } from '../../types/requestBody'; +import { ContentType, JsonSchema, Tool } from '../../types/requestBody'; /** * Encodes an object as a Base64 URL-encoded string. @@ -729,3 +730,51 @@ export const transformInputAudioPart = (c: ContentType) => { }, }; }; + +export const googleTools = [ + 'googleSearch', + 'google_search', + 'googleSearchRetrieval', + 'google_search_retrieval', + 'computerUse', + 'computer_use', +]; + +export const transformGoogleTools = (tool: Tool) => { + const tools: any = []; + if (['googleSearch', 'google_search'].includes(tool.function.name)) { + const timeRangeFilter = tool.function.parameters?.timeRangeFilter; + tools.push({ + googleSearch: { + // allow null + ...(timeRangeFilter !== undefined && { timeRangeFilter }), + }, + }); + } else if ( + ['googleSearchRetrieval', 'google_search_retrieval'].includes( + tool.function.name + ) + ) { + tools.push(buildGoogleSearchRetrievalTool(tool)); + } else if (['computerUse', 'computer_use'].includes(tool.function.name)) { + tools.push({ + computerUse: { + environment: tool.function.parameters?.environment, + excludedPredefinedFunctions: + tool.function.parameters?.excluded_predefined_functions, + }, + }); + } + return tools; +}; + +export const buildGoogleSearchRetrievalTool = (tool: Tool) => { + const googleSearchRetrievalTool: GoogleSearchRetrievalTool = { + googleSearchRetrieval: {}, + }; + if (tool.function.parameters?.dynamicRetrievalConfig) { + googleSearchRetrievalTool.googleSearchRetrieval.dynamicRetrievalConfig = + tool.function.parameters.dynamicRetrievalConfig; + } + return googleSearchRetrievalTool; +}; diff --git a/src/providers/google/chatComplete.ts b/src/providers/google/chatComplete.ts index d85a0351d..aa5821df9 100644 --- a/src/providers/google/chatComplete.ts +++ b/src/providers/google/chatComplete.ts @@ -9,11 +9,12 @@ import { SYSTEM_MESSAGE_ROLES, MESSAGE_ROLES, } from '../../types/requestBody'; -import { buildGoogleSearchRetrievalTool } from '../google-vertex-ai/chatComplete'; import { getMimeType, + googleTools, recursivelyDeleteUnsupportedParameters, transformGeminiToolParameters, + transformGoogleTools, transformInputAudioPart, transformVertexLogprobs, } from '../google-vertex-ai/utils'; @@ -374,15 +375,8 @@ export const GoogleChatCompleteConfig: ProviderConfig = { // these are not supported by google recursivelyDeleteUnsupportedParameters(tool.function?.parameters); delete tool.function?.strict; - - if (['googleSearch', 'google_search'].includes(tool.function.name)) { - tools.push({ googleSearch: {} }); - } else if ( - ['googleSearchRetrieval', 'google_search_retrieval'].includes( - tool.function.name - ) - ) { - tools.push(buildGoogleSearchRetrievalTool(tool)); + if (googleTools.includes(tool.function.name)) { + tools.push(...transformGoogleTools(tool)); } else { if (tool.function?.parameters) { tool.function.parameters = transformGeminiToolParameters(