Skip to content

Commit 6fbdae1

Browse files
authored
Fix: use state.input if input is empty (firebase#1056)
* Fix: use state.input if input is empty * Fix: add validation
1 parent f0bf098 commit 6fbdae1

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

go/genkit/flow.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323
"log"
2424
"net/http"
25+
"reflect"
2526
"strconv"
2627
"sync"
2728
"time"
@@ -476,6 +477,20 @@ func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamin
476477
return state, nil
477478
}
478479

480+
func isInputMissing(input any) bool {
481+
if input == nil {
482+
return true
483+
}
484+
v := reflect.ValueOf(input)
485+
switch v.Kind() {
486+
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface, reflect.Chan, reflect.Func:
487+
return v.IsNil()
488+
default:
489+
// For other types like structs, zero value might be a valid input.
490+
return false
491+
}
492+
}
493+
479494
// execute performs one flow execution.
480495
// Using its flowState argument as a starting point, it runs the flow function until
481496
// it finishes or is interrupted.
@@ -510,7 +525,22 @@ func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In
510525
traceID := rootSpanContext.TraceID().String()
511526
exec.TraceIDs = append(exec.TraceIDs, traceID)
512527
// TODO: Save rootSpanContext in the state.
513-
// TODO: If input is missing, get it from state.input and overwrite metadata.input.
528+
if isInputMissing(input) {
529+
if state == nil {
530+
return base.Zero[Out](), errors.New("input is missing and state is nil")
531+
}
532+
if isInputMissing(state.Input) {
533+
return base.Zero[Out](), errors.New("input is missing and state.Input is also empty")
534+
}
535+
input = state.Input
536+
537+
// Convert input to JSON string for tracing metadata
538+
bytes, err := json.Marshal(input)
539+
if err != nil {
540+
return base.Zero[Out](), fmt.Errorf("failed to marshal input for tracing: %w", err)
541+
}
542+
tracing.SetCustomMetadataAttr(ctx, "input", string(bytes))
543+
}
514544
start := time.Now()
515545
var err error
516546
if err = base.ValidateValue(input, f.inputSchema); err != nil {

0 commit comments

Comments
 (0)