diff --git a/api/app/clients/tools/manifest.json b/api/app/clients/tools/manifest.json index c12b962fee12..1049d3245a6a 100644 --- a/api/app/clients/tools/manifest.json +++ b/api/app/clients/tools/manifest.json @@ -179,5 +179,12 @@ "description": "Provide your Flux API key from your user profile." } ] + }, + { + "name": "Gemini Image Tools", + "pluginKey": "gemini_image_gen", + "description": "Generate high-quality images using Google's Gemini Image Models. Supports both Gemini API (GEMINI_API_KEY) and Vertex AI (service account).", + "icon": "/assets/gemini_image_gen.svg", + "authConfig": [] } ] diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 5692bee8db34..7d4b1597217a 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -10,6 +10,8 @@ const { createSafeUser, mcpToolPattern, loadWebSearchAuth, + createImageToolContext, + GeminiImageGen, } = require('@librechat/api'); const { Tools, @@ -38,11 +40,13 @@ const { } = require('../'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { createMCPTool, createMCPTools } = require('~/server/services/MCP'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { getMCPServerTools } = require('~/server/services/Config'); const { getRoleByName } = require('~/models/Role'); +const { getFiles } = require('~/models/File'); /** * Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values. @@ -179,6 +183,7 @@ const loadTools = async ({ 'azure-ai-search': StructuredACS, traversaal_search: TraversaalSearch, tavily_search_results_json: TavilySearchResults, + gemini_image_gen: GeminiImageGen, }; const customConstructors = { @@ -191,24 +196,10 @@ const loadTools = async ({ const authFields = getAuthFields('image_gen_oai'); const authValues = await loadAuthValues({ userId: user, authFields }); const imageFiles = options.tool_resources?.[EToolResources.image_edit]?.files ?? []; - let toolContext = ''; - for (let i = 0; i < imageFiles.length; i++) { - const file = imageFiles[i]; - if (!file) { - continue; - } - if (i === 0) { - toolContext = - 'Image files provided in this request (their image IDs listed in order of appearance) available for image editing:'; - } - toolContext += `\n\t- ${file.file_id}`; - if (i === imageFiles.length - 1) { - toolContext += `\n\nInclude any you need in the \`image_ids\` array when calling \`${EToolResources.image_edit}_oai\`. You may also include previously referenced or generated image IDs.`; - } - } - if (toolContext) { - toolContextMap.image_edit_oai = toolContext; - } + createImageToolContext(imageFiles, toolContextMap, { + toolKey: 'image_edit_oai', + purpose: 'image editing', + }); return createOpenAIImageTools({ ...authValues, isAgent: !!agent, @@ -218,6 +209,26 @@ const loadTools = async ({ imageFiles, }); }, + gemini_image_gen: async (toolContextMap) => { + const authFields = getAuthFields('gemini_image_gen'); + const authValues = await loadAuthValues({ userId: user, authFields }); + const imageFiles = options.tool_resources?.[EToolResources.image_edit]?.files ?? []; + createImageToolContext(imageFiles, toolContextMap, { + toolKey: 'gemini_image_gen', + purpose: 'image context', + }); + return new GeminiImageGen({ + ...authValues, + isAgent: !!agent, + req: options.req, + imageFiles, + processFileURL: options.processFileURL, + userId: user, + fileStrategy: options.fileStrategy, + getFiles, + getStrategyFunctions, + }); + }, }; const requestedTools = {}; @@ -240,6 +251,7 @@ const loadTools = async ({ flux: imageGenOptions, dalle: imageGenOptions, 'stable-diffusion': imageGenOptions, + gemini_image_gen: imageGenOptions, }; /** @type {Record} */ diff --git a/api/package.json b/api/package.json index 69354e181263..a57152f9288f 100644 --- a/api/package.json +++ b/api/package.json @@ -41,6 +41,7 @@ "@azure/search-documents": "^12.0.0", "@azure/storage-blob": "^12.27.0", "@google/generative-ai": "^0.24.0", + "@google/genai": "^1.19.0", "@googleapis/youtube": "^20.0.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.79", diff --git a/api/server/services/start/tools.js b/api/server/services/start/tools.js index f139eaac4dcd..078bd79089b8 100644 --- a/api/server/services/start/tools.js +++ b/api/server/services/start/tools.js @@ -5,7 +5,7 @@ const { Calculator } = require('@librechat/agents'); const { logger } = require('@librechat/data-schemas'); const { zodToJsonSchema } = require('zod-to-json-schema'); const { Tools, ImageVisionTool } = require('librechat-data-provider'); -const { getToolkitKey, oaiToolkit, ytToolkit } = require('@librechat/api'); +const { getToolkitKey, oaiToolkit, ytToolkit, GeminiImageGen } = require('@librechat/api'); const { toolkits } = require('~/app/clients/tools/manifest'); /** @@ -84,6 +84,7 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] }) new Calculator(), ...Object.values(oaiToolkit), ...Object.values(ytToolkit), + new GeminiImageGen({ override: true }), ]; for (const toolInstance of basicToolInstances) { const formattedTool = formatToOpenAIAssistantTool(toolInstance); diff --git a/client/public/assets/gemini_image_gen.svg b/client/public/assets/gemini_image_gen.svg new file mode 100644 index 000000000000..4430340eb1eb --- /dev/null +++ b/client/public/assets/gemini_image_gen.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index 16de45d476f4..bfa2b28fac65 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -103,7 +103,9 @@ const Part = memo( ); } else if ( isToolCall && - (toolCall.name === 'image_gen_oai' || toolCall.name === 'image_edit_oai') + (toolCall.name === 'image_gen_oai' || + toolCall.name === 'image_edit_oai' || + toolCall.name === 'gemini_image_gen') ) { return ( = 1) { + return localize('com_ui_image_created'); + } + if (progress >= 0.7) { + return localize('com_ui_final_touch'); + } + if (progress >= 0.5) { + return localize('com_ui_adding_details'); + } + if (progress >= 0.3) { + return localize('com_ui_creating_image'); + } + return localize('com_ui_getting_started'); + } + if (progress >= 1) { return localize('com_ui_image_created'); } diff --git a/package-lock.json b/package-lock.json index 8145faae7dac..253e1f47aa9c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -54,6 +54,7 @@ "@azure/identity": "^4.7.0", "@azure/search-documents": "^12.0.0", "@azure/storage-blob": "^12.27.0", + "@google/genai": "^1.19.0", "@google/generative-ai": "^0.24.0", "@googleapis/youtube": "^20.0.0", "@keyv/redis": "^4.3.3", @@ -14750,6 +14751,27 @@ "integrity": "sha512-kym7SodPp8/wloecOpcmSnWJsK7M0E5Wg8UcFA+uO4B9s5d0ywXOEro/8HM9x0rW+TljRzul/14UYz3TleT3ig==", "license": "MIT" }, + "node_modules/@google/genai": { + "version": "1.19.0", + "resolved": "https://registry.npmjs.org/@google/genai/-/genai-1.19.0.tgz", + "integrity": "sha512-mIMV3M/KfzzFA//0fziK472wKBJ1TdJLhozIUJKTPLyTDN1NotU+hyoHW/N0cfrcEWUK20YA0GxCeHC4z0SbMA==", + "license": "Apache-2.0", + "dependencies": { + "google-auth-library": "^9.14.2", + "ws": "^8.18.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "@modelcontextprotocol/sdk": "^1.11.4" + }, + "peerDependenciesMeta": { + "@modelcontextprotocol/sdk": { + "optional": true + } + } + }, "node_modules/@google/generative-ai": { "version": "0.24.0", "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.24.0.tgz", @@ -45968,7 +45990,6 @@ "version": "8.18.0", "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.0.tgz", "integrity": "sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw==", - "devOptional": true, "engines": { "node": ">=10.0.0" }, @@ -46273,6 +46294,7 @@ "@azure/identity": "^4.7.0", "@azure/search-documents": "^12.0.0", "@azure/storage-blob": "^12.27.0", + "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.79", "@librechat/agents": "^3.0.34", diff --git a/packages/api/package.json b/packages/api/package.json index c2486211489b..de5357252f6c 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -82,6 +82,7 @@ "@azure/identity": "^4.7.0", "@azure/search-documents": "^12.0.0", "@azure/storage-blob": "^12.27.0", + "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.79", "@librechat/agents": "^3.0.34", diff --git a/packages/api/src/tools/gemini/GeminiImageGen.ts b/packages/api/src/tools/gemini/GeminiImageGen.ts new file mode 100644 index 000000000000..02e60eff37ee --- /dev/null +++ b/packages/api/src/tools/gemini/GeminiImageGen.ts @@ -0,0 +1,734 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { z } from 'zod'; +import { v4 as uuidv4 } from 'uuid'; +import { StructuredTool } from '@langchain/core/tools'; +import { GoogleGenAI } from '@google/genai'; +import { FileContext, ContentTypes, FileSources } from 'librechat-data-provider'; +import { logger } from '@librechat/data-schemas'; +import type { ServerRequest } from '../../types'; +import type { + MongoFile, + SafetyBlock, + GeminiResponse, + GeminiProvider, + GeminiInlineData, + GeminiContentPart, + SaveBase64ImageParams, + GeminiImageGenFields, + AgentToolReturn, + GetFilesFunction, + GetStrategyFunctionsType, + GetDownloadStreamFunction, +} from './types'; +import { + TOOL_NAME, + DISPLAY_MESSAGE, + TOOL_DESCRIPTION, + DEFAULT_MODEL_ID, + PROMPT_DESCRIPTION, + DESCRIPTION_FOR_MODEL, + IMAGE_IDS_DESCRIPTION, +} from './constants'; + +// ============================================================================= +// UTILITY FUNCTIONS +// ============================================================================= + +/** + * Save base64 image data to storage using processFileURL + * Works with any file storage strategy (local, s3, azure, firebase) + */ +export async function saveBase64ImageToStorage({ + base64Data, + outputFormat, + processFileURL, + fileStrategy, + userId, +}: SaveBase64ImageParams): Promise { + if (!processFileURL || !fileStrategy || !userId) { + logger.warn( + '[GeminiImageGen] Missing required parameters for storage, falling back to data URL', + ); + return null; + } + + try { + const dataURL = `data:image/${outputFormat};base64,${base64Data}`; + const imageName = `gemini-img-${uuidv4()}.${outputFormat}`; + + const result = await processFileURL({ + URL: dataURL, + basePath: 'images', + userId, + fileName: imageName, + fileStrategy, + context: FileContext.image_generation, + }); + + return result.filepath; + } catch (error) { + logger.error('[GeminiImageGen] Error saving image to storage:', error); + return null; + } +} + +/** + * Replace unwanted characters from prompt text + */ +function replaceUnwantedChars(inputString?: string): string { + return inputString?.replaceAll(/[^\w\s\-_.,!?()]/g, '') || ''; +} + +// ============================================================================= +// GEMINI IMAGE GENERATION TOOL +// ============================================================================= + +// Schema definition +const geminiImageGenSchema = z.object({ + prompt: z.string().max(32000).describe(PROMPT_DESCRIPTION), + image_ids: z.array(z.string()).optional().describe(IMAGE_IDS_DESCRIPTION), +}); + +export class GeminiImageGen extends StructuredTool { + name = TOOL_NAME; + description = TOOL_DESCRIPTION; + description_for_model = DESCRIPTION_FOR_MODEL; + schema = geminiImageGenSchema; + + // Configuration + private readonly overrideConfig: boolean; + private readonly returnMetadata: boolean; + private readonly userId?: string; + private readonly fileStrategy?: string; + private readonly isAgent: boolean; + private readonly req?: ServerRequest; + private readonly processFileURL?: GeminiImageGenFields['processFileURL']; + private readonly imageFiles: MongoFile[]; + private readonly getFiles?: GetFilesFunction; + private readonly getStrategyFunctions?: GetStrategyFunctionsType; + + constructor(fields: GeminiImageGenFields = {}) { + super(); + + this.overrideConfig = fields.override ?? false; + this.returnMetadata = fields.returnMetadata ?? false; + this.userId = fields.userId; + this.fileStrategy = fields.fileStrategy; + this.isAgent = fields.isAgent ?? false; + this.req = fields.req; + this.imageFiles = fields.imageFiles || []; + this.getFiles = fields.getFiles; + this.getStrategyFunctions = fields.getStrategyFunctions; + + if (fields.processFileURL) { + this.processFileURL = fields.processFileURL.bind(this); + } + + this.validateConfig(); + } + + // =========================================================================== + // CONFIGURATION METHODS + // =========================================================================== + + /** + * Determine which provider to use based on configuration + */ + private getProvider(): GeminiProvider { + const provider = process.env.GEMINI_IMAGE_PROVIDER?.toLowerCase(); + if (provider === 'gemini' || provider === 'vertex') { + return provider; + } + // Auto-detect: prefer Vertex AI if GOOGLE_SERVICE_KEY_FILE exists and points to a valid file + const keyFile = process.env.GOOGLE_SERVICE_KEY_FILE; + if (keyFile && fs.existsSync(keyFile)) { + return 'vertex'; + } + return 'gemini'; + } + + /** + * Get the model ID to use for image generation + */ + private getModelId(): string { + return process.env.GEMINI_IMAGE_MODEL || DEFAULT_MODEL_ID; + } + + /** + * Get the credentials file path for Vertex AI + */ + private getCredentialsPath(): string { + return process.env.GOOGLE_SERVICE_KEY_FILE || path.join(process.cwd(), 'data', 'auth.json'); + } + + /** + * Validate configuration based on provider + */ + private validateConfig(): void { + if (this.overrideConfig) { + return; + } + + const provider = this.getProvider(); + + if (provider === 'gemini') { + if (!process.env.GEMINI_API_KEY) { + throw new Error( + 'GEMINI_API_KEY environment variable is required when using Gemini API provider. ' + + 'Set GEMINI_IMAGE_PROVIDER=vertex to use Vertex AI with service account instead.', + ); + } + } else { + const credentialsPath = this.getCredentialsPath(); + if (!fs.existsSync(credentialsPath)) { + throw new Error( + `Google service account credentials file not found at: ${credentialsPath}. ` + + 'Set GEMINI_IMAGE_PROVIDER=gemini and GEMINI_API_KEY to use Gemini API instead.', + ); + } + } + } + + // =========================================================================== + // CLIENT INITIALIZATION + // =========================================================================== + + /** + * Initialize the Gemini client based on the configured provider + */ + private async initializeGeminiClient(): Promise { + const provider = this.getProvider(); + const modelId = this.getModelId(); + + logger.debug(`[GeminiImageGen] Using provider: ${provider}, model: ${modelId}`); + + try { + if (provider === 'gemini') { + const apiKey = process.env.GEMINI_API_KEY; + if (!apiKey) { + throw new Error( + 'GEMINI_API_KEY environment variable is required for Gemini API provider', + ); + } + logger.debug('[GeminiImageGen] Initializing Gemini API client with API key'); + return new GoogleGenAI({ apiKey }); + } + + // Vertex AI with service account + logger.debug('[GeminiImageGen] Initializing Vertex AI client with service account'); + const credentialsPath = this.getCredentialsPath(); + + if (!fs.existsSync(credentialsPath)) { + throw new Error(`Google service account credentials file not found at: ${credentialsPath}`); + } + + let serviceKey; + try { + serviceKey = JSON.parse(fs.readFileSync(credentialsPath, 'utf8')); + } catch (parseError) { + throw new Error( + `Malformed JSON in Google service account credentials file at ${credentialsPath}: ${parseError instanceof Error ? parseError.message : String(parseError)}`, + ); + } + + return new GoogleGenAI({ + vertexai: true, + project: serviceKey.project_id, + location: process.env.GOOGLE_CLOUD_LOCATION || 'global', + }); + } catch (error) { + logger.error('[GeminiImageGen] Error initializing Gemini client:', error); + throw new Error( + `Failed to initialize Gemini client: ${error instanceof Error ? error.message : String(error)}`, + ); + } + } + + // =========================================================================== + // SAFETY CHECKING + // =========================================================================== + + /** + * Check if the API response indicates content was blocked by safety filters + */ + private checkForSafetyBlock(response: GeminiResponse): SafetyBlock | null { + try { + if (!response.candidates || response.candidates.length === 0) { + return { reason: 'NO_CANDIDATES', message: 'No candidates returned by Gemini' }; + } + + const candidate = response.candidates[0]; + + // Check finishReason for safety blocks + if (candidate.finishReason) { + const { finishReason } = candidate; + + if (finishReason === 'SAFETY' || finishReason === 'PROHIBITED_CONTENT') { + return { + reason: finishReason, + message: 'Content was blocked by Gemini safety filters', + safetyRatings: candidate.safetyRatings || [], + }; + } + + if (finishReason === 'RECITATION') { + return { + reason: finishReason, + message: 'Content was blocked due to recitation concerns', + }; + } + } + + // Check safety ratings for blocks + if (candidate.safetyRatings) { + for (const rating of candidate.safetyRatings) { + if (rating.probability === 'HIGH' || rating.blocked === true) { + return { + reason: 'SAFETY_RATING', + message: `Content blocked due to ${rating.category} safety concerns`, + category: rating.category, + probability: rating.probability, + }; + } + } + } + + return null; + } catch (error) { + logger.error('[GeminiImageGen] Error checking safety block:', error); + return null; + } + } + + /** + * Create user-friendly error message for safety blocks + */ + private createSafetyErrorMessage(safetyBlock: SafetyBlock): string { + let errorMessage = + 'I cannot generate this image because it was blocked by content safety filters. '; + + if (safetyBlock.reason === 'SAFETY' || safetyBlock.reason === 'PROHIBITED_CONTENT') { + errorMessage += + 'The prompt may contain content that violates content policies (such as violence, weapons, or inappropriate content). '; + } else if (safetyBlock.reason === 'RECITATION') { + errorMessage += 'The content may be too similar to copyrighted material. '; + } + + errorMessage += 'Please try rephrasing your request with different, safer content.'; + + if (safetyBlock.category) { + errorMessage += ` (Blocked category: ${safetyBlock.category})`; + } + + return errorMessage; + } + + // =========================================================================== + // IMAGE PROCESSING + // =========================================================================== + + /** + * Save image locally for local file strategy + */ + private async saveImageLocally( + base64Data: string, + outputFormat: string = 'png', + userId?: string, + ): Promise { + try { + const imageName = `gemini-img-${uuidv4()}.${outputFormat}`; + const userDir = path.join(process.cwd(), 'client/public/images', userId || 'default'); + + if (!fs.existsSync(userDir)) { + fs.mkdirSync(userDir, { recursive: true }); + } + + const filePath = path.join(userDir, imageName); + const imageBuffer = Buffer.from(base64Data, 'base64'); + fs.writeFileSync(filePath, new Uint8Array(imageBuffer)); + + const relativeUrl = `/images/${userId || 'default'}/${imageName}`; + logger.debug('[GeminiImageGen] Image saved locally to:', filePath); + + return relativeUrl; + } catch (error) { + logger.error('[GeminiImageGen] Error saving image locally:', error); + throw error; + } + } + + /** + * Convert image files to Gemini inlineData format + */ + private async convertImagesToInlineData( + imageFiles: MongoFile[], + image_ids: string[], + ): Promise { + if (!image_ids || image_ids.length === 0) { + return []; + } + + logger.debug('[GeminiImageGen] Converting images to inlineData format for IDs:', image_ids); + + const streamMethods: Record = {}; + const requestFilesMap = Object.fromEntries(imageFiles.map((f) => [f.file_id, { ...f }])); + const orderedFiles: (MongoFile | undefined)[] = new Array(image_ids.length); + const idsToFetch: string[] = []; + const indexOfMissing: Record = Object.create(null); + + // Map existing files and identify missing ones + for (let i = 0; i < image_ids.length; i++) { + const id = image_ids[i]; + const file = requestFilesMap[id]; + + if (file) { + orderedFiles[i] = file; + logger.debug('[GeminiImageGen] Found file in request files:', id); + } else { + idsToFetch.push(id); + indexOfMissing[id] = i; + logger.debug('[GeminiImageGen] Need to fetch file from database:', id); + } + } + + // Fetch missing files from database + if (idsToFetch.length && this.req?.user?.id && this.getFiles) { + logger.debug('[GeminiImageGen] Fetching', idsToFetch.length, 'files from database'); + const fetchedFiles = await this.getFiles( + { + user: this.req.user.id, + file_id: { $in: idsToFetch }, + height: { $exists: true }, + width: { $exists: true }, + }, + {}, + {}, + ); + + logger.debug('[GeminiImageGen] Fetched', fetchedFiles.length, 'files from database'); + for (const file of fetchedFiles) { + requestFilesMap[file.file_id] = file; + orderedFiles[indexOfMissing[file.file_id]] = file; + } + } + + // Convert files to Gemini inlineData format + const inlineDataArray: GeminiInlineData[] = []; + + for (const imageFile of orderedFiles) { + if (!imageFile) { + logger.warn('[GeminiImageGen] Skipping missing image file'); + continue; + } + + try { + const source = imageFile.source || this.fileStrategy; + if (!source) { + logger.error('[GeminiImageGen] No source found for image file:', imageFile.file_id); + continue; + } + + let getDownloadStream: GetDownloadStreamFunction | undefined; + if (streamMethods[source]) { + getDownloadStream = streamMethods[source]; + } else if (this.getStrategyFunctions) { + const functions = this.getStrategyFunctions(source); + getDownloadStream = functions.getDownloadStream; + streamMethods[source] = getDownloadStream; + } + + if (!getDownloadStream || !this.req) { + logger.error('[GeminiImageGen] No download stream method found for source:', source); + continue; + } + + const stream = await getDownloadStream(this.req, imageFile.filepath); + if (!stream) { + logger.error( + '[GeminiImageGen] Failed to get download stream for image:', + imageFile.file_id, + ); + continue; + } + + // Convert stream to buffer then to base64 + const chunks: Uint8Array[] = []; + for await (const chunk of stream) { + const buf = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk); + chunks.push(new Uint8Array(buf)); + } + const buffer = Buffer.concat(chunks); + const base64Data = buffer.toString('base64'); + + const mimeType = imageFile.type || 'image/png'; + + inlineDataArray.push({ + inlineData: { + mimeType, + data: base64Data, + }, + }); + + logger.debug('[GeminiImageGen] Converted image to inlineData:', { + file_id: imageFile.file_id, + mimeType, + dataLength: base64Data.length, + }); + } catch (error) { + logger.error('[GeminiImageGen] Error processing image file:', imageFile.file_id, error); + } + } + + logger.debug( + '[GeminiImageGen] Successfully converted', + inlineDataArray.length, + 'images to inlineData', + ); + return inlineDataArray; + } + + // =========================================================================== + // RESPONSE HANDLING + // =========================================================================== + + /** + * Return value in appropriate format based on configuration + */ + private returnValue(value: T): T | [string, T] { + if (this.returnMetadata && typeof value === 'object') { + return value; + } + + if (this.isAgent) { + return [DISPLAY_MESSAGE, value] as [string, T]; + } + + return value; + } + + /** + * Create error response in appropriate format + */ + private createErrorResponse(errorMessage: string): AgentToolReturn | string { + if (this.isAgent) { + const errorResponse = [ + { + type: ContentTypes.TEXT, + text: errorMessage, + }, + ]; + return [errorResponse, { content: [], file_ids: [] }]; + } + return this.returnValue(errorMessage) as string; + } + + // =========================================================================== + // MAIN EXECUTION + // =========================================================================== + + async _call( + data: z.infer, + ): Promise { + const { prompt, image_ids } = data; + + if (!prompt) { + throw new Error('Missing required field: prompt'); + } + + logger.debug('[GeminiImageGen] Generating image with prompt:', prompt); + logger.debug('[GeminiImageGen] Image IDs provided:', image_ids); + logger.debug('[GeminiImageGen] Available imageFiles:', this.imageFiles.length); + + // Initialize client + let ai: GoogleGenAI; + try { + ai = await this.initializeGeminiClient(); + } catch (error) { + const errorMsg = error instanceof Error ? error.message : String(error); + logger.error('[GeminiImageGen] Failed to initialize Gemini client:', error); + return this.createErrorResponse(`Failed to initialize Gemini client: ${errorMsg}`); + } + + // Build request contents + const contents: GeminiContentPart[] = [{ text: replaceUnwantedChars(prompt) }]; + + // Add context images if provided + if (image_ids && image_ids.length > 0) { + logger.debug('[GeminiImageGen] Processing context images...'); + const contextImages = await this.convertImagesToInlineData(this.imageFiles, image_ids); + + for (const imageData of contextImages) { + contents.push(imageData); + } + + logger.debug('[GeminiImageGen] Added', contextImages.length, 'context images to request'); + } else { + logger.debug('[GeminiImageGen] No image context provided - text-only generation'); + } + + logger.debug('[GeminiImageGen] Final contents array length:', contents.length); + + // Generate image + let apiResponse: GeminiResponse; + try { + const modelId = this.getModelId(); + apiResponse = (await ai.models.generateContent({ + model: modelId, + contents: contents, + config: { + responseModalities: ['TEXT', 'IMAGE'], + }, + })) as GeminiResponse; + + logger.debug(`[GeminiImageGen] Received response from Gemini (model: ${modelId})`); + } catch (error) { + const errorMsg = error instanceof Error ? error.message : String(error); + logger.error('[GeminiImageGen] Problem generating the image:', error); + return this.createErrorResponse( + `Something went wrong when trying to generate the image. The Gemini API may be unavailable:\nError Message: ${errorMsg}`, + ); + } + + // Validate response + if (!apiResponse?.candidates?.[0]) { + return this.createErrorResponse( + 'Something went wrong when trying to generate the image. The Gemini API may be unavailable', + ); + } + + // Check for safety blocks + const safetyBlock = this.checkForSafetyBlock(apiResponse); + if (safetyBlock) { + logger.warn('[GeminiImageGen] Content blocked by safety filters:', safetyBlock); + return this.createErrorResponse(this.createSafetyErrorMessage(safetyBlock)); + } + + // Extract image data + const imageData = apiResponse.candidates[0].content?.parts?.find((part) => part.inlineData) + ?.inlineData?.data; + + if (!imageData) { + return this.handleMissingImageData(apiResponse); + } + + logger.debug('[GeminiImageGen] Successfully extracted image data'); + + // Save image and create response + const imageUrl = await this.saveGeneratedImage(imageData); + + // Build response in OpenAI-compatible format + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { url: imageUrl }, + }, + ]; + + const file_ids = [uuidv4()]; + const textResponse = [ + { + type: ContentTypes.TEXT, + text: + DISPLAY_MESSAGE + + `\n\ngenerated_image_id: "${file_ids[0]}"` + + (image_ids && image_ids.length > 0 + ? `\nreferenced_image_ids: ["${image_ids.join('", "')}"]` + : ''), + }, + ]; + + return [textResponse, { content, file_ids }]; + } + + /** + * Handle case where no image data is returned + */ + private handleMissingImageData(apiResponse: GeminiResponse): AgentToolReturn | string { + const candidate = apiResponse.candidates![0]; + + logger.warn('[GeminiImageGen] No image data in response. Candidate details:', { + finishReason: candidate.finishReason, + hasContent: !!candidate.content, + contentParts: candidate.content?.parts?.length || 0, + safetyRatings: candidate.safetyRatings?.length || 0, + contentPartsTypes: candidate.content?.parts?.map((p) => Object.keys(p)) || [], + }); + + // Check for safety ratings + let safetyIssue = null; + if (candidate.safetyRatings && candidate.safetyRatings.length > 0) { + for (const rating of candidate.safetyRatings) { + if ( + rating.probability === 'HIGH' || + rating.probability === 'MEDIUM' || + rating.blocked === true + ) { + safetyIssue = rating; + break; + } + } + } + + let errorMessage: string; + if (safetyIssue) { + errorMessage = `I cannot generate this image because it was blocked by content safety filters. The content was flagged for ${safetyIssue.category} (probability: ${safetyIssue.probability}). Please try rephrasing your request with different, safer content.`; + logger.warn('[GeminiImageGen] Content blocked by safety filter:', safetyIssue); + } else if ( + candidate.finishReason === 'SAFETY' || + candidate.finishReason === 'PROHIBITED_CONTENT' + ) { + errorMessage = + 'I cannot generate this image because it was blocked by content safety filters. Please try rephrasing your request with different, safer content.'; + logger.warn('[GeminiImageGen] Content blocked by finishReason:', candidate.finishReason); + } else { + errorMessage = + 'No image was generated. This might be due to content safety filters blocking the request, or the model being unable to create the requested image. Please try rephrasing your prompt with different content.'; + logger.warn('[GeminiImageGen] Unknown reason for missing image data'); + } + + return this.createErrorResponse(errorMessage); + } + + /** + * Save generated image using appropriate strategy + */ + private async saveGeneratedImage(imageData: string): Promise { + let imageUrl = `data:image/png;base64,${imageData}`; + + if (this.fileStrategy === FileSources.local || this.fileStrategy === 'local') { + logger.debug('[GeminiImageGen] Local strategy detected - using direct save + data URL'); + try { + await this.saveImageLocally(imageData, 'png', this.userId); + logger.debug('[GeminiImageGen] Image saved locally successfully'); + } catch (error) { + logger.error('[GeminiImageGen] Local save failed:', error); + } + } else { + logger.debug('[GeminiImageGen] Cloud strategy detected - using OpenAI pattern'); + try { + const storageUrl = await saveBase64ImageToStorage({ + base64Data: imageData, + outputFormat: 'png', + processFileURL: this.processFileURL, + fileStrategy: this.fileStrategy, + userId: this.userId, + }); + + if (storageUrl) { + imageUrl = storageUrl; + logger.debug('[GeminiImageGen] Image saved to storage:', storageUrl); + } else { + logger.warn('[GeminiImageGen] Could not save to storage, using data URL'); + } + } catch (error) { + logger.error('[GeminiImageGen] Error saving image to storage:', error); + logger.warn('[GeminiImageGen] Falling back to data URL'); + } + } + + return imageUrl; + } +} + +export default GeminiImageGen; diff --git a/packages/api/src/tools/gemini/constants.ts b/packages/api/src/tools/gemini/constants.ts new file mode 100644 index 000000000000..becfb5c27922 --- /dev/null +++ b/packages/api/src/tools/gemini/constants.ts @@ -0,0 +1,78 @@ +/** + * Display message shown to users after image generation + * DO NOT MODIFY - kept exactly as original + */ +export const DISPLAY_MESSAGE = + "Gemini displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; + +/** + * Tool description shown in UI and to users + * DO NOT MODIFY - kept exactly as original + */ +export const TOOL_DESCRIPTION = `Generates high-quality, original images based on text prompts, with optional image context. + +When to use \`gemini_image_gen\`: +- To create entirely new images from detailed text descriptions +- To generate images using existing images as context or inspiration +- When the user requests image generation, creation, or asks to "generate an image" +- When the user asks to "edit", "modify", "change", or "swap" elements in an image (generates new image with changes) + +When NOT to use \`gemini_image_gen\`: +- For uploading or saving existing images without modification + +Generated image IDs will be returned in the response, so you can refer to them in future requests.`; + +/** + * Description for the model/agent + * DO NOT MODIFY - kept exactly as original + */ +export const DESCRIPTION_FOR_MODEL = `Use this tool to generate images from text descriptions using Google Gemini. +1. Prompts should be detailed and specific for best results. +2. One image per function call. Create only 1 image per request. +3. IMPORTANT: When user asks to "edit", "modify", "change", or "swap" elements in an existing image: + - ALWAYS include the original image ID in the image_ids array + - Describe the desired changes clearly in the prompt + - The tool will generate a new image based on the original image context + your prompt +4. IMPORTANT: For editing requests, use DIRECT editing instructions: + - User says "remove the gun" → prompt should be "remove the gun from this image" + - User says "make it blue" → prompt should be "make this image blue" + - User says "add sunglasses" → prompt should be "add sunglasses to this image" + - DO NOT reconstruct or modify the original prompt - use the user's editing instruction directly + - ALWAYS include the image being edited in image_ids array +5. OPTIONAL: Use image_ids to provide context images that will influence the generation: + - Include any relevant image IDs from the conversation in the image_ids array + - These images will be used as visual context/inspiration for the new generation + - For "editing" requests, always include the image being "edited" +6. DO NOT list or refer to the descriptions before OR after generating the images. +7. Always mention the image type (photo, oil painting, watercolor painting, illustration, cartoon, drawing, vector, render, etc.) at the beginning of the prompt. + +The prompt should be a detailed paragraph describing every part of the image in concrete, objective detail.`; + +/** + * Prompt field description for schema + * DO NOT MODIFY - kept exactly as original + */ +export const PROMPT_DESCRIPTION = + 'A detailed text description of the desired image, up to 32000 characters. For "editing" requests, describe the changes you want to make to the referenced image. Be specific about composition, style, lighting, and subject matter.'; + +/** + * Image IDs field description for schema + * DO NOT MODIFY - kept exactly as original + */ +export const IMAGE_IDS_DESCRIPTION = ` +Optional array of image IDs to use as visual context for generation. + +Guidelines: +- For "editing" requests: ALWAYS include the image ID being "edited" +- For new generation with context: Include any relevant reference image IDs +- If the user's request references any prior images, include their image IDs in this array +- These images will be used as visual context/inspiration for the new generation +- Never invent or hallucinate IDs; only use IDs that are visible in the conversation +- If no images are relevant, omit this field entirely +`.trim(); + +/** Default model ID for Gemini image generation */ +export const DEFAULT_MODEL_ID = 'gemini-2.5-flash-image-preview'; + +/** Tool name identifier */ +export const TOOL_NAME = 'gemini_image_gen'; diff --git a/packages/api/src/tools/gemini/index.ts b/packages/api/src/tools/gemini/index.ts new file mode 100644 index 000000000000..be5e6eebdd46 --- /dev/null +++ b/packages/api/src/tools/gemini/index.ts @@ -0,0 +1,3 @@ +export * from './types'; +export * from './constants'; +export * from './GeminiImageGen'; diff --git a/packages/api/src/tools/gemini/types.ts b/packages/api/src/tools/gemini/types.ts new file mode 100644 index 000000000000..707f6a1bc6d3 --- /dev/null +++ b/packages/api/src/tools/gemini/types.ts @@ -0,0 +1,137 @@ +import type { ServerRequest } from '../../types'; + +/** MongoDB file interface for image files */ +export interface MongoFile { + file_id: string; + filepath: string; + source?: string; + type?: string; + height?: number; + width?: number; +} + +/** Gemini API safety rating */ +export interface SafetyRating { + category: string; + probability: string; + blocked?: boolean; +} + +/** Gemini API response candidate */ +export interface GeminiCandidate { + finishReason?: string; + content?: { + parts?: Array<{ + text?: string; + inlineData?: { + mimeType: string; + data: string; + }; + }>; + }; + safetyRatings?: SafetyRating[]; +} + +/** Gemini API response */ +export interface GeminiResponse { + candidates?: GeminiCandidate[]; +} + +/** Safety block information */ +export interface SafetyBlock { + reason: string; + message: string; + safetyRatings?: SafetyRating[]; + category?: string; + probability?: string; +} + +/** Inline data format for Gemini API */ +export interface GeminiInlineData { + inlineData: { + mimeType: string; + data: string; + }; +} + +/** Content part for Gemini API request */ +export type GeminiContentPart = { text: string } | GeminiInlineData; + +/** Function type for fetching files from database */ +export type GetFilesFunction = ( + query: Record, + projection?: Record, + options?: Record, +) => Promise; + +/** Function type for getting download stream */ +export type GetDownloadStreamFunction = ( + req: ServerRequest, + filepath: string, +) => Promise; + +/** Function type for getting strategy functions */ +export type GetStrategyFunctionsType = (source: string) => { + getDownloadStream: GetDownloadStreamFunction; +}; + +/** Constructor fields for GeminiImageGen */ +export interface GeminiImageGenFields { + override?: boolean; + returnMetadata?: boolean; + userId?: string; + fileStrategy?: string; + isAgent?: boolean; + req?: ServerRequest; + processFileURL?: (params: ProcessFileURLParams) => Promise<{ filepath: string }>; + imageFiles?: MongoFile[]; + /** Function to fetch files from database */ + getFiles?: GetFilesFunction; + /** Function to get file strategy functions (download streams) */ + getStrategyFunctions?: GetStrategyFunctionsType; +} + +/** Parameters for processFileURL function */ +export interface ProcessFileURLParams { + URL: string; + basePath: string; + userId: string; + fileName: string; + fileStrategy: string; + context: string; +} + +/** Parameters for saving base64 image to storage */ +export interface SaveBase64ImageParams { + base64Data: string; + outputFormat: string; + processFileURL?: (params: ProcessFileURLParams) => Promise<{ filepath: string }>; + fileStrategy?: string; + userId?: string; +} + +/** Image generation result content */ +export interface ImageResultContent { + type: string; + image_url: { + url: string; + }; +} + +/** Text response content */ +export interface TextResponseContent { + type: string; + text: string; +} + +/** Artifact result */ +export interface ArtifactResult { + content: ImageResultContent[]; + file_ids: string[]; +} + +/** Tool return type for agent mode */ +export type AgentToolReturn = [TextResponseContent[], ArtifactResult]; + +/** Provider type */ +export type GeminiProvider = 'gemini' | 'vertex'; diff --git a/packages/api/src/tools/index.ts b/packages/api/src/tools/index.ts index eb375902f1be..161473ee515f 100644 --- a/packages/api/src/tools/index.ts +++ b/packages/api/src/tools/index.ts @@ -1,2 +1,3 @@ export * from './format'; export * from './toolkits'; +export * from './gemini'; diff --git a/packages/api/src/tools/toolkits/imageContext.ts b/packages/api/src/tools/toolkits/imageContext.ts new file mode 100644 index 000000000000..fc49d3e2a0d1 --- /dev/null +++ b/packages/api/src/tools/toolkits/imageContext.ts @@ -0,0 +1,47 @@ +export interface ImageToolContextOptions { + /** The tool key identifier, e.g., 'gemini_image_gen' or 'image_edit_oai' */ + toolKey: string; + /** The purpose description, e.g., 'image context' or 'image editing' */ + purpose: string; +} + +/** + * Creates tool context for image generation/editing tools by building a description + * of available image files that can be used as visual context. + * + * @param imageFiles - Array of image file objects with file_id property + * @param toolContextMap - Map to store the generated context, keyed by tool name + * @param options - Configuration options including toolKey and purpose + * @returns The generated tool context string, or empty string if no images + */ +export function createImageToolContext( + imageFiles: Array<{ file_id: string } | undefined>, + toolContextMap: Record, + options: ImageToolContextOptions, +): string { + const { toolKey, purpose } = options; + let toolContext = ''; + + for (let i = 0; i < imageFiles.length; i++) { + const file = imageFiles[i]; + if (!file) { + continue; + } + + if (i === 0) { + toolContext = `Image files provided in this request (their image IDs listed in order of appearance) available for ${purpose}:`; + } + + toolContext += `\n\t- ${file.file_id}`; + + if (i === imageFiles.length - 1) { + toolContext += `\n\nInclude any you need in the \`image_ids\` array when calling \`${toolKey}\`. You may also include previously referenced or generated image IDs.`; + } + } + + if (toolContext) { + toolContextMap[toolKey] = toolContext; + } + + return toolContext; +} diff --git a/packages/api/src/tools/toolkits/index.ts b/packages/api/src/tools/toolkits/index.ts index 33807c673bcf..2fb438981923 100644 --- a/packages/api/src/tools/toolkits/index.ts +++ b/packages/api/src/tools/toolkits/index.ts @@ -1,2 +1,3 @@ export * from './oai'; export * from './yt'; +export * from './imageContext'; diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 495d7cae6def..d84622eadf17 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -1199,7 +1199,13 @@ export function validateVisionModel({ return visionModels.concat(additionalModels).some((visionModel) => model.includes(visionModel)); } -export const imageGenTools = new Set(['dalle', 'dall-e', 'stable-diffusion', 'flux']); +export const imageGenTools = new Set([ + 'dalle', + 'dall-e', + 'stable-diffusion', + 'flux', + 'gemini_image_gen', +]); /** * Enum for collections using infinite queries