@@ -40,6 +40,8 @@ type Prompt interface {
4040 Name () string
4141 // Execute executes the prompt with the given options and returns a [ModelResponse].
4242 Execute (ctx context.Context , opts ... PromptExecuteOption ) (* ModelResponse , error )
43+ // ExecuteStream executes the prompt with streaming and returns an iterator.
44+ ExecuteStream (ctx context.Context , opts ... PromptExecuteOption ) func (func (* ModelStreamValue , error ) bool )
4345 // Render renders the prompt with the given input and returns a [GenerateActionOptions] to be used with [GenerateWithRequest].
4446 Render (ctx context.Context , input any ) (* GenerateActionOptions , error )
4547}
@@ -51,6 +53,14 @@ type prompt struct {
5153 registry api.Registry
5254}
5355
56+ // DataPrompt is a prompt with strongly-typed input and output.
57+ // It wraps an underlying Prompt and provides type-safe Execute and Render methods.
58+ // The Out type parameter can be string for text outputs or any struct type for JSON outputs.
59+ type DataPrompt [In , Out , Stream any ] struct {
60+ prompt Prompt
61+ registry api.Registry
62+ }
63+
5464// DefinePrompt creates a new [Prompt] and registers it.
5565func DefinePrompt (r api.Registry , name string , opts ... PromptOption ) Prompt {
5666 if name == "" {
@@ -232,6 +242,51 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod
232242 return GenerateWithRequest (ctx , r , actionOpts , execOpts .Middleware , execOpts .Stream )
233243}
234244
245+ // ExecuteStream executes the prompt with streaming and returns an iterator.
246+ // It returns a function whose argument function (the "yield function") will be repeatedly
247+ // called with the results.
248+ //
249+ // If the yield function is passed a non-nil error, execution has failed with that
250+ // error; the yield function will not be called again.
251+ //
252+ // If the yield function's [ModelStreamValue] argument has Done == true, the value's
253+ // Response field contains the final response; the yield function will not be called again.
254+ //
255+ // Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk.
256+ func (p * prompt ) ExecuteStream (ctx context.Context , opts ... PromptExecuteOption ) func (func (* ModelStreamValue , error ) bool ) {
257+ return func (yield func (* ModelStreamValue , error ) bool ) {
258+ if p == nil {
259+ yield (nil , errors .New ("Prompt.ExecuteStream: execute called on a nil Prompt; check that all prompts are defined" ))
260+ return
261+ }
262+
263+ cb := func (ctx context.Context , chunk * ModelResponseChunk ) error {
264+ if ctx .Err () != nil {
265+ return ctx .Err ()
266+ }
267+ if ! yield (& ModelStreamValue {Chunk : chunk }, nil ) {
268+ return errPromptStop
269+ }
270+ return nil
271+ }
272+
273+ allOpts := make ([]PromptExecuteOption , 0 , len (opts )+ 1 )
274+ allOpts = append (allOpts , opts ... )
275+ allOpts = append (allOpts , WithStreaming (cb ))
276+
277+ resp , err := p .Execute (ctx , allOpts ... )
278+ if err != nil {
279+ yield (nil , err )
280+ return
281+ }
282+
283+ yield (& ModelStreamValue {Done : true , Response : resp }, nil )
284+ }
285+ }
286+
287+ // errPromptStop is a sentinel error used to signal early termination of streaming.
288+ var errPromptStop = errors .New ("stop" )
289+
235290// Render renders the prompt template based on user input.
236291func (p * prompt ) Render (ctx context.Context , input any ) (* GenerateActionOptions , error ) {
237292 if p == nil {
@@ -759,3 +814,178 @@ func contentType(ct, uri string) (string, []byte, error) {
759814
760815 return "" , nil , errors .New ("uri content type not found" )
761816}
817+
818+ // DefineDataPrompt creates a new data prompt and registers it.
819+ // It automatically infers input schema from the In type parameter and configures
820+ // output schema and JSON format from the Out type parameter (unless Out is string).
821+ func DefineDataPrompt [In , Out , Stream any ](r api.Registry , name string , opts ... PromptOption ) * DataPrompt [In , Out , Stream ] {
822+ if name == "" {
823+ panic ("ai.DefineDataPrompt: name is required" )
824+ }
825+
826+ allOpts := make ([]PromptOption , 0 , len (opts )+ 2 )
827+
828+ var in In
829+ allOpts = append (allOpts , WithInputType (in ))
830+
831+ var out Out
832+ switch any (out ).(type ) {
833+ case string :
834+ // String output - no schema needed
835+ default :
836+ allOpts = append (allOpts , WithOutputType (out ))
837+ }
838+
839+ allOpts = append (allOpts , opts ... )
840+
841+ p := DefinePrompt (r , name , allOpts ... )
842+
843+ return & DataPrompt [In , Out , Stream ]{
844+ prompt : p ,
845+ registry : r ,
846+ }
847+ }
848+
849+ // LookupDataPrompt looks up a prompt by name and wraps it with type information.
850+ // This is useful for wrapping prompts loaded from .prompt files with strong types.
851+ // It returns nil if the prompt was not found.
852+ func LookupDataPrompt [In , Out , Stream any ](r api.Registry , name string ) * DataPrompt [In , Out , Stream ] {
853+ p := LookupPrompt (r , name )
854+ if p == nil {
855+ return nil
856+ }
857+
858+ return AsDataPrompt [In , Out , Stream ](p )
859+ }
860+
861+ // AsDataPrompt wraps an existing Prompt with type information, returning a DataPrompt.
862+ // This is useful for adding strong typing to a dynamically obtained prompt.
863+ func AsDataPrompt [In , Out , Stream any ](p Prompt ) * DataPrompt [In , Out , Stream ] {
864+ if p == nil {
865+ return nil
866+ }
867+
868+ return & DataPrompt [In , Out , Stream ]{
869+ prompt : p ,
870+ registry : p .(* prompt ).registry ,
871+ }
872+ }
873+
874+ // Name returns the name of the prompt.
875+ func (tp * DataPrompt [In , Out , Stream ]) Name () string {
876+ if tp == nil || tp .prompt == nil {
877+ return ""
878+ }
879+ return tp .prompt .Name ()
880+ }
881+
882+ // Execute executes the typed prompt and returns the strongly-typed output along with the full model response.
883+ // For structured output types (non-string Out), the prompt must be configured with the appropriate
884+ // output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt.
885+ func (tp * DataPrompt [In , Out , Stream ]) Execute (ctx context.Context , input In , opts ... PromptExecuteOption ) (* Out , * ModelResponse , error ) {
886+ if tp == nil || tp .prompt == nil {
887+ return nil , nil , errors .New ("TypedPrompt.Execute: called on a nil prompt; check that all prompts are defined" )
888+ }
889+
890+ allOpts := make ([]PromptExecuteOption , 0 , len (opts )+ 1 )
891+ allOpts = append (allOpts , WithInput (input ))
892+ allOpts = append (allOpts , opts ... )
893+
894+ resp , err := tp .prompt .Execute (ctx , allOpts ... )
895+ if err != nil {
896+ return nil , nil , err
897+ }
898+
899+ output , err := extractTypedOutput [Out ](resp )
900+ if err != nil {
901+ return nil , resp , err
902+ }
903+
904+ return output , resp , nil
905+ }
906+
907+ // ExecuteStream executes the typed prompt with streaming and returns an iterator.
908+ // It returns a function whose argument function (the "yield function") will be repeatedly
909+ // called with the results.
910+ //
911+ // If the yield function is passed a non-nil error, execution has failed with that
912+ // error; the yield function will not be called again.
913+ //
914+ // If the yield function's DataPromptStreamValue argument has Done == true, the value's
915+ // Output and Response fields contain the final typed output and response; the yield function
916+ // will not be called again.
917+ //
918+ // Otherwise the Chunk field of the passed DataPromptStreamValue holds a streamed chunk.
919+ //
920+ // For structured output types (non-string Out), the prompt must be configured with the appropriate
921+ // output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt.
922+ func (tp * DataPrompt [In , Out , Stream ]) ExecuteStream (ctx context.Context , input In , opts ... PromptExecuteOption ) func (func (* StreamValue [Out , Stream ], error ) bool ) {
923+ return func (yield func (* StreamValue [Out , Stream ], error ) bool ) {
924+ if tp == nil || tp .prompt == nil {
925+ yield (nil , errors .New ("DataPrompt.ExecuteStream: called on a nil prompt; check that all prompts are defined" ))
926+ return
927+ }
928+
929+ cb := func (ctx context.Context , chunk * ModelResponseChunk ) error {
930+ if ctx .Err () != nil {
931+ return ctx .Err ()
932+ }
933+ // TODO: Convert ModelResponseChunk to StreamValue[Out, Stream].
934+ if ! yield (& StreamValue [Out , Stream ]{}, nil ) {
935+ return errTypedPromptStop
936+ }
937+ return nil
938+ }
939+
940+ allOpts := make ([]PromptExecuteOption , 0 , len (opts )+ 2 )
941+ allOpts = append (allOpts , WithInput (input ))
942+ allOpts = append (allOpts , opts ... )
943+ allOpts = append (allOpts , WithStreaming (cb ))
944+
945+ resp , err := tp .prompt .Execute (ctx , allOpts ... )
946+ if err != nil {
947+ yield (nil , err )
948+ return
949+ }
950+
951+ output , err := extractTypedOutput [Out ](resp )
952+ if err != nil {
953+ yield (nil , err )
954+ return
955+ }
956+
957+ yield (& StreamValue [Out , Stream ]{Done : true , Output : * output , Response : resp }, nil )
958+ }
959+ }
960+
961+ // Render renders the typed prompt template with the given input.
962+ func (tp * DataPrompt [In , Out , Stream ]) Render (ctx context.Context , input In ) (* GenerateActionOptions , error ) {
963+ if tp == nil || tp .prompt == nil {
964+ return nil , errors .New ("TypedPrompt.Render: called on a nil prompt; check that all prompts are defined" )
965+ }
966+
967+ return tp .prompt .Render (ctx , input )
968+ }
969+
970+ // errTypedPromptStop is a sentinel error used to signal early termination of streaming.
971+ var errTypedPromptStop = errors .New ("stop" )
972+
973+ // extractTypedOutput extracts the typed output from a model response.
974+ func extractTypedOutput [Out any ](resp * ModelResponse ) (* Out , error ) {
975+ var output Out
976+
977+ switch any (output ).(type ) {
978+ case string :
979+ // String output - use Text()
980+ text := resp .Text ()
981+ // Type assertion to convert string to Out (which we know is string)
982+ result := any (text ).(Out )
983+ return & result , nil
984+ default :
985+ // Structured output - unmarshal from response
986+ if err := resp .Output (& output ); err != nil {
987+ return nil , fmt .Errorf ("TypedPrompt: failed to parse output: %w" , err )
988+ }
989+ return & output , nil
990+ }
991+ }
0 commit comments