Skip to content

Commit b07ca42

Browse files
authored
feat(go): refactored formatters + added support for formatting streams (#3905)
1 parent 520b99a commit b07ca42

File tree

19 files changed

+2058
-1541
lines changed

19 files changed

+2058
-1541
lines changed

go/ai/action_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,15 @@ func TestGenerateAction(t *testing.T) {
136136
t.Fatalf("action failed: %v", err)
137137
}
138138

139-
if diff := cmp.Diff(tc.ExpectChunks, chunks); diff != "" {
139+
if diff := cmp.Diff(tc.ExpectChunks, chunks, cmp.Options{
140+
cmpopts.IgnoreFields(ModelResponseChunk{}, "formatHandler"),
141+
}); diff != "" {
140142
t.Errorf("chunks mismatch (-want +got):\n%s", diff)
141143
}
142144

143145
if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{
144146
cmpopts.EquateEmpty(),
145-
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"),
147+
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs", "formatHandler"),
146148
cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"),
147149
cmpopts.IgnoreFields(ToolDefinition{}, "Metadata"),
148150
}); diff != "" {
@@ -156,7 +158,7 @@ func TestGenerateAction(t *testing.T) {
156158

157159
if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{
158160
cmpopts.EquateEmpty(),
159-
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs"),
161+
cmpopts.IgnoreFields(ModelResponse{}, "LatencyMs", "formatHandler"),
160162
cmpopts.IgnoreFields(GenerationUsage{}, "InputCharacters", "OutputCharacters"),
161163
cmpopts.IgnoreFields(ToolDefinition{}, "Metadata"),
162164
}); diff != "" {

go/ai/format_array.go

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ package ai
1616

1717
import (
1818
"encoding/json"
19-
"errors"
2019
"fmt"
21-
"strings"
2220

2321
"github.com/firebase/genkit/go/internal/base"
2422
)
@@ -45,6 +43,7 @@ func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) {
4543
handler := &arrayHandler{
4644
instructions: instructions,
4745
config: ModelOutputConfig{
46+
Constrained: true,
4847
Format: OutputFormatArray,
4948
Schema: schema,
5049
ContentType: "application/json",
@@ -55,58 +54,49 @@ func (a arrayFormatter) Handler(schema map[string]any) (FormatHandler, error) {
5554
}
5655

5756
type arrayHandler struct {
58-
instructions string
59-
config ModelOutputConfig
57+
instructions string
58+
config ModelOutputConfig
59+
accumulatedText string
60+
currentIndex int
61+
cursor int
6062
}
6163

6264
// Instructions returns the instructions for the formatter.
63-
func (a arrayHandler) Instructions() string {
65+
func (a *arrayHandler) Instructions() string {
6466
return a.instructions
6567
}
6668

6769
// Config returns the output config for the formatter.
68-
func (a arrayHandler) Config() ModelOutputConfig {
70+
func (a *arrayHandler) Config() ModelOutputConfig {
6971
return a.config
7072
}
7173

72-
// ParseMessage parses the message and returns the formatted message.
73-
func (a arrayHandler) ParseMessage(m *Message) (*Message, error) {
74-
if a.config.Format == OutputFormatArray {
75-
if m == nil {
76-
return nil, errors.New("message is empty")
77-
}
78-
if len(m.Content) == 0 {
79-
return nil, errors.New("message has no content")
80-
}
81-
82-
var nonTextParts []*Part
83-
accumulatedText := strings.Builder{}
74+
// ParseOutput parses the final message and returns the parsed array.
75+
func (a *arrayHandler) ParseOutput(m *Message) (any, error) {
76+
result := base.ExtractItems(m.Text(), 0)
77+
return result.Items, nil
78+
}
8479

85-
for _, part := range m.Content {
86-
if !part.IsText() {
87-
nonTextParts = append(nonTextParts, part)
88-
} else {
89-
accumulatedText.WriteString(part.Text)
90-
}
91-
}
80+
// ParseChunk processes a streaming chunk and returns parsed output.
81+
func (a *arrayHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) {
82+
if chunk.Index != a.currentIndex {
83+
a.accumulatedText = ""
84+
a.currentIndex = chunk.Index
85+
a.cursor = 0
86+
}
9287

93-
var newParts []*Part
94-
lines := base.GetJSONObjectLines(accumulatedText.String())
95-
for _, line := range lines {
96-
var schemaBytes []byte
97-
schemaBytes, err := json.Marshal(a.config.Schema["items"])
98-
if err != nil {
99-
return nil, fmt.Errorf("expected schema is not valid: %w", err)
100-
}
101-
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
102-
return nil, err
103-
}
104-
105-
newParts = append(newParts, NewJSONPart(line))
88+
for _, part := range chunk.Content {
89+
if part.IsText() {
90+
a.accumulatedText += part.Text
10691
}
107-
108-
m.Content = append(newParts, nonTextParts...)
10992
}
11093

94+
result := base.ExtractItems(a.accumulatedText, a.cursor)
95+
a.cursor = result.Cursor
96+
return result.Items, nil
97+
}
98+
99+
// ParseMessage parses the message and returns the formatted message.
100+
func (a *arrayHandler) ParseMessage(m *Message) (*Message, error) {
111101
return m, nil
112102
}

go/ai/format_enum.go

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020
"regexp"
2121
"slices"
2222
"strings"
23+
24+
"github.com/firebase/genkit/go/core"
2325
)
2426

2527
type enumFormatter struct{}
@@ -33,14 +35,15 @@ func (e enumFormatter) Name() string {
3335
func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) {
3436
enums := objectEnums(schema)
3537
if schema == nil || len(enums) == 0 {
36-
return nil, fmt.Errorf("schema is not valid JSON enum")
38+
return nil, core.NewError(core.INVALID_ARGUMENT, "schema must be an object with an 'enum' property for enum format")
3739
}
3840

3941
instructions := fmt.Sprintf("Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n```%s```", strings.Join(enums, "\n"))
4042

4143
handler := &enumHandler{
4244
instructions: instructions,
4345
config: ModelOutputConfig{
46+
Constrained: true,
4447
Format: OutputFormatEnum,
4548
Schema: schema,
4649
ContentType: "text/enum",
@@ -52,23 +55,49 @@ func (e enumFormatter) Handler(schema map[string]any) (FormatHandler, error) {
5255
}
5356

5457
type enumHandler struct {
55-
instructions string
56-
config ModelOutputConfig
57-
enums []string
58+
instructions string
59+
config ModelOutputConfig
60+
enums []string
61+
accumulatedText string
62+
currentIndex int
5863
}
5964

6065
// Instructions returns the instructions for the formatter.
61-
func (e enumHandler) Instructions() string {
66+
func (e *enumHandler) Instructions() string {
6267
return e.instructions
6368
}
6469

6570
// Config returns the output config for the formatter.
66-
func (e enumHandler) Config() ModelOutputConfig {
71+
func (e *enumHandler) Config() ModelOutputConfig {
6772
return e.config
6873
}
6974

75+
// ParseOutput parses the final message and returns the enum value.
76+
func (e *enumHandler) ParseOutput(m *Message) (any, error) {
77+
return e.parseEnum(m.Text())
78+
}
79+
80+
// ParseChunk processes a streaming chunk and returns parsed output.
81+
func (e *enumHandler) ParseChunk(chunk *ModelResponseChunk) (any, error) {
82+
if chunk.Index != e.currentIndex {
83+
e.accumulatedText = ""
84+
e.currentIndex = chunk.Index
85+
}
86+
87+
for _, part := range chunk.Content {
88+
if part.IsText() {
89+
e.accumulatedText += part.Text
90+
}
91+
}
92+
93+
// Ignore error since we are doing best effort parsing.
94+
enum, _ := e.parseEnum(e.accumulatedText)
95+
96+
return enum, nil
97+
}
98+
7099
// ParseMessage parses the message and returns the formatted message.
71-
func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
100+
func (e *enumHandler) ParseMessage(m *Message) (*Message, error) {
72101
if e.config.Format == OutputFormatEnum {
73102
if m == nil {
74103
return nil, errors.New("message is empty")
@@ -107,23 +136,63 @@ func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
107136
return m, nil
108137
}
109138

110-
// Get enum strings from json schema
139+
// Get enum strings from json schema.
140+
// Supports both top-level enum (e.g. {"type": "string", "enum": ["a", "b"]})
141+
// and nested property enum (e.g. {"properties": {"value": {"enum": ["a", "b"]}}}).
111142
func objectEnums(schema map[string]any) []string {
112-
var enums []string
143+
if enums := extractEnumStrings(schema["enum"]); len(enums) > 0 {
144+
return enums
145+
}
113146

114147
if properties, ok := schema["properties"].(map[string]any); ok {
115148
for _, propValue := range properties {
116149
if propMap, ok := propValue.(map[string]any); ok {
117-
if enumSlice, ok := propMap["enum"].([]any); ok {
118-
for _, enumVal := range enumSlice {
119-
if enumStr, ok := enumVal.(string); ok {
120-
enums = append(enums, enumStr)
121-
}
122-
}
150+
if enums := extractEnumStrings(propMap["enum"]); len(enums) > 0 {
151+
return enums
123152
}
124153
}
125154
}
126155
}
127156

128-
return enums
157+
return nil
158+
}
159+
160+
// Extracts string values from an enum field, supporting both []any (from JSON) and []string (from Go code).
161+
func extractEnumStrings(v any) []string {
162+
if v == nil {
163+
return nil
164+
}
165+
166+
if strs, ok := v.([]string); ok {
167+
return strs
168+
}
169+
170+
if slice, ok := v.([]any); ok {
171+
enums := make([]string, 0, len(slice))
172+
for _, val := range slice {
173+
if s, ok := val.(string); ok {
174+
enums = append(enums, s)
175+
}
176+
}
177+
return enums
178+
}
179+
180+
return nil
181+
}
182+
183+
// parseEnum is the shared parsing logic used by both ParseOutput and ParseChunk.
184+
func (e *enumHandler) parseEnum(text string) (string, error) {
185+
if text == "" {
186+
return "", nil
187+
}
188+
189+
re := regexp.MustCompile(`['"]`)
190+
clean := re.ReplaceAllString(text, "")
191+
trimmed := strings.TrimSpace(clean)
192+
193+
if !slices.Contains(e.enums, trimmed) {
194+
return "", fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", "))
195+
}
196+
197+
return trimmed, nil
129198
}

0 commit comments

Comments
 (0)