diff --git a/packages/firebase_vertexai/firebase_vertexai/example/ios/Runner/AppDelegate.swift b/packages/firebase_vertexai/firebase_vertexai/example/ios/Runner/AppDelegate.swift index 70693e4a8c12..b6363034812b 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/ios/Runner/AppDelegate.swift +++ b/packages/firebase_vertexai/firebase_vertexai/example/ios/Runner/AppDelegate.swift @@ -1,7 +1,7 @@ import UIKit import Flutter -@UIApplicationMain +@main @objc class AppDelegate: FlutterAppDelegate { override func application( _ application: UIApplication, diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart index aa70b575c1dc..0e3855eee06a 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart @@ -22,7 +22,7 @@ import 'pages/function_calling_page.dart'; import 'pages/image_prompt_page.dart'; import 'pages/token_count_page.dart'; import 'pages/schema_page.dart'; -import 'pages/storage_uri_page.dart'; +import 'pages/imagen_page.dart'; // REQUIRED if you want to run on Web const FirebaseOptions? options = null; @@ -79,7 +79,7 @@ class _HomeScreenState extends State { title: 'Function Calling', ), // function calling will initial its own model ImagePromptPage(title: 'Image Prompt', model: widget.model), - StorageUriPromptPage(title: 'Storage URI Prompt', model: widget.model), + ImagenPage(title: 'Imagen Model', model: widget.model), SchemaPromptPage(title: 'Schema Prompt', model: widget.model), ]; @@ -134,11 +134,11 @@ class _HomeScreenState extends State { ), BottomNavigationBarItem( icon: Icon( - Icons.folder, + Icons.image_search, color: Theme.of(context).colorScheme.primary, ), - label: 'Storage URI Prompt', - tooltip: 'Storage URI Prompt', + label: 'Imagen Model', + tooltip: 'Imagen Model', ), BottomNavigationBarItem( icon: Icon( diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/image_prompt_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/image_prompt_page.dart index f8b111296287..0d84c5941c03 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/image_prompt_page.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/image_prompt_page.dart @@ -89,14 +89,28 @@ class _ImagePromptPageState extends State { const SizedBox.square( dimension: 15, ), - ElevatedButton( - onPressed: !_loading - ? () async { - await _sendImagePrompt(_textController.text); - } - : null, - child: const Text('Send Image Prompt'), - ), + if (!_loading) + IconButton( + onPressed: () async { + await _sendImagePrompt(_textController.text); + }, + icon: Icon( + Icons.image, + color: Theme.of(context).colorScheme.primary, + ), + ), + if (!_loading) + IconButton( + onPressed: () async { + await _sendStorageUriPrompt(_textController.text); + }, + icon: Icon( + Icons.storage, + color: Theme.of(context).colorScheme.primary, + ), + ) + else + const CircularProgressIndicator(), ], ), ), @@ -162,6 +176,49 @@ class _ImagePromptPageState extends State { } } + Future _sendStorageUriPrompt(String message) async { + setState(() { + _loading = true; + }); + try { + final content = [ + Content.multi([ + TextPart(message), + FileData( + 'image/jpeg', + 'gs://vertex-ai-example-ef5a2.appspot.com/foodpic.jpg', + ), + ]), + ]; + _generatedContent.add(MessageData(text: message, fromUser: true)); + + var response = await widget.model.generateContent(content); + var text = response.text; + _generatedContent.add(MessageData(text: text, fromUser: false)); + + if (text == null) { + _showError('No response from API.'); + return; + } else { + setState(() { + _loading = false; + _scrollDown(); + }); + } + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + void _showError(String message) { showDialog( context: context, diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart new file mode 100644 index 000000000000..3fa103419c41 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart @@ -0,0 +1,229 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'package:flutter/material.dart'; +import 'package:firebase_vertexai/firebase_vertexai.dart'; +//import 'package:firebase_storage/firebase_storage.dart'; +import '../widgets/message_widget.dart'; + +class ImagenPage extends StatefulWidget { + const ImagenPage({ + super.key, + required this.title, + required this.model, + }); + + final String title; + final GenerativeModel model; + + @override + State createState() => _ImagenPageState(); +} + +class _ImagenPageState extends State { + final ScrollController _scrollController = ScrollController(); + final TextEditingController _textController = TextEditingController(); + final FocusNode _textFieldFocus = FocusNode(); + final List _generatedContent = []; + bool _loading = false; + late final ImagenModel _imagenModel; + + @override + void initState() { + super.initState(); + var generationConfig = ImagenGenerationConfig( + negativePrompt: 'frog', + numberOfImages: 1, + aspectRatio: ImagenAspectRatio.square1x1, + imageFormat: ImagenFormat.jpeg(compressionQuality: 75), + ); + _imagenModel = FirebaseVertexAI.instance.imagenModel( + model: 'imagen-3.0-generate-001', + generationConfig: generationConfig, + safetySettings: ImagenSafetySettings( + ImagenSafetyFilterLevel.blockLowAndAbove, + ImagenPersonFilterLevel.allowAdult, + ), + ); + } + + void _scrollDown() { + WidgetsBinding.instance.addPostFrameCallback( + (_) => _scrollController.animateTo( + _scrollController.position.maxScrollExtent, + duration: const Duration( + milliseconds: 750, + ), + curve: Curves.easeOutCirc, + ), + ); + } + + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar( + title: Text(widget.title), + ), + body: Padding( + padding: const EdgeInsets.all(8), + child: Column( + mainAxisAlignment: MainAxisAlignment.center, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Expanded( + child: ListView.builder( + controller: _scrollController, + itemBuilder: (context, idx) { + return MessageWidget( + text: _generatedContent[idx].text, + image: _generatedContent[idx].image, + isFromUser: _generatedContent[idx].fromUser ?? false, + ); + }, + itemCount: _generatedContent.length, + ), + ), + Padding( + padding: const EdgeInsets.symmetric( + vertical: 25, + horizontal: 15, + ), + child: Row( + children: [ + Expanded( + child: TextField( + autofocus: true, + focusNode: _textFieldFocus, + controller: _textController, + ), + ), + const SizedBox.square( + dimension: 15, + ), + if (!_loading) + IconButton( + onPressed: () async { + await _testImagen(_textController.text); + }, + icon: Icon( + Icons.image_search, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Imagen raw data', + ) + else + const CircularProgressIndicator(), + // NOTE: Keep this API private until future release. + // if (!_loading) + // IconButton( + // onPressed: () async { + // await _testImagenGCS(_textController.text); + // }, + // icon: Icon( + // Icons.imagesearch_roller, + // color: Theme.of(context).colorScheme.primary, + // ), + // tooltip: 'Imagen GCS', + // ) + // else + // const CircularProgressIndicator(), + ], + ), + ), + ], + ), + ), + ); + } + + Future _testImagen(String prompt) async { + setState(() { + _loading = true; + }); + + var response = await _imagenModel.generateImages(prompt); + + if (response.images.isNotEmpty) { + var imagenImage = response.images[0]; + // Process the image + _generatedContent.add( + MessageData( + image: Image.memory(imagenImage.bytesBase64Encoded), + text: prompt, + fromUser: false, + ), + ); + } else { + // Handle the case where no images were generated + _showError('Error: No images were generated.'); + } + setState(() { + _loading = false; + _scrollDown(); + }); + } + // NOTE: Keep this API private until future release. + // Future _testImagenGCS(String prompt) async { + // setState(() { + // _loading = true; + // }); + // var gcsUrl = 'gs://vertex-ai-example-ef5a2.appspot.com/imagen'; + + // var response = await _imagenModel.generateImagesGCS(prompt, gcsUrl); + + // if (response.images.isNotEmpty) { + // var imagenImage = response.images[0]; + // final returnImageUri = imagenImage.gcsUri; + // final reference = FirebaseStorage.instance.refFromURL(returnImageUri); + // final downloadUrl = await reference.getDownloadURL(); + // // Process the image + // _generatedContent.add( + // MessageData( + // image: Image(image: NetworkImage(downloadUrl)), + // text: prompt, + // fromUser: false, + // ), + // ); + // } else { + // // Handle the case where no images were generated + // _showError('Error: No images were generated.'); + // } + // setState(() { + // _loading = false; + // }); + // } + + void _showError(String message) { + showDialog( + context: context, + builder: (context) { + return AlertDialog( + title: const Text('Something went wrong'), + content: SingleChildScrollView( + child: SelectableText(message), + ), + actions: [ + TextButton( + onPressed: () { + Navigator.of(context).pop(); + }, + child: const Text('OK'), + ), + ], + ); + }, + ); + } +} diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/storage_uri_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/storage_uri_page.dart deleted file mode 100644 index b1624d52efc6..000000000000 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/storage_uri_page.dart +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import 'package:flutter/material.dart'; -import 'package:firebase_vertexai/firebase_vertexai.dart'; -import '../widgets/message_widget.dart'; - -class StorageUriPromptPage extends StatefulWidget { - const StorageUriPromptPage({ - super.key, - required this.title, - required this.model, - }); - - final String title; - final GenerativeModel model; - - @override - State createState() => _StorageUriPromptPageState(); -} - -class _StorageUriPromptPageState extends State { - final ScrollController _scrollController = ScrollController(); - final TextEditingController _textController = TextEditingController(); - final FocusNode _textFieldFocus = FocusNode(); - final List _messages = []; - bool _loading = false; - - void _scrollDown() { - WidgetsBinding.instance.addPostFrameCallback( - (_) => _scrollController.animateTo( - _scrollController.position.maxScrollExtent, - duration: const Duration( - milliseconds: 750, - ), - curve: Curves.easeOutCirc, - ), - ); - } - - @override - Widget build(BuildContext context) { - return Scaffold( - appBar: AppBar( - title: Text(widget.title), - ), - body: Padding( - padding: const EdgeInsets.all(8), - child: Column( - mainAxisAlignment: MainAxisAlignment.center, - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Expanded( - child: ListView.builder( - controller: _scrollController, - itemBuilder: (context, idx) { - return MessageWidget( - text: _messages[idx].text, - isFromUser: _messages[idx].fromUser ?? false, - ); - }, - itemCount: _messages.length, - ), - ), - Padding( - padding: const EdgeInsets.symmetric( - vertical: 25, - horizontal: 15, - ), - child: Row( - children: [ - Expanded( - child: TextField( - autofocus: true, - focusNode: _textFieldFocus, - controller: _textController, - ), - ), - const SizedBox.square( - dimension: 15, - ), - ElevatedButton( - onPressed: !_loading - ? () async { - await _sendStorageUriPrompt(_textController.text); - } - : null, - child: const Text('Send Storage URI Prompt'), - ), - ], - ), - ), - ], - ), - ), - ); - } - - Future _sendStorageUriPrompt(String message) async { - setState(() { - _loading = true; - }); - try { - final content = [ - Content.multi([ - TextPart(message), - FileData( - 'image/jpeg', - 'gs://vertex-ai-example-ef5a2.appspot.com/foodpic.jpg', - ), - ]), - ]; - _messages.add(MessageData(text: message, fromUser: true)); - - var response = await widget.model.generateContent(content); - var text = response.text; - _messages.add(MessageData(text: text, fromUser: false)); - - if (text == null) { - _showError('No response from API.'); - return; - } else { - setState(() { - _loading = false; - _scrollDown(); - }); - } - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); - } - } - - void _showError(String message) { - showDialog( - context: context, - builder: (context) { - return AlertDialog( - title: const Text('Something went wrong'), - content: SingleChildScrollView( - child: SelectableText(message), - ), - actions: [ - TextButton( - onPressed: () { - Navigator.of(context).pop(); - }, - child: const Text('OK'), - ), - ], - ); - }, - ); - } -} diff --git a/packages/firebase_vertexai/firebase_vertexai/example/pubspec.yaml b/packages/firebase_vertexai/firebase_vertexai/example/pubspec.yaml index f87dde5eca16..ea7a04750d02 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/pubspec.yaml +++ b/packages/firebase_vertexai/firebase_vertexai/example/pubspec.yaml @@ -20,6 +20,7 @@ dependencies: # Use with the CupertinoIcons class for iOS style icons. cupertino_icons: ^1.0.6 firebase_core: ^3.12.0 + firebase_storage: ^12.4.1 firebase_vertexai: ^1.3.0 flutter: sdk: flutter diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart b/packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart index a170d8deb969..17b05428ed3a 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart @@ -28,6 +28,7 @@ export 'src/api.dart' PromptFeedback, SafetyRating, SafetySetting, + // TODO(cynthiajiang) remove in next breaking change. TaskType, UsageMetadata; export 'src/chat.dart' show ChatSession, StartChatExtension; @@ -55,5 +56,15 @@ export 'src/function_calling.dart' FunctionDeclaration, Tool, ToolConfig; -export 'src/model.dart' show GenerativeModel; +export 'src/generative_model.dart' show GenerativeModel; +export 'src/imagen_api.dart' + show + ImagenSafetySettings, + ImagenFormat, + ImagenSafetyFilterLevel, + ImagenPersonFilterLevel, + ImagenGenerationConfig, + ImagenAspectRatio; +export 'src/imagen_content.dart' show ImagenInlineImage; +export 'src/imagen_model.dart' show ImagenModel; export 'src/schema.dart' show Schema, SchemaType; diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart index 14052ab539ef..3b3143fecfac 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart @@ -18,7 +18,7 @@ import 'schema.dart'; /// Response for Count Tokens final class CountTokensResponse { - /// Constructor + // ignore: public_member_api_docs CountTokensResponse(this.totalTokens, {this.totalBillableCharacters, this.promptTokensDetails}); @@ -38,7 +38,7 @@ final class CountTokensResponse { /// Response from the model; supports multiple candidates. final class GenerateContentResponse { - /// Constructor + // ignore: public_member_api_docs GenerateContentResponse(this.candidates, this.promptFeedback, {this.usageMetadata}); @@ -112,7 +112,7 @@ final class GenerateContentResponse { /// Feedback metadata of a prompt specified in a [GenerativeModel] request. final class PromptFeedback { - /// Constructor + // ignore: public_member_api_docs PromptFeedback(this.blockReason, this.blockReasonMessage, this.safetyRatings); /// If set, the prompt was blocked and no candidates are returned. @@ -131,7 +131,7 @@ final class PromptFeedback { /// Metadata on the generation request's token usage. final class UsageMetadata { - /// Constructor + // ignore: public_member_api_docs UsageMetadata._( {this.promptTokenCount, this.candidatesTokenCount, @@ -158,7 +158,7 @@ final class UsageMetadata { /// Response candidate generated from a [GenerativeModel]. final class Candidate { // TODO: token count? - /// Constructor + // ignore: public_member_api_docs Candidate(this.content, this.safetyRatings, this.citationMetadata, this.finishReason, this.finishMessage); @@ -223,7 +223,7 @@ final class Candidate { /// safety across a number of harm categories and the probability of the harm /// classification is included here. final class SafetyRating { - /// Constructor + // ignore: public_member_api_docs SafetyRating(this.category, this.probability, {this.probabilityScore, this.isBlocked, @@ -417,7 +417,7 @@ enum HarmSeverity { /// Source attributions for a piece of content. final class CitationMetadata { - /// Constructor + // ignore: public_member_api_docs CitationMetadata(this.citations); /// Citations to sources for a specific response. @@ -426,7 +426,7 @@ final class CitationMetadata { /// Citation to a source for a portion of a specific response. final class Citation { - /// Constructor + // ignore: public_member_api_docs Citation(this.startIndex, this.endIndex, this.uri, this.license); /// Start of segment of the response that is attributed to this source. @@ -553,7 +553,7 @@ enum ContentModality { /// Passing a safety setting for a category changes the allowed probability that /// content is blocked. final class SafetySetting { - /// Constructor + // ignore: public_member_api_docs SafetySetting(this.category, this.threshold); /// The category for this setting. @@ -609,7 +609,7 @@ enum HarmBlockThreshold { /// Configuration options for model generation and outputs. final class GenerationConfig { - /// Constructor + // ignore: public_member_api_docs GenerationConfig( {this.candidateCount, this.stopSequences, diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart new file mode 100644 index 000000000000..65e64550ebf1 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart @@ -0,0 +1,117 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import 'dart:async'; + +import 'package:firebase_app_check/firebase_app_check.dart'; +import 'package:firebase_auth/firebase_auth.dart'; +import 'package:firebase_core/firebase_core.dart'; + +import 'client.dart'; +import 'vertex_version.dart'; + +/// [Task] enum class for [GenerativeModel] to make request. +enum Task { + /// Request type to generate content. + generateContent, + + /// Request type to stream content. + streamGenerateContent, + + /// Request type to count token. + countTokens, + + /// Request type to talk to Prediction Services like Imagen. + predict, +} + +/// Base class for models. +/// +/// Do not instantiate directly. +abstract class BaseModel { + // ignore: public_member_api_docs + BaseModel({ + required String model, + required String location, + required FirebaseApp app, + required ApiClient client, + }) : _model = normalizeModelName(model), + _projectUri = _vertexUri(app, location), + _client = client; + + static const _baseUrl = 'firebasevertexai.googleapis.com'; + static const _apiVersion = 'v1beta'; + + final ({String prefix, String name}) _model; + + final Uri _projectUri; + final ApiClient _client; + + /// The normalized model name. + ({String prefix, String name}) get model => _model; + + /// The API client. + ApiClient get client => _client; + + /// Returns the model code for a user friendly model name. + /// + /// If the model name is already a model code (contains a `/`), use the parts + /// directly. Otherwise, return a `models/` model code. + static ({String prefix, String name}) normalizeModelName(String modelName) { + if (!modelName.contains('/')) return (prefix: 'models', name: modelName); + final parts = modelName.split('/'); + return (prefix: parts.first, name: parts.skip(1).join('/')); + } + + static Uri _vertexUri(FirebaseApp app, String location) { + var projectId = app.options.projectId; + return Uri.https( + _baseUrl, + '/$_apiVersion/projects/$projectId/locations/$location/publishers/google', + ); + } + + /// Returns a function that generates Firebase auth tokens. + static FutureOr> Function() firebaseTokens( + FirebaseAppCheck? appCheck, FirebaseAuth? auth) { + return () async { + Map headers = {}; + // Override the client name in Google AI SDK + headers['x-goog-api-client'] = + 'gl-dart/$packageVersion fire/$packageVersion'; + if (appCheck != null) { + final appCheckToken = await appCheck.getToken(); + if (appCheckToken != null) { + headers['X-Firebase-AppCheck'] = appCheckToken; + } + } + if (auth != null) { + final idToken = await auth.currentUser?.getIdToken(); + if (idToken != null) { + headers['Authorization'] = 'Firebase $idToken'; + } + } + return headers; + }; + } + + /// Returns a URI for the given [task]. + Uri taskUri(Task task) => _projectUri.replace( + pathSegments: _projectUri.pathSegments + .followedBy([_model.prefix, '${_model.name}:${task.name}'])); + + /// Make a unary request for [task] with JSON encodable [params]. + Future makeRequest(Task task, Map params, + T Function(Map) parse) => + _client.makeRequest(taskUri(task), params).then(parse); +} diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/chat.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/chat.dart index fa9229f46552..553ccb4e05bd 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/chat.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/chat.dart @@ -16,7 +16,7 @@ import 'dart:async'; import 'api.dart'; import 'content.dart'; -import 'model.dart'; +import 'generative_model.dart'; import 'utils/mutex.dart'; /// A back-and-forth chat with a generative model. diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/content.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/content.dart index 1fa79d2f5bf0..74a435fd6ef6 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/content.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/content.dart @@ -18,7 +18,7 @@ import 'error.dart'; /// The base structured datatype containing multi-part content of a message. final class Content { - /// Constructor + // ignore: public_member_api_docs Content(this.role, this.parts); /// The producer of the content. @@ -112,7 +112,7 @@ sealed class Part { /// A [Part] with the text content. final class TextPart implements Part { - /// Constructor + // ignore: public_member_api_docs TextPart(this.text); /// The text content of the [Part] @@ -123,7 +123,7 @@ final class TextPart implements Part { /// A [Part] with the byte content of a file. final class InlineDataPart implements Part { - /// Constructor + // ignore: public_member_api_docs InlineDataPart(this.mimeType, this.bytes); /// File type of the [InlineDataPart]. @@ -142,7 +142,7 @@ final class InlineDataPart implements Part { /// a string representing the `FunctionDeclaration.name` with the /// arguments and their values. final class FunctionCall implements Part { - /// Constructor + // ignore: public_member_api_docs FunctionCall(this.name, this.args); /// The name of the function to call. @@ -160,7 +160,7 @@ final class FunctionCall implements Part { /// The response class for [FunctionCall] final class FunctionResponse implements Part { - /// Constructor + // ignore: public_member_api_docs FunctionResponse(this.name, this.response); /// The name of the function that was called. @@ -180,7 +180,7 @@ final class FunctionResponse implements Part { /// A [Part] with Firebase Storage uri as prompt content final class FileData implements Part { - /// Constructor + // ignore: public_member_api_docs FileData(this.mimeType, this.fileUri); /// File type of the [FileData]. diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/error.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/error.dart index 66ab38e325c0..ad4a5e09d9e2 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/error.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/error.dart @@ -16,7 +16,7 @@ /// /// The [message] may explain the cause of the failure. final class VertexAIException implements Exception { - /// Constructor + // ignore: public_member_api_docs VertexAIException(this.message); /// Message of the exception @@ -28,7 +28,7 @@ final class VertexAIException implements Exception { /// Exception thrown when the server rejects the API key. final class InvalidApiKey implements VertexAIException { - /// Constructor + // ignore: public_member_api_docs InvalidApiKey(this.message); @override final String message; @@ -81,7 +81,7 @@ final class QuotaExceeded implements VertexAIException { /// Exception thrown when the server failed to generate content. final class ServerException implements VertexAIException { - /// Constructor + // ignore: public_member_api_docs ServerException(this.message); @override final String message; @@ -96,7 +96,7 @@ final class ServerException implements VertexAIException { /// as an inability to parse a new response format. Resolution paths may include /// updating to a new version of the SDK, or filing an issue. final class VertexAISdkException implements Exception { - /// Constructor + // ignore: public_member_api_docs VertexAISdkException(this.message); /// Message of the exception @@ -111,6 +111,21 @@ final class VertexAISdkException implements Exception { 'https://github.com/firebase/flutterfire/issues.'; } +/// Exception indicating all images filtered out. +/// +/// This exception indicates all images were filtered out because they violated +/// Vertex AI's usage guidelines. +final class ImagenImagesBlockedException implements Exception { + // ignore: public_member_api_docs + ImagenImagesBlockedException(this.message); + + /// Message of the exception + final String message; + + @override + String toString() => message; +} + /// Parse the error json object. VertexAIException parseError(Object jsonObject) { return switch (jsonObject) { diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart index df9c26f9eaaf..364c8d8f687b 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart @@ -21,7 +21,9 @@ import 'package:firebase_core_platform_interface/firebase_core_platform_interfac import 'api.dart'; import 'content.dart'; import 'function_calling.dart'; -import 'model.dart'; +import 'generative_model.dart'; +import 'imagen_api.dart'; +import 'imagen_model.dart'; const _defaultLocation = 'us-central1'; @@ -110,4 +112,22 @@ class FirebaseVertexAI extends FirebasePluginPlatform { systemInstruction: systemInstruction, ); } + + /// Create a [ImagenModel]. + /// + /// The optional [safetySettings] can be used to control and guide the + /// generation. See [ImagenSafetySettings] for details. + ImagenModel imagenModel( + {required String model, + ImagenGenerationConfig? generationConfig, + ImagenSafetySettings? safetySettings}) { + return createImagenModel( + app: app, + location: location, + model: model, + generationConfig: generationConfig, + safetySettings: safetySettings, + appCheck: appCheck, + auth: auth); + } } diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/function_calling.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/function_calling.dart index 5d552e33024e..f70bff0b3ff7 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/function_calling.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/function_calling.dart @@ -20,7 +20,7 @@ import 'schema.dart'; /// external systems to perform an action, or set of actions, outside of /// knowledge and scope of the model. final class Tool { - /// Constructor + // ignore: public_member_api_docs Tool._(this._functionDeclarations); /// Returns a [Tool] instance with list of [FunctionDeclaration]. @@ -54,7 +54,7 @@ final class Tool { /// FunctionDeclaration is a representation of a block of code that can be used /// as a `Tool` by the model and executed by the client. final class FunctionDeclaration { - /// Constructor + // ignore: public_member_api_docs FunctionDeclaration(this.name, this.description, {required Map parameters, List optionalParameters = const []}) @@ -82,7 +82,7 @@ final class FunctionDeclaration { /// Config for tools to use with model. final class ToolConfig { - /// Constructor + // ignore: public_member_api_docs ToolConfig({this.functionCallingConfig}); /// Config for function calling. @@ -98,7 +98,7 @@ final class ToolConfig { /// Configuration specifying how the model should use the functions provided as /// tools. final class FunctionCallingConfig { - /// Constructor + // ignore: public_member_api_docs FunctionCallingConfig._({this.mode, this.allowedFunctionNames}); /// The mode in which function calling should execute. diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/generative_model.dart similarity index 71% rename from packages/firebase_vertexai/firebase_vertexai/lib/src/model.dart rename to packages/firebase_vertexai/firebase_vertexai/lib/src/generative_model.dart index 605adf9b4fa7..ddf0d1e5a37e 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/generative_model.dart @@ -19,41 +19,19 @@ import 'dart:async'; import 'package:firebase_app_check/firebase_app_check.dart'; import 'package:firebase_auth/firebase_auth.dart'; import 'package:firebase_core/firebase_core.dart'; - import 'package:http/http.dart' as http; import 'api.dart'; +import 'base_model.dart'; import 'client.dart'; import 'content.dart'; import 'function_calling.dart'; -import 'vertex_version.dart'; - -const _baseUrl = 'firebasevertexai.googleapis.com'; -const _apiVersion = 'v1beta'; - -/// [Task] enum class for [GenerativeModel] to make request. -enum Task { - /// Request type to generate content. - generateContent, - - /// Request type to stream content. - streamGenerateContent, - - /// Request type to count token. - countTokens, - - /// Request type to embed content. - embedContent, - - /// Request type to batch embed content. - batchEmbedContents; -} /// A multimodel generative model (like Gemini). /// /// Allows generating content, creating embeddings, and counting the number of /// tokens in a piece of content. -final class GenerativeModel { +final class GenerativeModel extends BaseModel { /// Create a [GenerativeModel] backed by the generative model named [model]. /// /// The [model] argument can be a model name (such as `'gemini-pro'`) or a @@ -78,17 +56,19 @@ final class GenerativeModel { ToolConfig? toolConfig, Content? systemInstruction, http.Client? httpClient, - }) : _model = _normalizeModelName(model), - _baseUri = _vertexUri(app, location), - _safetySettings = safetySettings ?? [], + }) : _safetySettings = safetySettings ?? [], _generationConfig = generationConfig, _tools = tools, _toolConfig = toolConfig, _systemInstruction = systemInstruction, - _client = HttpApiClient( - apiKey: app.options.apiKey, - httpClient: httpClient, - requestHeaders: _firebaseTokens(appCheck, auth)); + super( + model: model, + app: app, + location: location, + client: HttpApiClient( + apiKey: app.options.apiKey, + httpClient: httpClient, + requestHeaders: BaseModel.firebaseTokens(appCheck, auth))); GenerativeModel._constructTestModel({ required String model, @@ -102,79 +82,30 @@ final class GenerativeModel { ToolConfig? toolConfig, Content? systemInstruction, ApiClient? apiClient, - }) : _model = _normalizeModelName(model), - _baseUri = _vertexUri(app, location), - _safetySettings = safetySettings ?? [], + }) : _safetySettings = safetySettings ?? [], _generationConfig = generationConfig, _tools = tools, _toolConfig = toolConfig, _systemInstruction = systemInstruction, - _client = apiClient ?? - HttpApiClient( - apiKey: app.options.apiKey, - requestHeaders: _firebaseTokens(appCheck, auth)); + super( + model: model, + app: app, + location: location, + client: apiClient ?? + HttpApiClient( + apiKey: app.options.apiKey, + requestHeaders: BaseModel.firebaseTokens(appCheck, auth))); - final ({String prefix, String name}) _model; final List _safetySettings; final GenerationConfig? _generationConfig; final List? _tools; - final ApiClient _client; - final Uri _baseUri; + + //final Uri _baseUri; final ToolConfig? _toolConfig; final Content? _systemInstruction; //static const _modelsPrefix = 'models/'; - /// Returns the model code for a user friendly model name. - /// - /// If the model name is already a model code (contains a `/`), use the parts - /// directly. Otherwise, return a `models/` model code. - static ({String prefix, String name}) _normalizeModelName(String modelName) { - if (!modelName.contains('/')) return (prefix: 'models', name: modelName); - final parts = modelName.split('/'); - return (prefix: parts.first, name: parts.skip(1).join('/')); - } - - static Uri _vertexUri(FirebaseApp app, String location) { - var projectId = app.options.projectId; - return Uri.https( - _baseUrl, - '/$_apiVersion/projects/$projectId/locations/$location/publishers/google', - ); - } - - static FutureOr> Function() _firebaseTokens( - FirebaseAppCheck? appCheck, FirebaseAuth? auth) { - return () async { - Map headers = {}; - // Override the client name in Google AI SDK - headers['x-goog-api-client'] = - 'gl-dart/$packageVersion fire/$packageVersion'; - if (appCheck != null) { - final appCheckToken = await appCheck.getToken(); - if (appCheckToken != null) { - headers['X-Firebase-AppCheck'] = appCheckToken; - } - } - if (auth != null) { - final idToken = await auth.currentUser?.getIdToken(); - if (idToken != null) { - headers['Authorization'] = 'Firebase $idToken'; - } - } - return headers; - }; - } - - Uri _taskUri(Task task) => _baseUri.replace( - pathSegments: _baseUri.pathSegments - .followedBy([_model.prefix, '${_model.name}:${task.name}'])); - - /// Make a unary request for [task] with JSON encodable [params]. - Future makeRequest(Task task, Map params, - T Function(Map) parse) => - _client.makeRequest(_taskUri(task), params).then(parse); - Map _generateContentRequest( Iterable contents, { List? safetySettings, @@ -187,7 +118,7 @@ final class GenerativeModel { tools ??= _tools; toolConfig ??= _toolConfig; return { - 'model': '${_model.prefix}/${_model.name}', + 'model': '${model.prefix}/${model.name}', 'contents': contents.map((c) => c.toJson()).toList(), if (safetySettings.isNotEmpty) 'safetySettings': safetySettings.map((s) => s.toJson()).toList(), @@ -244,8 +175,8 @@ final class GenerativeModel { GenerationConfig? generationConfig, List? tools, ToolConfig? toolConfig}) { - final response = _client.streamRequest( - _taskUri(Task.streamGenerateContent), + final response = client.streamRequest( + taskUri(Task.streamGenerateContent), _generateContentRequest( prompt, safetySettings: safetySettings, diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_api.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_api.dart new file mode 100644 index 000000000000..86ef4baae0d3 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_api.dart @@ -0,0 +1,221 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// Specifies the level of safety filtering for image generation. +/// +/// If not specified, default will be "block_medium_and_above". +enum ImagenSafetyFilterLevel { + /// Strongest filtering level, most strict blocking. + blockLowAndAbove('block_low_and_above'), + + /// Block some problematic prompts and responses. + blockMediumAndAbove('block_medium_and_above'), + + /// Reduces the number of requests blocked due to safety filters. + /// May increase objectionable content generated by Imagen. + blockOnlyHigh('block_only_high'), + + /// Block very few problematic prompts and responses. + /// Access to this feature is restricted. + blockNone('block_none'); + + const ImagenSafetyFilterLevel(this._jsonString); + + final String _jsonString; + + // ignore: public_member_api_docs + String toJson() => _jsonString; + + // ignore: unused_element + static ImagenSafetyFilterLevel _parseValue(Object jsonObject) { + return switch (jsonObject) { + 'block_low_and_above' => ImagenSafetyFilterLevel.blockLowAndAbove, + 'block_medium_and_above' => ImagenSafetyFilterLevel.blockMediumAndAbove, + 'block_only_high' => ImagenSafetyFilterLevel.blockOnlyHigh, + 'block_none' => ImagenSafetyFilterLevel.blockNone, + _ => throw FormatException( + 'Unhandled ImagenSafetyFilterLevel format', jsonObject), + }; + } + + @override + String toString() => name; +} + +/// Allow generation of people by the model. +/// +/// If not specified, the default value is "allow_adult". +enum ImagenPersonFilterLevel { + /// Disallow the inclusion of people or faces in images. + blockAll('dont_allow'), + + /// Allow generation of adults only. + allowAdult('allow_adult'), + + /// Allow generation of people of all ages. + allowAll('allow_all'); + + const ImagenPersonFilterLevel(this._jsonString); + + final String _jsonString; + + // ignore: public_member_api_docs + String toJson() => _jsonString; + + // ignore: unused_element + static ImagenPersonFilterLevel _parseValue(Object jsonObject) { + return switch (jsonObject) { + 'dont_allow' => ImagenPersonFilterLevel.blockAll, + 'allow_adult' => ImagenPersonFilterLevel.allowAdult, + 'allow_all' => ImagenPersonFilterLevel.allowAll, + _ => throw FormatException( + 'Unhandled ImagenPersonFilterLevel format', jsonObject), + }; + } + + @override + String toString() => name; +} + +/// A class representing safety settings for image generation. +/// +/// It includes a safety filter level and a person filter level. +final class ImagenSafetySettings { + // ignore: public_member_api_docs + ImagenSafetySettings(this.safetyFilterLevel, this.personFilterLevel); + + /// The safety filter level + final ImagenSafetyFilterLevel? safetyFilterLevel; + + /// The person filter level + final ImagenPersonFilterLevel? personFilterLevel; + + // ignore: public_member_api_docs + Object toJson() => { + if (safetyFilterLevel != null) + 'safetySetting': safetyFilterLevel!.toJson(), + if (personFilterLevel != null) + 'personGeneration': personFilterLevel!.toJson(), + }; +} + +/// The aspect ratio for the image. +/// +/// The default value is "1:1". +enum ImagenAspectRatio { + /// Square (1:1). + square1x1('1:1'), + + /// Portrait (9:16). + portrait9x16('9:16'), + + /// Landscape (16:9). + landscape16x9('16:9'), + + /// Portrait (3:4). + portrait3x4('3:4'), + + /// Landscape (4:3). + landscape4x3('4:3'); + + const ImagenAspectRatio(this._jsonString); + + final String _jsonString; + + // ignore: public_member_api_docs + String toJson() => _jsonString; + + // ignore: unused_element + static ImagenAspectRatio _parseValue(Object jsonObject) { + return switch (jsonObject) { + '1:1' => ImagenAspectRatio.square1x1, + '9:16' => ImagenAspectRatio.portrait9x16, + '16:9' => ImagenAspectRatio.landscape16x9, + '3:4' => ImagenAspectRatio.portrait3x4, + '4:3' => ImagenAspectRatio.landscape4x3, + _ => + throw FormatException('Unhandled ImagenAspectRatio format', jsonObject), + }; + } + + @override + String toString() => name; +} + +/// Configuration options for image generation. +final class ImagenGenerationConfig { + // ignore: public_member_api_docs + ImagenGenerationConfig( + {this.numberOfImages, + this.negativePrompt, + this.aspectRatio, + this.imageFormat, + this.addWatermark}); + + /// The number of images to generate. + /// + /// Default value is 1. + final int? numberOfImages; + + /// A description of what to discourage in the generated images. + final String? negativePrompt; + + /// The aspect ratio for the image. The default value is "1:1". + final ImagenAspectRatio? aspectRatio; + + /// The image format of the generated images. + final ImagenFormat? imageFormat; + + /// Whether to add an invisible watermark to generated images. + /// + /// Default value for each imagen model can be found in + /// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#generate_images + final bool? addWatermark; + + // ignore: public_member_api_docs + Map toJson() => { + if (negativePrompt != null) 'negativePrompt': negativePrompt, + if (numberOfImages != null) 'numberOfImages': numberOfImages, + if (aspectRatio != null) 'aspectRatio': aspectRatio!.toJson(), + if (addWatermark != null) 'addWatermark': addWatermark, + if (imageFormat != null) 'outputOption': imageFormat!.toJson(), + }; +} + +/// Represents the image format and compression quality. +final class ImagenFormat { + // ignore: public_member_api_docs + ImagenFormat(this.mimeType, this.compressionQuality); + + // ignore: public_member_api_docs + ImagenFormat.png() : this('image/png', null); + + // ignore: public_member_api_docs + ImagenFormat.jpeg({int? compressionQuality}) + : this('image/jpeg', compressionQuality); + + /// The MIME type of the image format. The default value is "image/png". + final String mimeType; + + /// The level of compression if the output type is "image/jpeg". + /// Accepted values are 0 through 100. The default value is 75. + final int? compressionQuality; + + // ignore: public_member_api_docs + Map toJson() => { + 'mimeType': mimeType, + if (compressionQuality != null) + 'compressionQuality': compressionQuality, + }; +} diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_content.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_content.dart new file mode 100644 index 000000000000..71d16f9da704 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_content.dart @@ -0,0 +1,155 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import 'dart:convert'; +import 'dart:typed_data'; +import 'error.dart'; + +/// Base type of Imagen Image. +sealed class ImagenImage { + // ignore: public_member_api_docs + ImagenImage({required this.mimeType}); + + /// The MIME type of the image format. + final String mimeType; + + /// Convert the [ImagenImage] content to json format. + Object toJson(); +} + +/// Represents an image stored as a base64-encoded string. +final class ImagenInlineImage implements ImagenImage { + // ignore: public_member_api_docs + ImagenInlineImage({ + required this.bytesBase64Encoded, + required this.mimeType, + }); + + /// Factory method to create an [ImagenInlineImage] from a JSON object. + factory ImagenInlineImage.fromJson(Map json) { + final mimeType = json['mimeType'] as String; + final bytes = json['bytesBase64Encoded'] as String; + final decodedBytes = base64Decode(bytes); + return ImagenInlineImage( + mimeType: mimeType, + bytesBase64Encoded: Uint8List.fromList(decodedBytes), + ); + } + + /// The data contents in bytes, encoded as base64. + final Uint8List bytesBase64Encoded; + + @override + final String mimeType; + + @override + Object toJson() => { + 'mimeType': mimeType, + 'bytesBase64Encoded': base64Encode(bytesBase64Encoded), + }; +} + +/// Represents an image stored in Google Cloud Storage. +final class ImagenGCSImage implements ImagenImage { + // ignore: public_member_api_docs + ImagenGCSImage({ + required this.gcsUri, + required this.mimeType, + }); + + /// Factory method to create an [ImagenGCSImage] from a JSON object. + factory ImagenGCSImage.fromJson(Map json) { + final mimeType = json['mimeType'] as String; + final uri = json['gcsUri'] as String; + + return ImagenGCSImage( + mimeType: mimeType, + gcsUri: uri, + ); + } + + /// The storage URI of the image. + final String gcsUri; + + @override + final String mimeType; + + @override + Object toJson() => { + 'mimeType': mimeType, + 'gcsUri': gcsUri, + }; +} + +/// Represents the response from an image generation request. +final class ImagenGenerationResponse { + // ignore: public_member_api_docs + ImagenGenerationResponse({ + required this.images, + this.filteredReason, + }); + + /// Factory method to create an [ImagenGenerationResponse] from a JSON object. + factory ImagenGenerationResponse.fromJson(Map json) { + final predictions = json['predictions']; + if (predictions.isEmpty) { + throw ServerException('Got empty prediction with no reason'); + } + + List images = []; + String? filteredReason; + + if (T == ImagenInlineImage) { + for (final prediction in predictions) { + if (prediction.containsKey('bytesBase64Encoded')) { + final image = ImagenInlineImage.fromJson(prediction) as T; + images.add(image); + } else if (prediction.containsKey('raiFilteredReason')) { + filteredReason = prediction['raiFilteredReason'] as String; + } + } + } else if (T == ImagenGCSImage) { + for (final prediction in predictions) { + if (prediction.containsKey('gcsUri')) { + final image = ImagenGCSImage.fromJson(prediction) as T; + images.add(image); + } else if (prediction.containsKey('raiFilteredReason')) { + filteredReason = prediction['raiFilteredReason'] as String; + } + } + } else { + throw ArgumentError('Unsupported ImagenImage type: $T'); + } + + if (images.isEmpty && filteredReason != null) { + throw ImagenImagesBlockedException(filteredReason); + } + + return ImagenGenerationResponse( + images: images, filteredReason: filteredReason); + } + + /// A list of generated images. The type of the images depends on the T parameter. + final List images; + + /// If the generation was filtered due to safety reasons, a message explaining the reason. + final String? filteredReason; +} + +/// Parse the json to [ImagenGenerationResponse] +ImagenGenerationResponse + parseImagenGenerationResponse(Object jsonObject) { + if (jsonObject case {'error': final Object error}) throw parseError(error); + Map json = jsonObject as Map; + return ImagenGenerationResponse.fromJson(json); +} diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart new file mode 100644 index 000000000000..631d7fa94849 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart @@ -0,0 +1,135 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import 'package:firebase_app_check/firebase_app_check.dart'; +import 'package:firebase_auth/firebase_auth.dart'; +import 'package:firebase_core/firebase_core.dart'; + +import 'base_model.dart'; +import 'client.dart'; +import 'imagen_api.dart'; +import 'imagen_content.dart'; + +/// Represents a remote Imagen model with the ability to generate images using +/// text prompts. +/// +/// See the [Cloud +/// documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images) +/// for more details about the image generation capabilities offered by the Imagen model. +/// +/// > Warning: For Vertex AI in Firebase, image generation using Imagen 3 models +/// is in Public Preview, which means that the feature is not subject to any SLA +/// or deprecation policy and could change in backwards-incompatible ways. +final class ImagenModel extends BaseModel { + ImagenModel._( + {required FirebaseApp app, + required String model, + required String location, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, + ImagenGenerationConfig? generationConfig, + ImagenSafetySettings? safetySettings}) + : _generationConfig = generationConfig, + _safetySettings = safetySettings, + super( + model: model, + app: app, + location: location, + client: HttpApiClient( + apiKey: app.options.apiKey, + requestHeaders: BaseModel.firebaseTokens(appCheck, auth))); + + final ImagenGenerationConfig? _generationConfig; + final ImagenSafetySettings? _safetySettings; + + Map _generateImagenRequest( + String prompt, { + String? gcsUri, + }) { + final parameters = { + if (gcsUri != null) 'storageUri': gcsUri, + 'sampleCount': _generationConfig?.numberOfImages ?? 1, + if (_generationConfig?.aspectRatio case final aspectRatio?) + 'aspectRatio': aspectRatio, + if (_generationConfig?.negativePrompt case final negativePrompt?) + 'negativePrompt': negativePrompt, + if (_generationConfig?.addWatermark case final addWatermark?) + 'addWatermark': addWatermark, + if (_generationConfig?.imageFormat case final imageFormat?) + 'outputOption': imageFormat.toJson(), + if (_safetySettings?.personFilterLevel case final personFilterLevel?) + 'personGeneration': personFilterLevel.toJson(), + if (_safetySettings?.safetyFilterLevel case final safetyFilterLevel?) + 'safetySetting': safetyFilterLevel.toJson(), + }; + + return { + 'instances': [ + {'prompt': prompt} + ], + 'parameters': parameters, + }; + } + + /// Generates images with format of [ImagenInlineImage] based on the given + /// prompt. + Future> generateImages( + String prompt, + ) => + makeRequest( + Task.predict, + _generateImagenRequest( + prompt, + ), + (jsonObject) => + parseImagenGenerationResponse(jsonObject), + ); + + /// Generates images with format of [ImagenGCSImage] based on the given + /// prompt. + /// Note: Keep this API private until future release. + // ignore: unused_element + Future> _generateImagesGCS( + String prompt, + String gcsUri, + ) => + makeRequest( + Task.predict, + _generateImagenRequest( + prompt, + gcsUri: gcsUri, + ), + (jsonObject) => + parseImagenGenerationResponse(jsonObject), + ); +} + +/// Returns a [ImagenModel] using it's private constructor. +ImagenModel createImagenModel({ + required FirebaseApp app, + required String location, + required String model, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, + ImagenGenerationConfig? generationConfig, + ImagenSafetySettings? safetySettings, +}) => + ImagenModel._( + model: model, + app: app, + appCheck: appCheck, + auth: auth, + location: location, + safetySettings: safetySettings, + generationConfig: generationConfig, + ); diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/schema.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/schema.dart index 0dda0f564d8a..e73f44b355f3 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/schema.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/schema.dart @@ -18,7 +18,7 @@ /// Represents a select subset of an /// [OpenAPI 3.0 schema object](https://spec.openapis.org/oas/v3.0.3#schema). final class Schema { - /// Constructor + // ignore: public_member_api_docs Schema( this.type, { this.format, diff --git a/packages/firebase_vertexai/firebase_vertexai/test/chat_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/chat_test.dart index 511c539f1a7e..a0f404cdac20 100644 --- a/packages/firebase_vertexai/firebase_vertexai/test/chat_test.dart +++ b/packages/firebase_vertexai/firebase_vertexai/test/chat_test.dart @@ -13,7 +13,7 @@ // limitations under the License. import 'package:firebase_core/firebase_core.dart'; import 'package:firebase_vertexai/firebase_vertexai.dart'; -import 'package:firebase_vertexai/src/model.dart'; +import 'package:firebase_vertexai/src/generative_model.dart'; import 'package:flutter_test/flutter_test.dart'; import 'mock.dart'; diff --git a/packages/firebase_vertexai/firebase_vertexai/test/imagen_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/imagen_test.dart new file mode 100644 index 000000000000..d030e6f89495 --- /dev/null +++ b/packages/firebase_vertexai/firebase_vertexai/test/imagen_test.dart @@ -0,0 +1,241 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; +import 'dart:typed_data'; + +import 'package:firebase_vertexai/src/error.dart'; +import 'package:firebase_vertexai/src/imagen_content.dart'; +import 'package:flutter_test/flutter_test.dart'; + +void main() { + group('ImagenInlineImage', () { + test('fromJson with valid base64', () { + final json = { + 'mimeType': 'image/png', + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=' + }; + final image = ImagenInlineImage.fromJson(json); + expect(image.mimeType, 'image/png'); + expect(image.bytesBase64Encoded, isA()); + expect(image.bytesBase64Encoded, isNotEmpty); + }); + + test('fromJson with invalid base64', () { + final json = { + 'mimeType': 'image/png', + 'bytesBase64Encoded': 'invalid_base64_string' + }; + // Expect that the constructor throws an exception. + expect(() => ImagenInlineImage.fromJson(json), throwsFormatException); + }); + + test('toJson', () { + final image = ImagenInlineImage( + mimeType: 'image/png', + bytesBase64Encoded: Uint8List.fromList(utf8.encode('Hello, world!')), + ); + final json = image.toJson(); + expect(json, { + 'mimeType': 'image/png', + 'bytesBase64Encoded': 'SGVsbG8sIHdvcmxkIQ==', + }); + }); + }); + + group('ImagenGCSImage', () { + test('fromJson', () { + final json = { + 'mimeType': 'image/jpeg', + 'gcsUri': + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_0.jpg' + }; + final image = ImagenGCSImage.fromJson(json); + expect(image.mimeType, 'image/jpeg'); + expect(image.gcsUri, + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_0.jpg'); + }); + + test('toJson', () { + final image = ImagenGCSImage( + mimeType: 'image/jpeg', + gcsUri: + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_0.jpg', + ); + final json = image.toJson(); + expect(json, { + 'mimeType': 'image/jpeg', + 'gcsUri': + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_0.jpg', + }); + }); + }); + + group('ImagenGenerationResponse', () { + test('fromJson with gcsUri', () { + final json = { + 'predictions': [ + { + 'mimeType': 'image/jpeg', + 'gcsUri': + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_0.jpg' + }, + { + 'mimeType': 'image/jpeg', + 'gcsUri': + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_1.jpg' + }, + { + 'mimeType': 'image/jpeg', + 'gcsUri': + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_2.jpg' + }, + { + 'mimeType': 'image/jpeg', + 'gcsUri': + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_3.jpg' + } + ] + }; + final response = ImagenGenerationResponse.fromJson(json); + expect(response.images, isA>()); + expect(response.images.length, 4); + expect(response.filteredReason, isNull); + }); + + test('fromJson with bytesBase64Encoded', () { + final json = { + 'predictions': [ + { + 'mimeType': 'image/jpeg', + 'bytesBase64Encoded': 'SGVsbG8sIHdvcmxkIQ==' + }, + { + 'mimeType': 'image/jpeg', + 'bytesBase64Encoded': 'SGVsbG8sIHdvcmxkIQ==' + }, + { + 'mimeType': 'image/jpeg', + 'bytesBase64Encoded': 'SGVsbG8sIHdvcmxkIQ==' + }, + { + 'mimeType': 'image/jpeg', + 'bytesBase64Encoded': 'SGVsbG8sIHdvcmxkIQ==' + } + ] + }; + final response = + ImagenGenerationResponse.fromJson(json); + expect(response.images, isA>()); + expect(response.images.length, 4); + expect(response.filteredReason, isNull); + }); + + test('fromJson with bytesBase64Encoded and raiFilteredReason', () { + final json = { + 'predictions': [ + { + 'mimeType': 'image/png', + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=' + }, + { + 'mimeType': 'image/png', + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=' + }, + { + 'raiFilteredReason': + 'Your current safety filter threshold filtered out 2 generated images. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback.' + } + ] + }; + final response = + ImagenGenerationResponse.fromJson(json); + expect(response.images, isA>()); + expect(response.images.length, 2); + expect(response.filteredReason, + 'Your current safety filter threshold filtered out 2 generated images. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback.'); + }); + + test('fromJson with only raiFilteredReason', () { + final json = { + 'predictions': [ + { + 'raiFilteredReason': + "Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback. Support codes: 39322892, 29310472" + } + ] + }; + // Expect that the constructor throws an exception. + expect(() => ImagenGenerationResponse.fromJson(json), + throwsA(isA())); + }); + + test('fromJson with empty predictions', () { + final json = {'predictions': {}}; + // Expect that the constructor throws an exception. + expect(() => ImagenGenerationResponse.fromJson(json), + throwsA(isA())); + }); + + test('fromJson with unsupported type', () { + final json = { + 'predictions': [ + { + 'mimeType': 'image/jpeg', + 'gcsUri': + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_0.jpg' + }, + ] + }; + // Expect that the constructor throws an exception. + expect(() => ImagenGenerationResponse.fromJson(json), + throwsA(isA())); + }); + }); + + group('parseImagenGenerationResponse', () { + test('with valid response', () { + final json = { + 'predictions': [ + { + 'mimeType': 'image/jpeg', + 'gcsUri': + 'gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_0.jpg' + }, + ] + }; + final response = parseImagenGenerationResponse(json); + expect(response.images, isA>()); + expect(response.images.length, 1); + expect(response.filteredReason, isNull); + }); + + test('with error', () { + final json = { + 'error': { + 'code': 400, + 'message': + "Image generation failed with the following error: The prompt could not be submitted. This prompt contains sensitive words that violate Google's Responsible AI practices. Try rephrasing the prompt. If you think this was an error, send feedback. Support codes: 42876398", + 'status': 'INVALID_ARGUMENT' + } + }; + // Expect that the function throws an exception. + expect(() => parseImagenGenerationResponse(json), + throwsA(isA())); + }); + }); +} diff --git a/packages/firebase_vertexai/firebase_vertexai/test/model_test.dart b/packages/firebase_vertexai/firebase_vertexai/test/model_test.dart index f0c05f71e8dc..9dba6c498bb7 100644 --- a/packages/firebase_vertexai/firebase_vertexai/test/model_test.dart +++ b/packages/firebase_vertexai/firebase_vertexai/test/model_test.dart @@ -13,7 +13,7 @@ // limitations under the License. import 'package:firebase_core/firebase_core.dart'; import 'package:firebase_vertexai/firebase_vertexai.dart'; -import 'package:firebase_vertexai/src/model.dart'; +import 'package:firebase_vertexai/src/generative_model.dart'; import 'package:flutter_test/flutter_test.dart'; import 'mock.dart';