Skip to content

[Vertex AI] Add countTokens support for Developer API via VinF #14644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ public final class GenerativeModel: Sendable {
/// Model name prefix to identify Gemini models.
static let geminiModelNamePrefix = "gemini-"

/// The resource name of the model in the backend; has the format "models/model-name".
/// The name of the model, for example "gemini-2.0-flash".
let modelName: String

/// The model resource name corresponding with `modelName` in the backend.
let modelResourceName: String

/// Configuration for the backend API used by this model.
Expand Down Expand Up @@ -53,8 +56,13 @@ public final class GenerativeModel: Sendable {
/// Initializes a new remote model with the given parameters.
///
/// - Parameters:
/// - modelResourceName: The resource name of the model to use, for example
/// `"projects/{project-id}/locations/{location-id}/publishers/google/models/{model-name}"`.
/// - modelName: The name of the model, for example "gemini-2.0-flash".
/// - modelResourceName: The model resource name corresponding with `modelName` in the backend.
/// The form depends on the backend and will be one of:
/// - Vertex AI via Vertex AI in Firebase:
/// `"projects/{projectID}/locations/{locationID}/publishers/google/models/{modelName}"`
/// - Developer API via Vertex AI in Firebase: `"projects/{projectID}/models/{modelName}"`
/// - Developer API via Generative Language: `"models/{modelName}"`
/// - firebaseInfo: Firebase data used by the SDK, including project ID and API key.
/// - apiConfig: Configuration for the backend API used by this model.
/// - generationConfig: The content generation parameters your model should use.
Expand All @@ -65,7 +73,8 @@ public final class GenerativeModel: Sendable {
/// only text content is supported.
/// - requestOptions: Configuration parameters for sending requests to the backend.
/// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
init(modelResourceName: String,
init(modelName: String,
modelResourceName: String,
firebaseInfo: FirebaseInfo,
apiConfig: APIConfig,
generationConfig: GenerationConfig? = nil,
Expand All @@ -75,6 +84,7 @@ public final class GenerativeModel: Sendable {
systemInstruction: ModelContent? = nil,
requestOptions: RequestOptions,
urlSession: URLSession = .shared) {
self.modelName = modelName
self.modelResourceName = modelResourceName
self.apiConfig = apiConfig
generativeAIService = GenerativeAIService(
Expand Down Expand Up @@ -275,8 +285,20 @@ public final class GenerativeModel: Sendable {
content.map { ModelContent(role: nil, parts: $0.parts) }
}

// When using the Developer API via the Firebase backend, the model name of the
// `GenerateContentRequest` nested in the `CountTokensRequest` must be of the form
// "models/model-name". This field is unaltered by the Firebase backend before forwarding the
// request to the Generative Language backend, which expects the form "models/model-name".
let generateContentRequestModelResourceName = switch apiConfig.service {
case .vertexAI, .developer(endpoint: .generativeLanguage):
modelResourceName
case .developer(endpoint: .firebaseVertexAIProd),
.developer(endpoint: .firebaseVertexAIStaging):
"models/\(modelName)"
}

let generateContentRequest = GenerateContentRequest(
model: modelResourceName,
model: generateContentRequestModelResourceName,
contents: requestContent,
generationConfig: generationConfig,
safetySettings: safetySettings,
Expand All @@ -287,7 +309,9 @@ public final class GenerativeModel: Sendable {
apiMethod: .countTokens,
options: requestOptions
)
let countTokensRequest = CountTokensRequest(generateContentRequest: generateContentRequest)
let countTokensRequest = CountTokensRequest(
modelResourceName: modelResourceName, generateContentRequest: generateContentRequest
)

return try await generativeAIService.loadRequest(request: countTokensRequest)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct CountTokensRequest {
let modelResourceName: String

let generateContentRequest: GenerateContentRequest
}

Expand All @@ -30,7 +32,7 @@ extension CountTokensRequest: GenerativeAIRequest {
var url: URL {
let version = apiConfig.version.rawValue
let endpoint = apiConfig.service.endpoint.rawValue
return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")!
return URL(string: "\(endpoint)/\(version)/\(modelResourceName):countTokens")!
}
}

Expand Down
5 changes: 2 additions & 3 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public class VertexAI {
}

return GenerativeModel(
modelName: modelName,
modelResourceName: modelResourceName(modelName: modelName),
firebaseInfo: firebaseInfo,
apiConfig: apiConfig,
Expand Down Expand Up @@ -240,13 +241,11 @@ public class VertexAI {

private func developerModelResourceName(modelName: String) -> String {
switch apiConfig.service.endpoint {
case .firebaseVertexAIStaging:
case .firebaseVertexAIStaging, .firebaseVertexAIProd:
let projectID = firebaseInfo.projectID
return "projects/\(projectID)/models/\(modelName)"
case .generativeLanguage:
return "models/\(modelName)"
default:
fatalError("The Developer API is not supported on '\(apiConfig.service.endpoint)'.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct CountTokensIntegrationTests {

@Test(arguments: [
/* System instructions are not supported on the v1 Developer API. */
InstanceConfig.developerV1,
InstanceConfig.developerV1Spark,
])
func countTokens_text_systemInstruction_unsupported(_ config: InstanceConfig) async throws {
let model = VertexAI.componentInstance(config).generativeModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ struct InstanceConfig {
static let vertexV1BetaStaging = InstanceConfig(
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseVertexAIStaging), version: .v1beta)
)
static let developerV1 = InstanceConfig(
static let developerV1Beta = InstanceConfig(
apiConfig: APIConfig(service: .developer(endpoint: .firebaseVertexAIProd), version: .v1beta)
)
static let developerV1Spark = InstanceConfig(
appName: FirebaseAppNames.spark,
apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1)
)
static let developerV1Beta = InstanceConfig(
static let developerV1BetaSpark = InstanceConfig(
appName: FirebaseAppNames.spark,
apiConfig: APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta)
)
Expand All @@ -45,8 +48,9 @@ struct InstanceConfig {
vertexV1Staging,
vertexV1Beta,
vertexV1BetaStaging,
developerV1,
developerV1Beta,
developerV1Spark,
developerV1BetaSpark,
]

static let vertexV1AppCheckNotConfigured = InstanceConfig(
Expand Down
6 changes: 5 additions & 1 deletion FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ChatTests: XCTestCase {
let modelName = "test-model-name"
let modelResourceName = "projects/my-project/locations/us-central1/models/test-model-name"

var urlSession: URLSession!

override func setUp() {
Expand All @@ -45,7 +48,7 @@
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
MockURLProtocol.requestHandler = { request in

Check warning on line 51 in FirebaseVertexAI/Tests/Unit/ChatTests.swift

View workflow job for this annotation

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
let response = HTTPURLResponse(
url: request.url!,
statusCode: 200,
Expand All @@ -59,7 +62,8 @@
options: FirebaseOptions(googleAppID: "ignore",
gcmSenderID: "ignore"))
let model = GenerativeModel(
modelResourceName: "my-model",
modelName: modelName,
modelResourceName: modelResourceName,
firebaseInfo: FirebaseInfo(
projectID: "my-project-id",
apiKey: "API_KEY",
Expand Down
20 changes: 17 additions & 3 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
blocked: false
),
].sorted()
let testModelName = "test-model"
let testModelResourceName =
"projects/test-project-id/locations/test-location/publishers/google/models/test-model"
let apiConfig = VertexAI.defaultVertexAIAPIConfig
Expand All @@ -70,6 +71,7 @@
configuration.protocolClasses = [MockURLProtocol.self]
urlSession = try XCTUnwrap(URLSession(configuration: configuration))
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(),
apiConfig: apiConfig,
Expand Down Expand Up @@ -275,8 +277,8 @@
subdirectory: vertexSubdirectory
)
let model = GenerativeModel(
// Model name is prefixed with "models/".
modelResourceName: "models/test-model",
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(),
apiConfig: apiConfig,
tools: nil,
Expand Down Expand Up @@ -399,6 +401,7 @@
func testGenerateContent_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
apiConfig: apiConfig,
Expand All @@ -420,6 +423,7 @@
func testGenerateContent_dataCollectionOff() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken),
privateAppID: true),
Expand All @@ -442,6 +446,7 @@

func testGenerateContent_appCheck_tokenRefreshError() async throws {
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
apiConfig: apiConfig,
Expand All @@ -463,6 +468,7 @@
func testGenerateContent_auth_validAuthToken() async throws {
let authToken = "test-valid-token"
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: authToken)),
apiConfig: apiConfig,
Expand All @@ -483,6 +489,7 @@

func testGenerateContent_auth_nilAuthToken() async throws {
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(token: nil)),
apiConfig: apiConfig,
Expand All @@ -503,7 +510,8 @@

func testGenerateContent_auth_authTokenRefreshError() async throws {
model = GenerativeModel(
modelResourceName: "my-model",
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(auth: AuthInteropFake(error: AuthErrorFake())),
apiConfig: apiConfig,
tools: nil,
Expand Down Expand Up @@ -900,6 +908,7 @@
)
let requestOptions = RequestOptions(timeout: expectedTimeout)
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(),
apiConfig: apiConfig,
Expand Down Expand Up @@ -1204,6 +1213,7 @@
func testGenerateContentStream_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(token: appCheckToken)),
apiConfig: apiConfig,
Expand All @@ -1225,6 +1235,7 @@

func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(appCheck: AppCheckInteropFake(error: AppCheckErrorFake())),
apiConfig: apiConfig,
Expand Down Expand Up @@ -1375,6 +1386,7 @@
)
let requestOptions = RequestOptions(timeout: expectedTimeout)
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(),
apiConfig: apiConfig,
Expand Down Expand Up @@ -1451,6 +1463,7 @@
parts: "You are a calculator. Use the provided tools."
)
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(),
apiConfig: apiConfig,
Expand Down Expand Up @@ -1511,6 +1524,7 @@
)
let requestOptions = RequestOptions(timeout: expectedTimeout)
model = GenerativeModel(
modelName: testModelName,
modelResourceName: testModelResourceName,
firebaseInfo: testFirebaseInfo(),
apiConfig: apiConfig,
Expand Down Expand Up @@ -1552,7 +1566,7 @@
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
return { request in

Check warning on line 1569 in FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

View workflow job for this annotation

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
// This is *not* an HTTPURLResponse
let response = URLResponse(
url: request.url!,
Expand Down Expand Up @@ -1580,7 +1594,7 @@
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
let bundle = BundleTestUtil.bundle()

Check warning on line 1597 in FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

View workflow job for this annotation

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
let fileURL = try XCTUnwrap(
bundle.url(forResource: name, withExtension: ext, subdirectory: subpath)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ final class CountTokensRequestTests: XCTestCase {
apiMethod: .countTokens,
options: requestOptions
)
let request = CountTokensRequest(generateContentRequest: generateContentRequest)
let request = CountTokensRequest(
modelResourceName: modelResourceName, generateContentRequest: generateContentRequest
)

let jsonData = try encoder.encode(request)

Expand Down Expand Up @@ -86,7 +88,9 @@ final class CountTokensRequestTests: XCTestCase {
apiMethod: .countTokens,
options: requestOptions
)
let request = CountTokensRequest(generateContentRequest: generateContentRequest)
let request = CountTokensRequest(
modelResourceName: modelResourceName, generateContentRequest: generateContentRequest
)

let jsonData = try encoder.encode(request)

Expand Down
Loading