Skip to content

[Vertex AI] Add APIConfig to userInfo dictionary in coders #14592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
@@ -26,7 +26,6 @@ struct GenerateContentRequest: Sendable {
let toolConfig: ToolConfig?
let systemInstruction: ModelContent?

let apiConfig: APIConfig
let apiMethod: APIMethod
let options: RequestOptions
}
@@ -73,7 +72,7 @@ extension GenerateContentRequest {
extension GenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

var url: URL {
func requestURL(apiConfig: APIConfig) -> URL {
let modelURL = "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model)"
switch apiMethod {
case .generateContent:
4 changes: 2 additions & 2 deletions FirebaseVertexAI/Sources/GenerativeAIRequest.swift
Original file line number Diff line number Diff line change
@@ -18,9 +18,9 @@ import Foundation
protocol GenerativeAIRequest: Sendable, Encodable {
associatedtype Response: Decodable

var url: URL { get }

var options: RequestOptions { get }

func requestURL(apiConfig: APIConfig) -> URL
}

/// Configuration parameters for sending requests to the backend.
22 changes: 14 additions & 8 deletions FirebaseVertexAI/Sources/GenerativeAIService.swift
Original file line number Diff line number Diff line change
@@ -28,10 +28,19 @@ struct GenerativeAIService {

private let firebaseInfo: FirebaseInfo

/// Configuration for the backend API used by this model.
private let apiConfig: APIConfig

private let jsonDecoder: JSONDecoder
private let jsonEncoder: JSONEncoder

private let urlSession: URLSession

init(firebaseInfo: FirebaseInfo, urlSession: URLSession) {
init(firebaseInfo: FirebaseInfo, apiConfig: APIConfig, urlSession: URLSession) {
self.firebaseInfo = firebaseInfo
self.apiConfig = apiConfig
jsonDecoder = JSONDecoder(apiConfig: apiConfig)
jsonEncoder = JSONEncoder(apiConfig: apiConfig)
self.urlSession = urlSession
}

@@ -125,8 +134,6 @@ struct GenerativeAIService {
// Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
var extraLines = ""

let decoder = JSONDecoder()
decoder.keyDecodingStrategy = .convertFromSnakeCase
for try await line in stream.lines {
VertexLog.debug(code: .loadRequestStreamResponseLine, "Stream response: \(line)")

@@ -167,7 +174,7 @@ struct GenerativeAIService {
// MARK: - Private Helpers

private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
var urlRequest = URLRequest(url: request.url)
var urlRequest = URLRequest(url: request.requestURL(apiConfig: apiConfig))
urlRequest.httpMethod = "POST"
urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key")
urlRequest.setValue(
@@ -200,8 +207,7 @@ struct GenerativeAIService {
}
}

let encoder = JSONEncoder()
urlRequest.httpBody = try encoder.encode(request)
urlRequest.httpBody = try jsonEncoder.encode(request)
urlRequest.timeoutInterval = request.options.timeout

return urlRequest
@@ -246,7 +252,7 @@ struct GenerativeAIService {

private func parseError(responseData: Data) -> Error {
do {
let rpcError = try JSONDecoder().decode(BackendError.self, from: responseData)
let rpcError = try jsonDecoder.decode(BackendError.self, from: responseData)
logRPCError(rpcError)
return rpcError
} catch {
@@ -273,7 +279,7 @@ struct GenerativeAIService {

private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
do {
return try JSONDecoder().decode(type, from: data)
return try jsonDecoder.decode(type, from: data)
} catch {
if let json = String(data: data, encoding: .utf8) {
VertexLog.error(code: .loadRequestParseResponseFailedJSON, "JSON response: \(json)")
4 changes: 1 addition & 3 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
@@ -75,6 +75,7 @@ public final class GenerativeModel: Sendable {
self.apiConfig = apiConfig
generativeAIService = GenerativeAIService(
firebaseInfo: firebaseInfo,
apiConfig: apiConfig,
urlSession: urlSession
)
self.generationConfig = generationConfig
@@ -137,7 +138,6 @@ public final class GenerativeModel: Sendable {
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
apiConfig: apiConfig,
apiMethod: .generateContent,
options: requestOptions
)
@@ -197,7 +197,6 @@ public final class GenerativeModel: Sendable {
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
apiConfig: apiConfig,
apiMethod: .streamGenerateContent,
options: requestOptions
)
@@ -279,7 +278,6 @@ public final class GenerativeModel: Sendable {
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
apiConfig: apiConfig,
apiMethod: .countTokens,
options: requestOptions
)
53 changes: 53 additions & 0 deletions FirebaseVertexAI/Sources/Types/Internal/APIConfig.swift
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// Configuration for the generative AI backend API used by this SDK.
struct APIConfig: Sendable, Hashable {
/// The service to use for generative AI.
@@ -90,3 +92,54 @@ extension APIConfig {
case v1beta
}
}

// MARK: - Coding Utilities

extension CodingUserInfoKey {
static let apiConfig = {
let keyName = "com.google.firebase.VertexAI.APIConfig"
guard let userInfoKey = CodingUserInfoKey(rawValue: keyName) else {
fatalError("The key name '\(keyName)' is not a valid raw value for CodingUserInfoKey.")
}
return userInfoKey
}()
}

extension APIConfig {
static func from(userInfo: [CodingUserInfoKey: Any]) -> APIConfig {
guard let config = userInfo[CodingUserInfoKey.apiConfig] else {
fatalError(
"No value provided for '\(CodingUserInfoKey.apiConfig)' in the coder's userInfo."
)
}
guard let config = config as? APIConfig else {
fatalError("""
The value provided for '\(CodingUserInfoKey.apiConfig)' in the coder's userInfo is not of \
type '\(APIConfig.self)'; found type '\(config)'.
""")
}
return config
}
}

extension Decoder {
var apiConfig: APIConfig { APIConfig.from(userInfo: userInfo) }
}

extension JSONDecoder {
convenience init(apiConfig: APIConfig) {
self.init()
userInfo[CodingUserInfoKey.apiConfig] = apiConfig
}
}

extension Encoder {
var apiConfig: APIConfig { APIConfig.from(userInfo: userInfo) }
}

extension JSONEncoder {
convenience init(apiConfig: APIConfig) {
self.init()
userInfo[CodingUserInfoKey.apiConfig] = apiConfig
}
}
Original file line number Diff line number Diff line change
@@ -17,18 +17,15 @@ import Foundation
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImagenGenerationRequest<ImageType: ImagenImageRepresentable>: Sendable {
let model: String
let apiConfig: APIConfig
let options: RequestOptions
let instances: [ImageGenerationInstance]
let parameters: ImageGenerationParameters

init(model: String,
apiConfig: APIConfig,
options: RequestOptions,
instances: [ImageGenerationInstance],
parameters: ImageGenerationParameters) {
self.model = model
self.apiConfig = apiConfig
self.options = options
self.instances = instances
self.parameters = parameters
@@ -39,7 +36,7 @@ struct ImagenGenerationRequest<ImageType: ImagenImageRepresentable>: Sendable {
extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
typealias Response = ImagenGenerationResponse<ImageType>

var url: URL {
func requestURL(apiConfig: APIConfig) -> URL {
return URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict")!
}
Original file line number Diff line number Diff line change
@@ -25,9 +25,7 @@ extension CountTokensRequest: GenerativeAIRequest {

var options: RequestOptions { generateContentRequest.options }

var apiConfig: APIConfig { generateContentRequest.apiConfig }

var url: URL {
func requestURL(apiConfig: APIConfig) -> URL {
let version = apiConfig.version.rawValue
let endpoint = apiConfig.service.endpoint.rawValue
return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")!
@@ -66,7 +64,7 @@ extension CountTokensRequest: Encodable {
}

func encode(to encoder: any Encoder) throws {
switch apiConfig.service {
switch encoder.apiConfig.service {
case .vertexAI:
try encodeForVertexAI(to: encoder)
case .developer:
Original file line number Diff line number Diff line change
@@ -31,9 +31,6 @@ public final class ImagenModel {
/// The resource name of the model in the backend; has the format "models/model-name".
let modelResourceName: String

/// Configuration for the backend API used by this model.
let apiConfig: APIConfig

/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

@@ -52,9 +49,9 @@ public final class ImagenModel {
requestOptions: RequestOptions,
urlSession: URLSession = .shared) {
modelResourceName = name
self.apiConfig = apiConfig
generativeAIService = GenerativeAIService(
firebaseInfo: firebaseInfo,
apiConfig: apiConfig,
urlSession: urlSession
)
self.generationConfig = generationConfig
@@ -129,7 +126,6 @@ public final class ImagenModel {
-> ImagenGenerationResponse<T> where T: Decodable, T: ImagenImageRepresentable {
let request = ImagenGenerationRequest<T>(
model: modelResourceName,
apiConfig: apiConfig,
options: requestOptions,
instances: [ImageGenerationInstance(prompt: prompt)],
parameters: parameters
Original file line number Diff line number Diff line change
@@ -47,7 +47,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
func testInitializeRequest_inlineDataImage() throws {
let request = ImagenGenerationRequest<ImagenInlineImage>(
model: modelName,
apiConfig: apiConfig,
options: requestOptions,
instances: [instance],
parameters: parameters
@@ -58,7 +57,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
XCTAssertEqual(request.instances, [instance])
XCTAssertEqual(request.parameters, parameters)
XCTAssertEqual(
request.url,
request.requestURL(apiConfig: apiConfig),
URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
@@ -67,7 +66,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
func testInitializeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
apiConfig: apiConfig,
options: requestOptions,
instances: [instance],
parameters: parameters
@@ -78,7 +76,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
XCTAssertEqual(request.instances, [instance])
XCTAssertEqual(request.parameters, parameters)
XCTAssertEqual(
request.url,
request.requestURL(apiConfig: apiConfig),
URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
@@ -89,7 +87,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
func testEncodeRequest_inlineDataImage() throws {
let request = ImagenGenerationRequest<ImagenInlineImage>(
model: modelName,
apiConfig: apiConfig,
options: RequestOptions(),
instances: [instance],
parameters: parameters
@@ -118,7 +115,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
func testEncodeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
apiConfig: apiConfig,
options: RequestOptions(),
instances: [instance],
parameters: parameters
Original file line number Diff line number Diff line change
@@ -19,23 +19,16 @@ import XCTest

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class CountTokensRequestTests: XCTestCase {
let encoder = JSONEncoder()

let modelResourceName = "models/test-model-name"
let textPart = TextPart("test-prompt")
let vertexAPIConfig = APIConfig(service: .vertexAI, version: .v1beta)
let developerAPIConfig = APIConfig(
service: .developer(endpoint: .firebaseVertexAIProd),
version: .v1beta
let vertexEncoder = CountTokensRequestTests.encoder(
apiConfig: APIConfig(service: .vertexAI, version: .v1beta)
)
let developerEncoder = CountTokensRequestTests.encoder(
apiConfig: APIConfig(service: .developer(endpoint: .firebaseVertexAIProd), version: .v1beta)
)
let requestOptions = RequestOptions()

override func setUp() {
encoder.outputFormatting = .init(
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
)
}

// MARK: CountTokensRequest Encoding

func testEncodeCountTokensRequest_vertexAI_minimal() throws {
@@ -48,13 +41,12 @@ final class CountTokensRequestTests: XCTestCase {
tools: nil,
toolConfig: nil,
systemInstruction: nil,
apiConfig: vertexAPIConfig,
apiMethod: .countTokens,
options: requestOptions
)
let request = CountTokensRequest(generateContentRequest: generateContentRequest)

let jsonData = try encoder.encode(request)
let jsonData = try vertexEncoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
@@ -82,13 +74,12 @@ final class CountTokensRequestTests: XCTestCase {
tools: nil,
toolConfig: nil,
systemInstruction: nil,
apiConfig: developerAPIConfig,
apiMethod: .countTokens,
options: requestOptions
)
let request = CountTokensRequest(generateContentRequest: generateContentRequest)

let jsonData = try encoder.encode(request)
let jsonData = try developerEncoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
@@ -108,4 +99,12 @@ final class CountTokensRequestTests: XCTestCase {
}
""")
}

static func encoder(apiConfig: APIConfig) -> JSONEncoder {
let encoder = JSONEncoder(apiConfig: apiConfig)
encoder.outputFormatting = .init(
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
)
return encoder
}
}

Unchanged files with check annotations Beta

#if compiler(>=6)
private nonisolated(unsafe) var _history: [ModelContent] = []
#else
private var _history: [ModelContent] = []

Check warning on line 33 in FirebaseVertexAI/Sources/Chat.swift

GitHub Actions / spm-unit (macos-13, Xcode_15.2, iOS)

stored property '_history' of 'Sendable'-conforming class 'Chat' is mutable

Check warning on line 33 in FirebaseVertexAI/Sources/Chat.swift

GitHub Actions / spm-unit (macos-13, Xcode_15.2, iOS)

stored property '_history' of 'Sendable'-conforming class 'Chat' is mutable

Check warning on line 33 in FirebaseVertexAI/Sources/Chat.swift

GitHub Actions / spm-unit (macos-14, Xcode_15.4, iOS)

stored property '_history' of 'Sendable'-conforming class 'Chat' is mutable

Check warning on line 33 in FirebaseVertexAI/Sources/Chat.swift

GitHub Actions / spm-unit (macos-14, Xcode_15.4, iOS)

stored property '_history' of 'Sendable'-conforming class 'Chat' is mutable

Check warning on line 33 in FirebaseVertexAI/Sources/Chat.swift

GitHub Actions / spm-unit (macos-14, Xcode_15.4, iOS)

stored property '_history' of 'Sendable'-conforming class 'Chat' is mutable

Check warning on line 33 in FirebaseVertexAI/Sources/Chat.swift

GitHub Actions / spm-unit (macos-14, Xcode_15.4, iOS)

stored property '_history' of 'Sendable'-conforming class 'Chat' is mutable
#endif
/// The previous content from the chat that has been successfully sent and received from the
/// model. This will be provided to the model for each message sent as context for the discussion.
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
MockURLProtocol.requestHandler = { request in

Check warning on line 48 in FirebaseVertexAI/Tests/Unit/ChatTests.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
let response = HTTPURLResponse(
url: request.url!,
statusCode: 200,
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
return { request in

Check warning on line 1555 in FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
// This is *not* an HTTPURLResponse
let response = URLResponse(
url: request.url!,
#if os(watchOS)
throw XCTSkip("Custom URL protocols are unsupported in watchOS 2 and later.")
#endif // os(watchOS)
let bundle = BundleTestUtil.bundle()

Check warning on line 1583 in FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, watchOS)

code after 'throw' will never be executed
let fileURL = try XCTUnwrap(
bundle.url(forResource: name, withExtension: ext, subdirectory: subpath)
)
return "application/octet-stream"
}
if let type = UTTypeCreatePreferredIdentifierForTag(

Check warning on line 79 in FirebaseStorage/Sources/Internal/StorageUtils.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, visionOS)

'UTTypeCreatePreferredIdentifierForTag' was deprecated in visionOS 1.0: Use the UTType class instead.
kUTTagClassFilenameExtension,

Check warning on line 80 in FirebaseStorage/Sources/Internal/StorageUtils.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, visionOS)

'kUTTagClassFilenameExtension' was deprecated in visionOS 1.0: Use UTTagClassFilenameExtension instead.
fileExtension as NSString,
nil
)?.takeRetainedValue() {
if let mimeType = UTTypeCopyPreferredTagWithClass(type, kUTTagClassMIMEType)?

Check warning on line 84 in FirebaseStorage/Sources/Internal/StorageUtils.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, visionOS)

'UTTypeCopyPreferredTagWithClass' was deprecated in visionOS 1.0: Use the UTType class instead.

Check warning on line 84 in FirebaseStorage/Sources/Internal/StorageUtils.swift

GitHub Actions / spm-unit (macos-15, Xcode_16.2, visionOS)

'kUTTagClassMIMEType' was deprecated in visionOS 1.0: Use UTTagClassMIMEType instead.
.takeRetainedValue() {
return mimeType as String
}