@@ -20,6 +20,8 @@ import (
2020 "regexp"
2121 "slices"
2222 "strings"
23+
24+ "github.com/firebase/genkit/go/core"
2325)
2426
2527type enumFormatter struct {}
@@ -33,14 +35,15 @@ func (e enumFormatter) Name() string {
3335func (e enumFormatter ) Handler (schema map [string ]any ) (FormatHandler , error ) {
3436 enums := objectEnums (schema )
3537 if schema == nil || len (enums ) == 0 {
36- return nil , fmt . Errorf ( "schema is not valid JSON enum" )
38+ return nil , core . NewError ( core . INVALID_ARGUMENT , "schema must be an object with an ' enum' property for enum format " )
3739 }
3840
3941 instructions := fmt .Sprintf ("Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n \n ```%s```" , strings .Join (enums , "\n " ))
4042
4143 handler := & enumHandler {
4244 instructions : instructions ,
4345 config : ModelOutputConfig {
46+ Constrained : true ,
4447 Format : OutputFormatEnum ,
4548 Schema : schema ,
4649 ContentType : "text/enum" ,
@@ -52,23 +55,49 @@ func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) {
5255}
5356
5457type enumHandler struct {
55- instructions string
56- config ModelOutputConfig
57- enums []string
58+ instructions string
59+ config ModelOutputConfig
60+ enums []string
61+ accumulatedText string
62+ currentIndex int
5863}
5964
6065// Instructions returns the instructions for the formatter.
61- func (e enumHandler ) Instructions () string {
66+ func (e * enumHandler ) Instructions () string {
6267 return e .instructions
6368}
6469
6570// Config returns the output config for the formatter.
66- func (e enumHandler ) Config () ModelOutputConfig {
71+ func (e * enumHandler ) Config () ModelOutputConfig {
6772 return e .config
6873}
6974
75+ // ParseOutput parses the final message and returns the enum value.
76+ func (e * enumHandler ) ParseOutput (m * Message ) (any , error ) {
77+ return e .parseEnum (m .Text ())
78+ }
79+
80+ // ParseChunk processes a streaming chunk and returns parsed output.
81+ func (e * enumHandler ) ParseChunk (chunk * ModelResponseChunk ) (any , error ) {
82+ if chunk .Index != e .currentIndex {
83+ e .accumulatedText = ""
84+ e .currentIndex = chunk .Index
85+ }
86+
87+ for _ , part := range chunk .Content {
88+ if part .IsText () {
89+ e .accumulatedText += part .Text
90+ }
91+ }
92+
93+ // Ignore error since we are doing best effort parsing.
94+ enum , _ := e .parseEnum (e .accumulatedText )
95+
96+ return enum , nil
97+ }
98+
7099// ParseMessage parses the message and returns the formatted message.
71- func (e enumHandler ) ParseMessage (m * Message ) (* Message , error ) {
100+ func (e * enumHandler ) ParseMessage (m * Message ) (* Message , error ) {
72101 if e .config .Format == OutputFormatEnum {
73102 if m == nil {
74103 return nil , errors .New ("message is empty" )
@@ -107,23 +136,63 @@ func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
107136 return m , nil
108137}
109138
110- // Get enum strings from json schema
139+ // Get enum strings from json schema.
140+ // Supports both top-level enum (e.g. {"type": "string", "enum": ["a", "b"]})
141+ // and nested property enum (e.g. {"properties": {"value": {"enum": ["a", "b"]}}}).
111142func objectEnums (schema map [string ]any ) []string {
112- var enums []string
143+ if enums := extractEnumStrings (schema ["enum" ]); len (enums ) > 0 {
144+ return enums
145+ }
113146
114147 if properties , ok := schema ["properties" ].(map [string ]any ); ok {
115148 for _ , propValue := range properties {
116149 if propMap , ok := propValue .(map [string ]any ); ok {
117- if enumSlice , ok := propMap ["enum" ].([]any ); ok {
118- for _ , enumVal := range enumSlice {
119- if enumStr , ok := enumVal .(string ); ok {
120- enums = append (enums , enumStr )
121- }
122- }
150+ if enums := extractEnumStrings (propMap ["enum" ]); len (enums ) > 0 {
151+ return enums
123152 }
124153 }
125154 }
126155 }
127156
128- return enums
157+ return nil
158+ }
159+
160+ // Extracts string values from an enum field, supporting both []any (from JSON) and []string (from Go code).
161+ func extractEnumStrings (v any ) []string {
162+ if v == nil {
163+ return nil
164+ }
165+
166+ if strs , ok := v .([]string ); ok {
167+ return strs
168+ }
169+
170+ if slice , ok := v .([]any ); ok {
171+ enums := make ([]string , 0 , len (slice ))
172+ for _ , val := range slice {
173+ if s , ok := val .(string ); ok {
174+ enums = append (enums , s )
175+ }
176+ }
177+ return enums
178+ }
179+
180+ return nil
181+ }
182+
183+ // parseEnum is the shared parsing logic used by both ParseOutput and ParseChunk.
184+ func (e * enumHandler ) parseEnum (text string ) (string , error ) {
185+ if text == "" {
186+ return "" , nil
187+ }
188+
189+ re := regexp .MustCompile (`['"]` )
190+ clean := re .ReplaceAllString (text , "" )
191+ trimmed := strings .TrimSpace (clean )
192+
193+ if ! slices .Contains (e .enums , trimmed ) {
194+ return "" , fmt .Errorf ("message %s not in list of valid enums: %s" , trimmed , strings .Join (e .enums , ", " ))
195+ }
196+
197+ return trimmed , nil
129198}
0 commit comments