Skip to content

feat(go): Added ModelArg interface and ModelRef #2487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
cf91032
feat(go): Add ModelArg interface and ModelRef
huangjeff5 Mar 27, 2025
1f2c67c
Fix tests
huangjeff5 Mar 27, 2025
e9aa30f
Update prompt.go
huangjeff5 Mar 27, 2025
102a353
force types in GoogleAIModelRef
huangjeff5 Mar 27, 2025
efcc32e
Merge branch 'jh-modelarg' of https://github.com/firebase/genkit into…
huangjeff5 Mar 27, 2025
9901617
Merge branch 'main' into jh-modelarg
huangjeff5 Apr 1, 2025
2c00047
Add support for Dev UI and custom configs
huangjeff5 Apr 1, 2025
a6d577a
Fix defineEmbedder
huangjeff5 Apr 1, 2025
bf8977f
fix
huangjeff5 Apr 1, 2025
84ffefd
remove references to Resolve
huangjeff5 Apr 1, 2025
f44d635
Add configs
huangjeff5 Apr 2, 2025
0b1d2c4
Fix custom options
huangjeff5 Apr 2, 2025
136ff08
add Version back
huangjeff5 Apr 2, 2025
a37348d
clean up convertRequest
huangjeff5 Apr 2, 2025
8efa682
Fix dev UI
huangjeff5 Apr 2, 2025
771bfb4
Update go/plugins/googlegenai/googlegenai.go
huangjeff5 Apr 2, 2025
02ee501
Update go/ai/option.go
huangjeff5 Apr 2, 2025
a8063ca
Update go/ai/option.go
huangjeff5 Apr 2, 2025
06651c7
Update go/ai/generate.go
huangjeff5 Apr 2, 2025
885d977
Change ModelArg to Model
huangjeff5 Apr 2, 2025
496d6ce
Fix
huangjeff5 Apr 2, 2025
ff79a42
Fix types
huangjeff5 Apr 2, 2025
407610e
Make config provide and expose via Config()
huangjeff5 Apr 2, 2025
3cd7b89
remove extra types
huangjeff5 Apr 2, 2025
5e7a06d
change config types
huangjeff5 Apr 3, 2025
1d5fb9d
Address comments
huangjeff5 Apr 3, 2025
2160def
remove Ptr
huangjeff5 Apr 3, 2025
9212d21
fix VertexAIModelRef def
huangjeff5 Apr 3, 2025
5aea219
extract config from all types
huangjeff5 Apr 3, 2025
ed5d4ee
Update gemini.go
huangjeff5 Apr 3, 2025
b1626f7
clean up
huangjeff5 Apr 3, 2025
a86304a
Merge branch 'jh-modelarg' of https://github.com/firebase/genkit into…
huangjeff5 Apr 3, 2025
8437142
fix
huangjeff5 Apr 3, 2025
ecd2abf
fix
huangjeff5 Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ type (
Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error)
}

// ModelArg is the interface for model arguments.
ModelArg interface {
Name() string
}

// ModelRef is a struct to hold model name and configuration.
ModelRef struct {
name string
config any
}

