diff --git a/go/ai/generate.go b/go/ai/generate.go index 6652d8df8b..93e08d7aeb 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -41,6 +41,17 @@ type ( Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) } + // ModelArg is the interface for model arguments. + ModelArg interface { + Name() string + } + + // ModelRef is a struct to hold model name and configuration. + ModelRef struct { + name string + config any + } + // ToolConfig handles configuration around tool calls during generation. ToolConfig struct { MaxTurns int // Maximum number of tool call iterations before erroring. @@ -87,6 +98,7 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc // DefineModel registers the given generate function as an action, and returns a [Model] that runs it. func DefineModel(r *registry.Registry, provider, name string, info *ModelInfo, fn ModelFunc) Model { + if info == nil { // Always make sure there's at least minimal metadata. info = &ModelInfo{ @@ -113,6 +125,17 @@ func DefineModel(r *registry.Registry, provider, name string, info *ModelInfo, f metadata["label"] = info.Label } + if info.ConfigSchema != nil { + metadata["customOptions"] = info.ConfigSchema + // Make sure "model" exists in metadata + if metadata["model"] == nil { + metadata["model"] = make(map[string]any) + } + // Add customOptios to the model metadata + modelMeta := metadata["model"].(map[string]any) + modelMeta["customOptions"] = info.ConfigSchema + } + // Create the middleware list middlewares := []ModelMiddleware{ simulateSystemPrompt(info, nil), @@ -162,7 +185,9 @@ func LookupModelByName(r *registry.Registry, modelName string) (Model, error) { // GenerateWithRequest is the central generation implementation for ai.Generate(), prompt.Execute(), and the GenerateAction direct call. func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *GenerateActionOptions, mw []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) { if opts.Model == "" { - opts.Model = r.LookupValue(registry.DefaultModelKey).(string) + if defaultModel, ok := r.LookupValue(registry.DefaultModelKey).(string); ok && defaultModel != "" { + opts.Model = defaultModel + } if opts.Model == "" { return nil, errors.New("ai.GenerateWithRequest: model is required") } @@ -209,7 +234,6 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera output.Format = string(OutputFormatJSON) } } - req := &ModelRequest{ Messages: opts.Messages, Config: opts.Config, @@ -280,9 +304,11 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) } } - modelName := genOpts.ModelName - if modelName == "" && genOpts.Model != nil { + var modelName string + if genOpts.Model != nil { modelName = genOpts.Model.Name() + } else { + modelName = genOpts.ModelName } tools := make([]string, len(genOpts.Tools)) @@ -316,6 +342,13 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) messages = append(messages, NewUserTextMessage(prompt)) } + // Apply Model config if no Generate config. + modelArg := genOpts.Model + if modelRef, ok := modelArg.(ModelRef); ok { + if genOpts.Config == nil { + genOpts.Config = modelRef.Config() + } + } actionOpts := &GenerateActionOptions{ Model: modelName, Messages: messages, @@ -626,3 +659,18 @@ func (m *Message) Text() string { } return sb.String() } + +// NewModelRef creates a new ModelRef with the given name and configuration. +func NewModelRef(name string, config any) ModelRef { + return ModelRef{name: name, config: config} +} + +// Name returns the name of the ModelRef. +func (m ModelRef) Name() string { + return m.name +} + +// ModelConfig returns the configuration of a ModelRef. +func (m ModelRef) Config() any { + return m.config +} diff --git a/go/ai/option.go b/go/ai/option.go index 17bfe318ac..41d9654b58 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -34,8 +34,8 @@ type messagesFn = func(context.Context, any) ([]*Message, error) // commonOptions are common options for model generation, prompt definition, and prompt execution. type commonOptions struct { - ModelName string // Name of the model to use. - Model Model // Model to use. + Model ModelArg // Resolvable reference to a model to use with optional embedded config. + ModelName string // Name of model to use MessagesFn messagesFn // Messages function. If this is set, Messages should be an empty. Config any // Model configuration. If nil will be taken from the prompt config. Tools []ToolRef // References to tools to use. @@ -67,6 +67,7 @@ func (o *commonOptions) applyCommon(opts *commonOptions) error { return errors.New("cannot set model more than once (either WithModel or WithModelName)") } opts.Model = o.Model + return nil } if o.ModelName != "" { @@ -164,8 +165,8 @@ func WithConfig(config any) CommonOption { return &commonOptions{Config: config} } -// WithModel sets the model to call for generation. -func WithModel(model Model) CommonOption { +// WithModel sets a resolvable model reference to use for generation. +func WithModel(model ModelArg) CommonOption { return &commonOptions{Model: model} } diff --git a/go/ai/prompt.go b/go/ai/prompt.go index ca3dc0a45e..cb8cb82434 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -50,7 +50,13 @@ func DefinePrompt(r *registry.Registry, name string, opts ...PromptOption) (*Pro return nil, err } } - + // Apply Model config if no Prompt config. + modelArg := pOpts.Model + if modelRef, ok := modelArg.(ModelRef); ok { + if pOpts.Config == nil { + pOpts.Config = modelRef.Config() + } + } p := &Prompt{ registry: r, promptOptions: *pOpts, @@ -113,6 +119,13 @@ func (p *Prompt) Execute(ctx context.Context, opts ...PromptGenerateOption) (*Mo return nil, err } } + // Apply Model config if no Prompt Generate config. + modelArg := genOpts.Model + if modelRef, ok := modelArg.(ModelRef); ok { + if genOpts.Config == nil { + genOpts.Config = modelRef.Config() + } + } p.MessagesFn = mergeMessagesFn(p.MessagesFn, genOpts.MessagesFn) diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index 04be1c5d15..71dbde8ae5 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -18,6 +18,7 @@ package googlegenai import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -25,6 +26,8 @@ import ( "strings" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/base" + "github.com/invopop/jsonschema" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" @@ -69,6 +72,141 @@ type EmbedOptions struct { TaskType string `json:"task_type,omitempty"` } +func convertConfigSchemaToMap(config any) map[string]any { + r := jsonschema.Reflector{ + DoNotReference: true, // Prevent $ref usage + ExpandedStruct: true, // Include all fields directly + } + schema := r.Reflect(config) + result := base.SchemaAsMap(schema) + return result +} + +func mapToStruct(m map[string]any, v any) error { + jsonData, err := json.Marshal(m) + if err != nil { + return err + } + return json.Unmarshal(jsonData, v) +} + +func convertSafetySettings(settings []*SafetySetting) []*genai.SafetySetting { + if len(settings) == 0 { + return nil + } + + result := make([]*genai.SafetySetting, len(settings)) + for i, s := range settings { + result[i] = &genai.SafetySetting{ + Method: genai.HarmBlockMethod(s.Method), + Category: genai.HarmCategory(s.Category), + Threshold: genai.HarmBlockThreshold(s.Threshold), + } + } + return result +} + +type HarmCategory string + +const ( + // The harm category is unspecified. + HarmCategoryUnspecified HarmCategory = "HARM_CATEGORY_UNSPECIFIED" + // The harm category is hate speech. + HarmCategoryHateSpeech HarmCategory = "HARM_CATEGORY_HATE_SPEECH" + // The harm category is dangerous content. + HarmCategoryDangerousContent HarmCategory = "HARM_CATEGORY_DANGEROUS_CONTENT" + // The harm category is harassment. + HarmCategoryHarassment HarmCategory = "HARM_CATEGORY_HARASSMENT" + // The harm category is sexually explicit content. + HarmCategorySexuallyExplicit HarmCategory = "HARM_CATEGORY_SEXUALLY_EXPLICIT" + // The harm category is civic integrity. + HarmCategoryCivicIntegrity HarmCategory = "HARM_CATEGORY_CIVIC_INTEGRITY" +) + +// Specify if the threshold is used for probability or severity score. If not specified, +// the threshold is used for probability score. +type HarmBlockMethod string + +const ( + // The harm block method is unspecified. + HarmBlockMethodUnspecified HarmBlockMethod = "HARM_BLOCK_METHOD_UNSPECIFIED" + // The harm block method uses both probability and severity scores. + HarmBlockMethodSeverity HarmBlockMethod = "SEVERITY" + // The harm block method uses the probability score. + HarmBlockMethodProbability HarmBlockMethod = "PROBABILITY" +) + +// The harm block threshold. +type HarmBlockThreshold string + +const ( + // Unspecified harm block threshold. + HarmBlockThresholdUnspecified HarmBlockThreshold = "HARM_BLOCK_THRESHOLD_UNSPECIFIED" + // Block low threshold and above (i.e. block more). + HarmBlockThresholdBlockLowAndAbove HarmBlockThreshold = "BLOCK_LOW_AND_ABOVE" + // Block medium threshold and above. + HarmBlockThresholdBlockMediumAndAbove HarmBlockThreshold = "BLOCK_MEDIUM_AND_ABOVE" + // Block only high threshold (i.e. block less). + HarmBlockThresholdBlockOnlyHigh HarmBlockThreshold = "BLOCK_ONLY_HIGH" + // Block none. + HarmBlockThresholdBlockNone HarmBlockThreshold = "BLOCK_NONE" + // Turn off the safety filter. + HarmBlockThresholdOff HarmBlockThreshold = "OFF" +) + +// Safety settings. +type SafetySetting struct { + // Determines if the harm block method uses probability or probability + // and severity scores. + Method HarmBlockMethod `json:"method,omitempty"` + // Required. Harm category. + Category HarmCategory `json:"category,omitempty"` + // Required. The harm block threshold. + Threshold HarmBlockThreshold `json:"threshold,omitempty"` +} + +// GeminiConfig mirrors GenerateContentConfig without direct genai dependency +type GeminiConfig struct { + ai.GenerationCommonConfig + + // Safety settings + SafetySettings []*SafetySetting `json:"safetySettings,omitempty"` +} + +// extractConfigFromInput converts any supported config type to GoogleAIConfig +func extractConfigFromInput(input *ai.ModelRequest) (*GeminiConfig, error) { + var result GeminiConfig + switch config := input.Config.(type) { + case GeminiConfig: + return &config, nil + case *GeminiConfig: + return config, nil + case ai.GenerationCommonConfig: + return &GeminiConfig{ + GenerationCommonConfig: config, + }, nil + case *ai.GenerationCommonConfig: + if config == nil { + return &result, nil + } + return &GeminiConfig{ + GenerationCommonConfig: *config, + }, nil + case map[string]any: + // // TODO: FYI Using map[string]any for config may silently ignore unknown parameters, may want to handle explicitly + if err := mapToStruct(config, &result); err == nil { + return &result, nil + } else { + return nil, err + } + case nil: + // Empty but valid config + return &result, nil + default: + return nil, fmt.Errorf("unexpected config type: %T", input.Config) + } +} + // DefineModel defines a model in the registry func defineModel(g *genkit.Genkit, client *genai.Client, name string, info ai.ModelInfo) ai.Model { provider := googleAIProvider @@ -77,9 +215,10 @@ func defineModel(g *genkit.Genkit, client *genai.Client, name string, info ai.Mo } meta := &ai.ModelInfo{ - Label: info.Label, - Supports: info.Supports, - Versions: info.Versions, + Label: info.Label, + Supports: info.Supports, + Versions: info.Versions, + ConfigSchema: convertConfigSchemaToMap(&GeminiConfig{}), } fn := func( @@ -163,12 +302,15 @@ func generate( input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, ) (*ai.ModelResponse, error) { - // since context caching is only available for specific model versions, we - // must make sure the configuration has the right version - if c, ok := input.Config.(*ai.GenerationCommonConfig); ok { - if c != nil && c.Version != "" { - model = c.Version - } + // Extract configuration to get the model version + config, err := extractConfigFromInput(input) + if err != nil { + return nil, err + } + + // Update model with version if specified + if config.Version != "" { + model = config.Version } cache, err := handleCache(ctx, client, input, model) @@ -176,7 +318,7 @@ func generate( return nil, err } - gc, err := convertRequest(client, model, input, cache) + gc, err := convertRequest(client, input, cache) if err != nil { return nil, err } @@ -263,24 +405,55 @@ func generate( // convertRequest translates from [*ai.ModelRequest] to // *genai.GenerateContentParameters -func convertRequest(client *genai.Client, model string, input *ai.ModelRequest, cache *genai.CachedContent) (*genai.GenerateContentConfig, error) { +func convertRequest(client *genai.Client, input *ai.ModelRequest, cache *genai.CachedContent) (*genai.GenerateContentConfig, error) { gc := genai.GenerateContentConfig{} gc.CandidateCount = genai.Ptr[int32](1) - if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil { - if c.MaxOutputTokens != 0 { - gc.MaxOutputTokens = genai.Ptr[int32](int32(c.MaxOutputTokens)) - } - if len(c.StopSequences) > 0 { - gc.StopSequences = c.StopSequences - } - if c.Temperature != 0 { - gc.Temperature = genai.Ptr[float32](float32(c.Temperature)) - } - if c.TopK != 0 { - gc.TopK = genai.Ptr[float32](float32(c.TopK)) + c, err := extractConfigFromInput(input) + if err != nil { + return nil, err + } + // Convert standard fields + if c.MaxOutputTokens != 0 { + gc.MaxOutputTokens = genai.Ptr[int32](int32(c.MaxOutputTokens)) + } + if len(c.StopSequences) > 0 { + gc.StopSequences = c.StopSequences + } + if c.Temperature != 0 { + gc.Temperature = genai.Ptr[float32](float32(c.Temperature)) + } + if c.TopK != 0 { + gc.TopK = genai.Ptr[float32](float32(c.TopK)) + } + if c.TopP != 0 { + gc.TopP = genai.Ptr[float32](float32(c.TopP)) + } + // Convert non-primitive fields + gc.SafetySettings = convertSafetySettings(c.SafetySettings) + + // Set response MIME type based on output format if specified + hasOutput := input.Output != nil + isJsonFormat := hasOutput && input.Output.Format == "json" + isJsonContentType := hasOutput && input.Output.ContentType == "application/json" + jsonMode := isJsonFormat || (isJsonContentType && len(input.Tools) == 0) + if jsonMode { + gc.ResponseMIMEType = "application/json" + } + + // Add tool configuration from input.Tools and input.ToolChoice directly + // This overrides any functionCallingConfig in the passed config + if len(input.Tools) > 0 { + // First convert the tools + tools, err := convertTools(input.Tools) + if err != nil { + return nil, err } - if c.TopP != 0 { - gc.TopP = genai.Ptr[float32](float32(c.TopP)) + gc.Tools = tools + + // Then set up the tool configuration based on ToolChoice + tc := convertToolChoice(input.ToolChoice, input.Tools) + if tc != nil { + gc.ToolConfig = tc } } @@ -302,15 +475,6 @@ func convertRequest(client *genai.Client, model string, input *ai.ModelRequest, } } - tools, err := convertTools(input.Tools) - if err != nil { - return nil, err - } - gc.Tools = tools - - choice := convertToolChoice(input.ToolChoice, input.Tools) - gc.ToolConfig = choice - if cache != nil { gc.CachedContent = cache.Name } diff --git a/go/plugins/googlegenai/googlegenai.go b/go/plugins/googlegenai/googlegenai.go index e5f16c4ca8..51908d89c8 100644 --- a/go/plugins/googlegenai/googlegenai.go +++ b/go/plugins/googlegenai/googlegenai.go @@ -277,6 +277,16 @@ func (v *VertexAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool { return genkit.LookupEmbedder(g, vertexAIProvider, name) != nil } +// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given name and configuration. +func GoogleAIModelRef(name string, config *GeminiConfig) ai.ModelRef { + return ai.NewModelRef(googleAIProvider+"/"+name, config) +} + +// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given name and configuration. +func VertexAIModelRef(name string, config *GeminiConfig) ai.ModelRef { + return ai.NewModelRef(vertexAIProvider+"/"+name, config) +} + // GoogleAIModel returns the [ai.Model] with the given name. // It returns nil if the model was not defined. func GoogleAIModel(g *genkit.Genkit, name string) ai.Model { diff --git a/go/samples/basic-gemini/main.go b/go/samples/basic-gemini/main.go index d4b2afc56b..d6d4d248fa 100644 --- a/go/samples/basic-gemini/main.go +++ b/go/samples/basic-gemini/main.go @@ -46,9 +46,17 @@ func main() { resp, err := genkit.Generate(ctx, g, ai.WithModel(m), - ai.WithConfig(&ai.GenerationCommonConfig{ - Temperature: 1, - Version: "gemini-2.0-flash-001", + ai.WithConfig(&googlegenai.GeminiConfig{ + GenerationCommonConfig: ai.GenerationCommonConfig{ + Temperature: 1.0, + MaxOutputTokens: 256, + }, + SafetySettings: []*googlegenai.SafetySetting{ + { + Category: googlegenai.HarmCategoryHarassment, + Threshold: googlegenai.HarmBlockThresholdBlockMediumAndAbove, + }, + }, }), ai.WithPromptText(fmt.Sprintf(`Tell silly short jokes about %s`, input))) if err != nil {