Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions go/ai/formatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,17 +661,14 @@ func TestResolveFormat(t *testing.T) {
}
})

t.Run("defaults to text even when schema present but no format", func(t *testing.T) {
t.Run("defaults to json when schema present but no format", func(t *testing.T) {
schema := map[string]any{"type": "object"}
formatter, err := resolveFormat(r, schema, "")
if err != nil {
t.Fatalf("resolveFormat() error = %v", err)
}
// Note: The current implementation defaults to text when format is empty,
// even if schema is present. The schema/format combination is typically
// handled at a higher level (e.g., in Generate options).
if formatter.Name() != OutputFormatText {
t.Errorf("resolveFormat() = %q, want %q", formatter.Name(), OutputFormatText)
if formatter.Name() != OutputFormatJSON {
t.Errorf("resolveFormat() = %q, want %q", formatter.Name(), OutputFormatJSON)
}
})

Expand Down
147 changes: 135 additions & 12 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"errors"
"fmt"
"iter"
"slices"
"strings"

Expand Down Expand Up @@ -550,7 +551,7 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) (
return res.Text(), nil
}

// Generate run generate request for this model. Returns ModelResponse struct.
// GenerateData runs a generate request and returns strongly-typed output.
func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) {
var value Out
opts = append(opts, WithOutputType(value))
Expand All @@ -568,6 +569,104 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate
return &value, resp, nil
}

// StreamValue is either a streamed chunk or the final response of a generate request.
type StreamValue[Out, Stream any] struct {
Done bool
Chunk Stream // valid if Done is false
Output Out // valid if Done is true
Response *ModelResponse // valid if Done is true
}

// ModelStreamValue is a stream value for a model response.
// Out is never set because the output is already available in the Response field.
type ModelStreamValue = StreamValue[struct{}, *ModelResponseChunk]

// errGenerateStop is a sentinel error used to signal early termination of streaming.
var errGenerateStop = errors.New("stop")

// GenerateStream generates a model response and streams the output.
// It returns an iterator that yields streaming results.
//
// If the yield function is passed a non-nil error, generation has failed with that
// error; the yield function will not be called again.
//
// If the yield function's [ModelStreamValue] argument has Done == true, the value's
// Response field contains the final response; the yield function will not be called
// again.
//
// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk.
func GenerateStream(ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*ModelStreamValue, error] {
return func(yield func(*ModelStreamValue, error) bool) {
cb := func(ctx context.Context, chunk *ModelResponseChunk) error {
if ctx.Err() != nil {
return ctx.Err()
}
if !yield(&ModelStreamValue{Chunk: chunk}, nil) {
return errGenerateStop
}
return nil
}

allOpts := append(slices.Clone(opts), WithStreaming(cb))

resp, err := Generate(ctx, r, allOpts...)
if err != nil {
yield(nil, err)
} else {
yield(&ModelStreamValue{Done: true, Response: resp}, nil)
}
}
}

// GenerateDataStream generates a model response with streaming and returns strongly-typed output.
// It returns an iterator that yields streaming results.
//
// If the yield function is passed a non-nil error, generation has failed with that
// error; the yield function will not be called again.
//
// If the yield function's [StreamValue] argument has Done == true, the value's
// Output and Response fields contain the final typed output and response; the yield function
// will not be called again.
//
// Otherwise the Chunk field of the passed [StreamValue] holds a streamed chunk.
func GenerateDataStream[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*StreamValue[Out, Out], error] {
return func(yield func(*StreamValue[Out, Out], error) bool) {
cb := func(ctx context.Context, chunk *ModelResponseChunk) error {
if ctx.Err() != nil {
return ctx.Err()
}
var streamValue Out
if err := chunk.Output(&streamValue); err != nil {
yield(nil, err)
return err
}
if !yield(&StreamValue[Out, Out]{Chunk: streamValue}, nil) {
return errGenerateStop
}
return nil
}

// Prepend WithOutputType so the user can override the output format.
var value Out
allOpts := append([]GenerateOption{WithOutputType(value)}, opts...)
allOpts = append(allOpts, WithStreaming(cb))

resp, err := Generate(ctx, r, allOpts...)
if err != nil {
yield(nil, err)
return
}

output, err := extractTypedOutput[Out](resp)
if err != nil {
yield(nil, err)
return
}

yield(&StreamValue[Out, Out]{Done: true, Output: output, Response: resp}, nil)
}
}

// Generate applies the [Action] to provided request.
func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
if m == nil {
Expand Down Expand Up @@ -744,7 +843,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,
// [ModelResponse] as a string. It returns an empty string if there
// are no candidates or if the candidate has no message.
func (mr *ModelResponse) Text() string {
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return ""
}
return mr.Message.Text()
Expand All @@ -753,7 +852,7 @@ func (mr *ModelResponse) Text() string {
// History returns messages from the request combined with the response message
// to represent the conversation history.
func (mr *ModelResponse) History() []*Message {
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return mr.Request.Messages
}
return append(mr.Request.Messages, mr.Message)
Expand All @@ -762,7 +861,7 @@ func (mr *ModelResponse) History() []*Message {
// Reasoning concatenates all reasoning parts present in the message
func (mr *ModelResponse) Reasoning() string {
var sb strings.Builder
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return ""
}

Expand Down Expand Up @@ -806,7 +905,7 @@ func (mr *ModelResponse) Output(v any) error {
// ToolRequests returns the tool requests from the response.
func (mr *ModelResponse) ToolRequests() []*ToolRequest {
toolReqs := []*ToolRequest{}
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return toolReqs
}
for _, part := range mr.Message.Content {
Expand All @@ -820,7 +919,7 @@ func (mr *ModelResponse) ToolRequests() []*ToolRequest {
// Interrupts returns the interrupted tool request parts from the response.
func (mr *ModelResponse) Interrupts() []*Part {
parts := []*Part{}
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return parts
}
for _, part := range mr.Message.Content {
Expand All @@ -833,7 +932,7 @@ func (mr *ModelResponse) Interrupts() []*Part {

// Media returns the media content of the [ModelResponse] as a string.
func (mr *ModelResponse) Media() string {
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return ""
}
for _, part := range mr.Message.Content {
Expand Down Expand Up @@ -902,17 +1001,41 @@ func (c *ModelResponseChunk) Output(v any) error {

// outputer is an interface for types that can unmarshal structured output.
type outputer interface {
Output(v any) error
// Text returns the contents of the output as a string.
Text() string
// Output parses the structured output from the response and unmarshals it into value.
Output(value any) error
}

// OutputFrom is a convenience function that parses structured output from a
// [ModelResponse] or [ModelResponseChunk] and returns it as a typed value.
// This is equivalent to calling Output() but returns the value directly instead
// of requiring a pointer argument. If you need to handle the error, use Output() instead.
func OutputFrom[T any](src outputer) T {
var v T
src.Output(&v)
return v
func OutputFrom[Out any](src outputer) Out {
output, err := extractTypedOutput[Out](src)
if err != nil {
return base.Zero[Out]()
}
return output
}

// extractTypedOutput extracts the typed output from a model response.
// It supports string output by calling Text() and returning the result.
func extractTypedOutput[Out any](o outputer) (Out, error) {
var output Out

switch any(output).(type) {
case string:
text := o.Text()
// Type assertion to convert string to Out (which we know is string).
result := any(text).(Out)
return result, nil
default:
if err := o.Output(&output); err != nil {
return base.Zero[Out](), fmt.Errorf("failed to parse output: %w", err)
}
return output, nil
}
}

// Text returns the contents of a [Message] as a string. It
Expand Down
Loading
Loading