// ToolConfig handles configuration around tool calls during generation.
ToolConfig struct {
MaxTurns int // Maximum number of tool call iterations before erroring.
Expand Down Expand Up @@ -87,6 +98,7 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc

// DefineModel registers the given generate function as an action, and returns a [Model] that runs it.
func DefineModel(r *registry.Registry, provider, name string, info *ModelInfo, fn ModelFunc) Model {

if info == nil {
// Always make sure there's at least minimal metadata.
info = &ModelInfo{
Expand All @@ -113,6 +125,17 @@ func DefineModel(r *registry.Registry, provider, name string, info *ModelInfo, f
metadata["label"] = info.Label
}

if info.ConfigSchema != nil {
metadata["customOptions"] = info.ConfigSchema
// Make sure "model" exists in metadata
if metadata["model"] == nil {
metadata["model"] = make(map[string]any)
}
// Add customOptios to the model metadata
modelMeta := metadata["model"].(map[string]any)
modelMeta["customOptions"] = info.ConfigSchema
}

// Create the middleware list
middlewares := []ModelMiddleware{
simulateSystemPrompt(info, nil),
Expand Down Expand Up @@ -162,7 +185,9 @@ func LookupModelByName(r *registry.Registry, modelName string) (Model, error) {
// GenerateWithRequest is the central generation implementation for ai.Generate(), prompt.Execute(), and the GenerateAction direct call.
func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *GenerateActionOptions, mw []ModelMiddleware, cb ModelStreamCallback) (*ModelResponse, error) {
if opts.Model == "" {
opts.Model = r.LookupValue(registry.DefaultModelKey).(string)
if defaultModel, ok := r.LookupValue(registry.DefaultModelKey).(string); ok && defaultModel != "" {
opts.Model = defaultModel
}
if opts.Model == "" {
return nil, errors.New("ai.GenerateWithRequest: model is required")
}
Expand Down Expand Up @@ -209,7 +234,6 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera
output.Format = string(OutputFormatJSON)
}
}

req := &ModelRequest{
Messages: opts.Messages,
Config: opts.Config,
Expand Down Expand Up @@ -280,9 +304,11 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
}
}

modelName := genOpts.ModelName
if modelName == "" && genOpts.Model != nil {
var modelName string
if genOpts.Model != nil {
modelName = genOpts.Model.Name()
} else {
modelName = genOpts.ModelName
}

tools := make([]string, len(genOpts.Tools))
Expand Down Expand Up @@ -316,6 +342,13 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
messages = append(messages, NewUserTextMessage(prompt))
}

// Apply Model config if no Generate config.
modelArg := genOpts.Model
if modelRef, ok := modelArg.(ModelRef); ok {
if genOpts.Config == nil {
genOpts.Config = modelRef.Config()
}
}
actionOpts := &GenerateActionOptions{
Model: modelName,
Messages: messages,
Expand Down Expand Up @@ -626,3 +659,18 @@ func (m *Message) Text() string {
}
return sb.String()
}

// NewModelRef creates a new ModelRef with the given name and configuration.
func NewModelRef(name string, config any) ModelRef {
return ModelRef{name: name, config: config}
}

// Name returns the name of the ModelRef.
func (m ModelRef) Name() string {
return m.name
}

// ModelConfig returns the configuration of a ModelRef.
func (m ModelRef) Config() any {
return m.config
}
9 changes: 5 additions & 4 deletions go/ai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ type messagesFn = func(context.Context, any) ([]*Message, error)

// commonOptions are common options for model generation, prompt definition, and prompt execution.
type commonOptions struct {
ModelName string // Name of the model to use.
Model Model // Model to use.
Model ModelArg // Resolvable reference to a model to use with optional embedded config.
ModelName string // Name of model to use
MessagesFn messagesFn // Messages function. If this is set, Messages should be an empty.
Config any // Model configuration. If nil will be taken from the prompt config.
Tools []ToolRef // References to tools to use.
Expand Down Expand Up @@ -67,6 +67,7 @@ func (o *commonOptions) applyCommon(opts *commonOptions) error {
return errors.New("cannot set model more than once (either WithModel or WithModelName)")
}
opts.Model = o.Model
return nil
}

if o.ModelName != "" {
Expand Down Expand Up @@ -164,8 +165,8 @@ func WithConfig(config any) CommonOption {
return &commonOptions{Config: config}
}

// WithModel sets the model to call for generation.
func WithModel(model Model) CommonOption {
// WithModel sets a resolvable model reference to use for generation.
func WithModel(model ModelArg) CommonOption {
return &commonOptions{Model: model}
}

Expand Down
15 changes: 14 additions & 1 deletion go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ func DefinePrompt(r *registry.Registry, name string, opts ...PromptOption) (*Pro
return nil, err
}
}

// Apply Model config if no Prompt config.
modelArg := pOpts.Model
if modelRef, ok := modelArg.(ModelRef); ok {
if pOpts.Config == nil {
pOpts.Config = modelRef.Config()
}
}
p := &Prompt{
registry: r,
promptOptions: *pOpts,
Expand Down Expand Up @@ -113,6 +119,13 @@ func (p *Prompt) Execute(ctx context.Context, opts ...PromptGenerateOption) (*Mo
return nil, err
}
}
// Apply Model config if no Prompt Generate config.
modelArg := genOpts.Model
if modelRef, ok := modelArg.(ModelRef); ok {
if genOpts.Config == nil {
genOpts.Config = modelRef.Config()
}
}

p.MessagesFn = mergeMessagesFn(p.MessagesFn, genOpts.MessagesFn)

Expand Down
Loading
Loading