Skip to content

Commit 50ec641

Browse files
committed
Added iterator streaming functions and typed prompts.
1 parent 9dcde54 commit 50ec641

File tree

3 files changed

+427
-3
lines changed

3 files changed

+427
-3
lines changed

go/ai/generate.go

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,56 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) (
531531
return res.Text(), nil
532532
}
533533

534-
// Generate run generate request for this model. Returns ModelResponse struct.
535-
// TODO: Stream GenerateData with partial JSON
534+
// StreamValue is either a streamed chunk or the final response of a generate request.
535+
type StreamValue[Out, Stream any] struct {
536+
Done bool
537+
Chunk Stream // valid if Done is false
538+
Output Out // valid if Done is true
539+
Response *ModelResponse // valid if Done is true
540+
}
541+
542+
// ModelStreamValue is a stream value for a model response.
543+
type ModelStreamValue = StreamValue[*ModelResponse, *ModelResponseChunk]
544+
545+
// errGenerateStop is a sentinel error used to signal early termination of streaming.
546+
var errGenerateStop = errors.New("stop")
547+
548+
// GenerateStream generates a model response and streams the output.
549+
// It returns a function whose argument function (the "yield function") will be repeatedly
550+
// called with the results.
551+
//
552+
// If the yield function is passed a non-nil error, generation has failed with that
553+
// error; the yield function will not be called again.
554+
//
555+
// If the yield function's [StreamValue] argument has Done == true, the value's
556+
// Response field contains the final response; the yield function will not be called
557+
// again.
558+
//
559+
// Otherwise the Chunk field of the passed [StreamValue] holds a streamed chunk.
560+
func GenerateStream(ctx context.Context, r api.Registry, opts ...GenerateOption) func(func(*ModelStreamValue, error) bool) {
561+
return func(yield func(*ModelStreamValue, error) bool) {
562+
cb := func(ctx context.Context, chunk *ModelResponseChunk) error {
563+
if ctx.Err() != nil {
564+
return ctx.Err()
565+
}
566+
if !yield(&ModelStreamValue{Chunk: chunk}, nil) {
567+
return errGenerateStop
568+
}
569+
return nil
570+
}
571+
572+
allOpts := append(slices.Clone(opts), WithStreaming(cb))
573+
574+
resp, err := Generate(ctx, r, allOpts...)
575+
if err != nil {
576+
yield(nil, err)
577+
} else {
578+
yield(&ModelStreamValue{Done: true, Response: resp}, nil)
579+
}
580+
}
581+
}
582+
583+
// GenerateData runs a generate request and returns strongly-typed output.
536584
func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) {
537585
var value Out
538586
opts = append(opts, WithOutputType(value))
@@ -550,6 +598,50 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate
550598
return &value, resp, nil
551599
}
552600

