diff --git a/go/ai/formatter_test.go b/go/ai/formatter_test.go index 40afd0008d..792247e665 100644 --- a/go/ai/formatter_test.go +++ b/go/ai/formatter_test.go @@ -661,17 +661,14 @@ func TestResolveFormat(t *testing.T) { } }) - t.Run("defaults to text even when schema present but no format", func(t *testing.T) { + t.Run("defaults to json even when schema present but no format", func(t *testing.T) { schema := map[string]any{"type": "object"} formatter, err := resolveFormat(r, schema, "") if err != nil { t.Fatalf("resolveFormat() error = %v", err) } - // Note: The current implementation defaults to text when format is empty, - // even if schema is present. The schema/format combination is typically - // handled at a higher level (e.g., in Generate options). - if formatter.Name() != OutputFormatText { - t.Errorf("resolveFormat() = %q, want %q", formatter.Name(), OutputFormatText) + if formatter.Name() != OutputFormatJSON { + t.Errorf("resolveFormat() = %q, want %q", formatter.Name(), OutputFormatJSON) } }) diff --git a/go/go.mod b/go/go.mod index 3aa1cd948b..3472c0f4cb 100644 --- a/go/go.mod +++ b/go/go.mod @@ -41,7 +41,7 @@ require ( golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 golang.org/x/tools v0.34.0 google.golang.org/api v0.236.0 - google.golang.org/genai v1.36.0 + google.golang.org/genai v1.40.0 ) require ( diff --git a/go/go.sum b/go/go.sum index 43f5ac29cd..e7abcc1495 100644 --- a/go/go.sum +++ b/go/go.sum @@ -537,8 +537,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw= google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= -google.golang.org/genai v1.36.0 h1:sJCIjqTAmwrtAIaemtTiKkg2TO1RxnYEusTmEQ3nGxM= -google.golang.org/genai v1.36.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genai v1.40.0 h1:kYxyQSH+vsib8dvsgyLJzsVEIv5k3ZmHJyVqdvGncmc= +google.golang.org/genai v1.40.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index c7e49a8bec..192b7ab4bf 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -484,6 +484,34 @@ func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { return outTools, nil } +// toGeminiFunctionResponsePart translates a slice of [ai.Part] to a slice of [genai.FunctionResponsePart] +func toGeminiFunctionResponsePart(parts []*ai.Part) ([]*genai.FunctionResponsePart, error) { + frp := []*genai.FunctionResponsePart{} + for _, p := range parts { + switch { + case p.IsData(): + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) + case p.IsMedia(): + if strings.HasPrefix(p.Text, "data:") { + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) + continue + } + frp = append(frp, genai.NewFunctionResponsePartFromURI(p.Text, p.ContentType)) + default: + return nil, fmt.Errorf("unsupported function response part type: %d", p.Kind) + } + } + return frp, nil +} + // mergeTools consolidates all FunctionDeclarations into a single Tool // while preserving non-function tools (Retrieval, GoogleSearch, CodeExecution, etc.) func mergeTools(ts []*genai.Tool) []*genai.Tool { @@ -814,6 +842,14 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { Name: part.FunctionCall.Name, Input: part.FunctionCall.Args, }) + // FunctionCall parts may contain a ThoughtSignature that must be preserved + // and returned in subsequent requests for the tool call to be valid. + if len(part.ThoughtSignature) > 0 { + if p.Metadata == nil { + p.Metadata = make(map[string]any) + } + p.Metadata["signature"] = part.ThoughtSignature + } } if part.CodeExecutionResult != nil { partFound++ @@ -894,7 +930,7 @@ func toGeminiParts(parts []*ai.Part) ([]*genai.Part, error) { func toGeminiPart(p *ai.Part) (*genai.Part, error) { switch { case p.IsReasoning(): - // TODO: go-genai does not support genai.NewPartFromThought() + // NOTE: go-genai does not support genai.NewPartFromThought() signature := []byte{} if p.Metadata != nil { if sig, ok := p.Metadata["signature"].([]byte); ok { @@ -934,8 +970,22 @@ func toGeminiPart(p *ai.Part) (*genai.Part, error) { "content": toolResp.Output, } } - fr := genai.NewPartFromFunctionResponse(toolResp.Name, output) - return fr, nil + + var isMultipart bool + if multiPart, ok := p.Metadata["multipart"].(bool); ok { + isMultipart = multiPart + } + if len(toolResp.Content) > 0 { + isMultipart = true + } + if isMultipart { + toolRespParts, err := toGeminiFunctionResponsePart(toolResp.Content) + if err != nil { + return nil, err + } + return genai.NewPartFromFunctionResponseWithParts(toolResp.Name, output, toolRespParts), nil + } + return genai.NewPartFromFunctionResponse(toolResp.Name, output), nil case p.IsToolRequest(): toolReq := p.ToolRequest var input map[string]any @@ -947,6 +997,12 @@ func toGeminiPart(p *ai.Part) (*genai.Part, error) { } } fc := genai.NewPartFromFunctionCall(toolReq.Name, input) + // Restore ThoughtSignature if present in metadata + if p.Metadata != nil { + if sig, ok := p.Metadata["signature"].([]byte); ok { + fc.ThoughtSignature = sig + } + } return fc, nil default: panic("unknown part type in a request") diff --git a/go/plugins/googlegenai/gemini_test.go b/go/plugins/googlegenai/gemini_test.go index daa4da215e..9aae76a054 100644 --- a/go/plugins/googlegenai/gemini_test.go +++ b/go/plugins/googlegenai/gemini_test.go @@ -707,6 +707,82 @@ func TestValidToolName(t *testing.T) { } } +func TestToGeminiParts_MultipartToolResponse(t *testing.T) { + t.Run("ValidPartType", func(t *testing.T) { + // Create a tool response with both output and additional content (media) + toolResp := &ai.ToolResponse{ + Name: "generateImage", + Output: map[string]any{"status": "success"}, + Content: []*ai.Part{ + ai.NewMediaPart("image/png", ""), + }, + } + + // create a mock ToolResponsePart, setting "multipart" to true is required + part := ai.NewToolResponsePart(toolResp) + part.Metadata = map[string]any{"multipart": true} + + geminiParts, err := toGeminiParts([]*ai.Part{part}) + if err != nil { + t.Fatalf("toGeminiParts failed: %v", err) + } + + // Expecting 1 part which contains the function response with internal parts + if len(geminiParts) != 1 { + t.Fatalf("expected 1 Gemini part, got %d", len(geminiParts)) + } + + if geminiParts[0].FunctionResponse == nil { + t.Error("expected first part to be FunctionResponse") + } + if geminiParts[0].FunctionResponse.Name != "generateImage" { + t.Errorf("expected function name 'generateImage', got %q", geminiParts[0].FunctionResponse.Name) + } + }) + + t.Run("UnsupportedPartType", func(t *testing.T) { + // Create a tool response with text content (unsupported for multipart) + toolResp := &ai.ToolResponse{ + Name: "generateText", + Output: map[string]any{"status": "success"}, + Content: []*ai.Part{ + ai.NewTextPart("Generated text"), + }, + } + + part := ai.NewToolResponsePart(toolResp) + part.Metadata = map[string]any{"multipart": true} + + _, err := toGeminiParts([]*ai.Part{part}) + if err == nil { + t.Fatal("expected error for unsupported text part in multipart response, got nil") + } + }) +} + +func TestToGeminiParts_SimpleToolResponse(t *testing.T) { + // Create a simple tool response (no content) + toolResp := &ai.ToolResponse{ + Name: "search", + Output: map[string]any{"result": "foo"}, + } + + part := ai.NewToolResponsePart(toolResp) + + geminiParts, err := toGeminiParts([]*ai.Part{part}) + if err != nil { + t.Fatalf("toGeminiParts failed: %v", err) + } + + if len(geminiParts) != 1 { + t.Fatalf("expected 1 Gemini part, got %d", len(geminiParts)) + } + + if geminiParts[0].FunctionResponse == nil { + t.Error("expected part to be FunctionResponse") + } +} + // genToolName generates a string of a specified length using only // the valid characters for a Gemini Tool name func genToolName(length int, chars string) string { diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go index 4e78ab4f4c..783eccd239 100644 --- a/go/plugins/googlegenai/googleai_live_test.go +++ b/go/plugins/googlegenai/googleai_live_test.go @@ -170,7 +170,7 @@ func TestGoogleAILive(t *testing.T) { t.Fatal(err) } - out := resp.Message.Content[0].Text + out := resp.Text() const want = "11.31" if !strings.Contains(out, want) { t.Errorf("got %q, expecting it to contain %q", out, want) @@ -219,7 +219,7 @@ func TestGoogleAILive(t *testing.T) { t.Fatal(err) } - out := resp.Message.Content[0].Text + out := resp.Text() const want = "11.31" if !strings.Contains(out, want) { t.Errorf("got %q, expecting it to contain %q", out, want) @@ -307,7 +307,7 @@ func TestGoogleAILive(t *testing.T) { t.Fatal(err) } - out := resp.Message.Content[0].Text + out := resp.Text() const doNotWant = "11.31" if strings.Contains(out, doNotWant) { t.Errorf("got %q, expecting it NOT to contain %q", out, doNotWant) @@ -582,6 +582,37 @@ func TestGoogleAILive(t *testing.T) { t.Fatal("thoughts tokens should be zero") } }) + t.Run("multipart tool", func(t *testing.T) { + m := googlegenai.GoogleAIModel(g, "gemini-3-pro-preview") + img64, err := fetchImgAsBase64() + if err != nil { + t.Fatal(err) + } + + tool := genkit.DefineMultipartTool(g, "getImage", "returns a misterious image", + func(ctx *ai.ToolContext, input any) (*ai.MultipartToolResponse, error) { + return &ai.MultipartToolResponse{ + Output: map[string]any{"status": "success"}, + Content: []*ai.Part{ + ai.NewMediaPart("image/jpeg", "data:image/jpeg;base64,"+img64), + }, + }, nil + }, + ) + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithTools(tool), + ai.WithPrompt("get an image and tell me what is in it"), + ) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(strings.ToLower(resp.Text()), "cat") { + t.Errorf("expected response to contain 'cat', got: %s", resp.Text()) + } + }) } func TestCacheHelper(t *testing.T) { diff --git a/go/samples/multipart-tools/main.go b/go/samples/multipart-tools/main.go new file mode 100644 index 0000000000..c9cb04bdc3 --- /dev/null +++ b/go/samples/multipart-tools/main.go @@ -0,0 +1,68 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + // Define a multipart tool. + // This simulates a tool that takes a screenshot + screenshot := genkit.DefineMultipartTool(g, "screenshot", "Takes a screenshot", + func(ctx *ai.ToolContext, input any) (*ai.MultipartToolResponse, error) { + rectangle := "" + + "AAAAI0lEQVR4nGNgGHaA/z8UHIDwOWASDqP8Uf7w56On/1FAQwAAVM0exw1hqwkAAAAASUVORK5CYII=" + return &ai.MultipartToolResponse{ + Output: map[string]any{"success": true}, + Content: []*ai.Part{ + ai.NewMediaPart("image/png", rectangle), + }, + }, nil + }, + ) + + // Define a simple flow that uses the multipart tool + genkit.DefineStreamingFlow(g, "cardFlow", func(ctx context.Context, input any, cb ai.ModelStreamCallback) (string, error) { + resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-3-pro-preview"), + ai.WithConfig(&genai.GenerateContentConfig{ + Temperature: genai.Ptr[float32](1.0), + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingLevel: genai.ThinkingLevelHigh, + }, + }), + ai.WithTools(screenshot), + ai.WithStreaming(cb), + ai.WithPrompt("Tell me what I'm seeing in the screen"), + ) + if err != nil { + return "", err + } + + return resp.Text(), nil + }) + + <-ctx.Done() +}