601+
// GenerateDataStream generates a model response with streaming and returns strongly-typed output.
602+
// It returns a function whose argument function (the "yield function") will be repeatedly
603+
// called with the results.
604+
//
605+
// If the yield function is passed a non-nil error, generation has failed with that
606+
// error; the yield function will not be called again.
607+
//
608+
// If the yield function's [StreamValue] argument has Done == true, the value's
609+
// Output and Response fields contain the final typed output and response; the yield function
610+
// will not be called again.
611+
//
612+
// Otherwise the Chunk field of the passed [StreamValue] holds a streamed chunk.
613+
func GenerateDataStream[Out, Stream any](ctx context.Context, r api.Registry, opts ...GenerateOption) func(func(*StreamValue[Out, Stream], error) bool) {
614+
return func(yield func(*StreamValue[Out, Stream], error) bool) {
615+
cb := func(ctx context.Context, chunk *ModelResponseChunk) error {
616+
if ctx.Err() != nil {
617+
return ctx.Err()
618+
}
619+
// TODO: Convert ModelResponseChunk to Stream type.
620+
if !yield(&StreamValue[Out, Stream]{}, nil) {
621+
return errGenerateStop
622+
}
623+
return nil
624+
}
625+
626+
var value Out
627+
allOpts := append(slices.Clone(opts), WithOutputType(value), WithStreaming(cb))
628+
629+
resp, err := Generate(ctx, r, allOpts...)
630+
if err != nil {
631+
yield(nil, err)
632+
return
633+
}
634+
635+
err = resp.Output(&value)
636+
if err != nil {
637+
yield(nil, err)
638+
return
639+
}
640+
641+
yield(&StreamValue[Out, Stream]{Done: true, Output: value, Response: resp}, nil)
642+
}
643+
}
644+
553645
// Generate applies the [Action] to provided request.
554646
func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
555647
if m == nil {

go/ai/prompt.go

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type Prompt interface {
4040
Name() string
4141
// Execute executes the prompt with the given options and returns a [ModelResponse].
4242
Execute(ctx context.Context, opts ...PromptExecuteOption) (*ModelResponse, error)
43+
// ExecuteStream executes the prompt with streaming and returns an iterator.
44+
ExecuteStream(ctx context.Context, opts ...PromptExecuteOption) func(func(*ModelStreamValue, error) bool)
4345
// Render renders the prompt with the given input and returns a [GenerateActionOptions] to be used with [GenerateWithRequest].
4446
Render(ctx context.Context, input any) (*GenerateActionOptions, error)
4547
}
@@ -51,6 +53,14 @@ type prompt struct {
5153
registry api.Registry
5254
}
5355

56+
// DataPrompt is a prompt with strongly-typed input and output.
57+
// It wraps an underlying Prompt and provides type-safe Execute and Render methods.
58+
// The Out type parameter can be string for text outputs or any struct type for JSON outputs.
59+
type DataPrompt[In, Out, Stream any] struct {
60+
prompt Prompt
61+
registry api.Registry
62+
}
63+
5464
// DefinePrompt creates a new [Prompt] and registers it.
5565
func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt {
5666
if name == "" {
@@ -232,6 +242,51 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod
232242
return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream)
233243
}
234244

245+
// ExecuteStream executes the prompt with streaming and returns an iterator.
246+
// It returns a function whose argument function (the "yield function") will be repeatedly
247+
// called with the results.
248+
//
249+
// If the yield function is passed a non-nil error, execution has failed with that
250+
// error; the yield function will not be called again.
251+
//
252+
// If the yield function's [ModelStreamValue] argument has Done == true, the value's
253+
// Response field contains the final response; the yield function will not be called again.
254+
//
255+
// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk.
256+
func (p *prompt) ExecuteStream(ctx context.Context, opts ...PromptExecuteOption) func(func(*ModelStreamValue, error) bool) {
257+
return func(yield func(*ModelStreamValue, error) bool) {
258+
if p == nil {
259+
yield(nil, errors.New("Prompt.ExecuteStream: execute called on a nil Prompt; check that all prompts are defined"))
260+
return
261+
}
262+
263+
cb := func(ctx context.Context, chunk *ModelResponseChunk) error {
264+
if ctx.Err() != nil {
265+
return ctx.Err()
266+
}
267+
if !yield(&ModelStreamValue{Chunk: chunk}, nil) {
268+
return errPromptStop
269+
}
270+
return nil
271+
}
272+
273+
allOpts := make([]PromptExecuteOption, 0, len(opts)+1)
274+
allOpts = append(allOpts, opts...)
275+
allOpts = append(allOpts, WithStreaming(cb))
276+
277+
resp, err := p.Execute(ctx, allOpts...)
278+
if err != nil {
279+
yield(nil, err)
280+
return
281+
}
282+
283+
yield(&ModelStreamValue{Done: true, Response: resp}, nil)
284+
}
285+
}
286+
287+
// errPromptStop is a sentinel error used to signal early termination of streaming.
288+
var errPromptStop = errors.New("stop")
289+
235290
// Render renders the prompt template based on user input.
236291
func (p *prompt) Render(ctx context.Context, input any) (*GenerateActionOptions, error) {
237292
if p == nil {
@@ -759,3 +814,178 @@ func contentType(ct, uri string) (string, []byte, error) {
759814

760815
return "", nil, errors.New("uri content type not found")
761816
}
817+
818+
// DefineDataPrompt creates a new data prompt and registers it.
819+
// It automatically infers input schema from the In type parameter and configures
820+
// output schema and JSON format from the Out type parameter (unless Out is string).
821+
func DefineDataPrompt[In, Out, Stream any](r api.Registry, name string, opts ...PromptOption) *DataPrompt[In, Out, Stream] {
822+
if name == "" {
823+
panic("ai.DefineDataPrompt: name is required")
824+
}
825+
826+
allOpts := make([]PromptOption, 0, len(opts)+2)
827+
828+
var in In
829+
allOpts = append(allOpts, WithInputType(in))
830+
831+
var out Out
832+
switch any(out).(type) {
833+
case string:
834+
// String output - no schema needed
835+
default:
836+
allOpts = append(allOpts, WithOutputType(out))
837+
}
838+
839+
allOpts = append(allOpts, opts...)
840+
841+
p := DefinePrompt(r, name, allOpts...)
842+
843+
return &DataPrompt[In, Out, Stream]{
844+
prompt: p,
845+
registry: r,
846+
}
847+
}
848+
849+
// LookupDataPrompt looks up a prompt by name and wraps it with type information.
850+
// This is useful for wrapping prompts loaded from .prompt files with strong types.
851+
// It returns nil if the prompt was not found.
852+
func LookupDataPrompt[In, Out, Stream any](r api.Registry, name string) *DataPrompt[In, Out, Stream] {
853+
p := LookupPrompt(r, name)
854+
if p == nil {
855+
return nil
856+
}
857+
858+
return AsDataPrompt[In, Out, Stream](p)
859+
}
860+
861+
// AsDataPrompt wraps an existing Prompt with type information, returning a DataPrompt.
862+
// This is useful for adding strong typing to a dynamically obtained prompt.
863+
func AsDataPrompt[In, Out, Stream any](p Prompt) *DataPrompt[In, Out, Stream] {
864+
if p == nil {
865+
return nil
866+
}
867+
868+
return &DataPrompt[In, Out, Stream]{
869+
prompt: p,
870+
registry: p.(*prompt).registry,
871+
}
872+
}
873+
874+
// Name returns the name of the prompt.
875+
func (tp *DataPrompt[In, Out, Stream]) Name() string {
876+
if tp == nil || tp.prompt == nil {
877+
return ""
878+
}
879+
return tp.prompt.Name()
880+
}
881+
882+
// Execute executes the typed prompt and returns the strongly-typed output along with the full model response.
883+
// For structured output types (non-string Out), the prompt must be configured with the appropriate
884+
// output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt.
885+
func (tp *DataPrompt[In, Out, Stream]) Execute(ctx context.Context, input In, opts ...PromptExecuteOption) (*Out, *ModelResponse, error) {
886+
if tp == nil || tp.prompt == nil {
887+
return nil, nil, errors.New("TypedPrompt.Execute: called on a nil prompt; check that all prompts are defined")
888+
}
889+
890+
allOpts := make([]PromptExecuteOption, 0, len(opts)+1)
891+
allOpts = append(allOpts, WithInput(input))
892+
allOpts = append(allOpts, opts...)
893+
894+
resp, err := tp.prompt.Execute(ctx, allOpts...)
895+
if err != nil {
896+
return nil, nil, err
897+
}
898+
899+
output, err := extractTypedOutput[Out](resp)
900+
if err != nil {
901+
return nil, resp, err
902+
}
903+
904+
return output, resp, nil
905+
}
906+
907+
// ExecuteStream executes the typed prompt with streaming and returns an iterator.
908+
// It returns a function whose argument function (the "yield function") will be repeatedly
909+
// called with the results.
910+
//
911+
// If the yield function is passed a non-nil error, execution has failed with that
912+
// error; the yield function will not be called again.
913+
//
914+
// If the yield function's DataPromptStreamValue argument has Done == true, the value's
915+
// Output and Response fields contain the final typed output and response; the yield function
916+
// will not be called again.
917+
//
918+
// Otherwise the Chunk field of the passed DataPromptStreamValue holds a streamed chunk.
919+
//
920+
// For structured output types (non-string Out), the prompt must be configured with the appropriate
921+
// output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt.
922+
func (tp *DataPrompt[In, Out, Stream]) ExecuteStream(ctx context.Context, input In, opts ...PromptExecuteOption) func(func(*StreamValue[Out, Stream], error) bool) {
923+
return func(yield func(*StreamValue[Out, Stream], error) bool) {
924+
if tp == nil || tp.prompt == nil {
925+
yield(nil, errors.New("DataPrompt.ExecuteStream: called on a nil prompt; check that all prompts are defined"))
926+
return
927+
}
928+
929+
cb := func(ctx context.Context, chunk *ModelResponseChunk) error {
930+
if ctx.Err() != nil {
931+
return ctx.Err()
932+
}
933+
// TODO: Convert ModelResponseChunk to StreamValue[Out, Stream].
934+
if !yield(&StreamValue[Out, Stream]{}, nil) {
935+
return errTypedPromptStop
936+
}
937+
return nil
938+
}
939+
940+
allOpts := make([]PromptExecuteOption, 0, len(opts)+2)
941+
allOpts = append(allOpts, WithInput(input))
942+
allOpts = append(allOpts, opts...)
943+
allOpts = append(allOpts, WithStreaming(cb))
944+
945+
resp, err := tp.prompt.Execute(ctx, allOpts...)
946+
if err != nil {
947+
yield(nil, err)
948+
return
949+
}
950+
951+
output, err := extractTypedOutput[Out](resp)
952+
if err != nil {
953+
yield(nil, err)
954+
return
955+
}
956+
957+
yield(&StreamValue[Out, Stream]{Done: true, Output: *output, Response: resp}, nil)
958+
}
959+
}
960+
961+
// Render renders the typed prompt template with the given input.
962+
func (tp *DataPrompt[In, Out, Stream]) Render(ctx context.Context, input In) (*GenerateActionOptions, error) {
963+
if tp == nil || tp.prompt == nil {
964+
return nil, errors.New("TypedPrompt.Render: called on a nil prompt; check that all prompts are defined")
965+
}
966+
967+
return tp.prompt.Render(ctx, input)
968+
}
969+
970+
// errTypedPromptStop is a sentinel error used to signal early termination of streaming.
971+
var errTypedPromptStop = errors.New("stop")
972+
973+
// extractTypedOutput extracts the typed output from a model response.
974+
func extractTypedOutput[Out any](resp *ModelResponse) (*Out, error) {
975+
var output Out
976+
977+
switch any(output).(type) {
978+
case string:
979+
// String output - use Text()
980+
text := resp.Text()
981+
// Type assertion to convert string to Out (which we know is string)
982+
result := any(text).(Out)
983+
return &result, nil
984+
default:
985+
// Structured output - unmarshal from response
986+
if err := resp.Output(&output); err != nil {
987+
return nil, fmt.Errorf("TypedPrompt: failed to parse output: %w", err)
988+
}
989+
return &output, nil
990+
}
991+
}

0 commit comments

Comments
 (0)