Skip to content

Commit 3b4f001

Browse files
Zerekerzhanghaojie01
authored andcommitted
fix(go): defer template rendering in LoadPrompt to execution time
Previously, LoadPrompt called ToMessages with an empty DataArgument at load time, causing template variables to be replaced with empty values. This meant all subsequent Execute() calls would use prompts with empty template variable values. This change defers template rendering to execution time by using WithMessagesFn. The closure captures the raw template text and compiles/renders it with actual input values when Execute() or Render() is called. The fix properly handles: 1. Template variable substitution with actual input values 2. Multi-role messages (<<<dotprompt:role:XXX>>> markers) 3. History insertion (<<<dotprompt:history>>> markers) Added convertDotpromptMessages helper to convert dotprompt.Message to ai.Message format. Added regression test TestLoadPromptTemplateVariableSubstitution to verify template variables are correctly substituted with different input values on multiple calls. Fixes #3924
1 parent 9dcde54 commit 3b4f001

File tree

2 files changed

+136
-38
lines changed

2 files changed

+136
-38
lines changed

go/ai/prompt.go

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,31 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) {
517517
return result, nil
518518
}
519519

520+
// convertDotpromptMessages converts []dotprompt.Message to []*Message
521+
func convertDotpromptMessages(msgs []dotprompt.Message) ([]*Message, error) {
522+
result := make([]*Message, 0, len(msgs))
523+
for _, msg := range msgs {
524+
parts, err := convertToPartPointers(msg.Content)
525+
if err != nil {
526+
return nil, err
527+
}
528+
// Filter out nil parts
529+
filteredParts := make([]*Part, 0, len(parts))
530+
for _, p := range parts {
531+
if p != nil {
532+
filteredParts = append(filteredParts, p)
533+
}
534+
}
535+
if len(filteredParts) > 0 {
536+
result = append(result, &Message{
537+
Role: Role(msg.Role),
538+
Content: filteredParts,
539+
})
540+
}
541+
}
542+
return result, nil
543+
}
544+
520545
// LoadPromptDir loads prompts and partials from the input directory for the given namespace.
521546
func LoadPromptDir(r api.Registry, dir string, namespace string) {
522547
useDefaultDir := false
@@ -662,51 +687,53 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
662687

663688
key := promptKey(name, variant, namespace)
664689

665-
dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
666-
if err != nil {
667-
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
668-
return nil
669-
}
690+
// Store the raw template text to defer rendering until Execute() is called.
691+
// This ensures template variables are properly substituted with actual input values.
692+
// Previously, ToMessages was called with empty DataArgument which caused template
693+
// variables to be replaced with empty values at load time.
694+
// See: https://github.com/firebase/genkit/issues/3924
695+
templateText := parsedPrompt.Template
696+
697+
promptOpts := []PromptOption{opts}
670698

671-
var systemText string
672-
var nonSystemMessages []*Message
673-
for _, dpMsg := range dpMessages {
674-
parts, err := convertToPartPointers(dpMsg.Content)
699+
// Use WithMessagesFn to defer template rendering until execution time.
700+
// This approach properly handles:
701+
// 1. Template variable substitution with actual input values
702+
// 2. Multi-role messages (<<<dotprompt:role:XXX>>> markers)
703+
// 3. History insertion (<<<dotprompt:history>>> markers)
704+
promptOpts = append(promptOpts, WithMessagesFn(func(ctx context.Context, input any) ([]*Message, error) {
705+
inputMap, err := buildVariables(input)
675706
if err != nil {
676-
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
677-
return nil
707+
return nil, err
678708
}
679709

680-
role := Role(dpMsg.Role)
681-
if role == RoleSystem {
682-
var textParts []string
683-
for _, part := range parts {
684-
if part.IsText() {
685-
textParts = append(textParts, part.Text)
686-
}
687-
}
688-
689-
if len(textParts) > 0 {
690-
systemText = strings.Join(textParts, " ")
691-
}
692-
} else {
693-
nonSystemMessages = append(nonSystemMessages, &Message{Role: role, Content: parts})
710+
// Compile and render the template with actual input values
711+
renderedFunc, err := dp.Compile(templateText, &dotprompt.PromptMetadata{})
712+
if err != nil {
713+
return nil, fmt.Errorf("failed to compile template: %w", err)
694714
}
695-
}
696715

697-
promptOpts := []PromptOption{opts}
698-
699-
// Add system prompt if found
700-
if systemText != "" {
701-
promptOpts = append(promptOpts, WithSystem(systemText))
702-
}
716+
// Prepare the context for rendering
717+
context := map[string]any{}
718+
actionCtx := core.FromContext(ctx)
719+
maps.Copy(context, actionCtx)
720+
721+
// Render with actual input values
722+
rendered, err := renderedFunc(&dotprompt.DataArgument{
723+
Input: inputMap,
724+
Context: context,
725+
}, &dotprompt.PromptMetadata{
726+
Input: dotprompt.PromptMetadataInput{
727+
Default: opts.DefaultInput,
728+
},
729+
})
730+
if err != nil {
731+
return nil, fmt.Errorf("failed to render template: %w", err)
732+
}
703733

704-
// If there are non-system messages, use WithMessages, otherwise use WithPrompt for template
705-
if len(nonSystemMessages) > 0 {
706-
promptOpts = append(promptOpts, WithMessages(nonSystemMessages...))
707-
} else if systemText == "" {
708-
promptOpts = append(promptOpts, WithPrompt(parsedPrompt.Template))
709-
}
734+
// Convert dotprompt messages to ai messages
735+
return convertDotpromptMessages(rendered.Messages)
736+
}))
710737

711738
prompt := DefinePrompt(r, key, promptOpts...)
712739

go/ai/prompt_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,3 +1274,74 @@ Hello!
12741274
t.Errorf("Expected user message text to be 'Hello!', got '%s'", userMsg.Content[0].Text)
12751275
}
12761276
}
1277+
1278+
// TestLoadPromptTemplateVariableSubstitution tests that template variables are
1279+
// properly substituted with actual input values at execution time.
1280+
// This is a regression test for https://github.com/firebase/genkit/issues/3924
1281+
func TestLoadPromptTemplateVariableSubstitution(t *testing.T) {
1282+
tempDir := t.TempDir()
1283+
1284+
mockPromptFile := filepath.Join(tempDir, "greeting.prompt")
1285+
mockPromptContent := `---
1286+
model: test/chat
1287+
description: A greeting prompt with variables
1288+
---
1289+
Hello {{name}}, welcome to {{place}}!
1290+
`
1291+
1292+
if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
1293+
t.Fatalf("Failed to create mock prompt file: %v", err)
1294+
}
1295+
1296+
prompt := LoadPrompt(registry.New(), tempDir, "greeting.prompt", "template-var-test")
1297+
1298+
// Test with first set of input values
1299+
actionOpts1, err := prompt.Render(context.Background(), map[string]any{
1300+
"name": "Alice",
1301+
"place": "Wonderland",
1302+
})
1303+
if err != nil {
1304+
t.Fatalf("Failed to render prompt with first input: %v", err)
1305+
}
1306+
1307+
if len(actionOpts1.Messages) != 1 {
1308+
t.Fatalf("Expected 1 message, got %d", len(actionOpts1.Messages))
1309+
}
1310+
1311+
text1 := actionOpts1.Messages[0].Content[0].Text
1312+
if !strings.Contains(text1, "Alice") {
1313+
t.Errorf("Expected message to contain 'Alice', got: %s", text1)
1314+
}
1315+
if !strings.Contains(text1, "Wonderland") {
1316+
t.Errorf("Expected message to contain 'Wonderland', got: %s", text1)
1317+
}
1318+
1319+
// Test with second set of input values (different from first)
1320+
actionOpts2, err := prompt.Render(context.Background(), map[string]any{
1321+
"name": "Bob",
1322+
"place": "Paradise",
1323+
})
1324+
if err != nil {
1325+
t.Fatalf("Failed to render prompt with second input: %v", err)
1326+
}
1327+
1328+
if len(actionOpts2.Messages) != 1 {
1329+
t.Fatalf("Expected 1 message, got %d", len(actionOpts2.Messages))
1330+
}
1331+
1332+
text2 := actionOpts2.Messages[0].Content[0].Text
1333+
if !strings.Contains(text2, "Bob") {
1334+
t.Errorf("Expected message to contain 'Bob', got: %s", text2)
1335+
}
1336+
if !strings.Contains(text2, "Paradise") {
1337+
t.Errorf("Expected message to contain 'Paradise', got: %s", text2)
1338+
}
1339+
1340+
// Critical: Ensure the second render did NOT use the first input values
1341+
if strings.Contains(text2, "Alice") {
1342+
t.Errorf("BUG: Second render contains 'Alice' from first input! Got: %s", text2)
1343+
}
1344+
if strings.Contains(text2, "Wonderland") {
1345+
t.Errorf("BUG: Second render contains 'Wonderland' from first input! Got: %s", text2)
1346+
}
1347+
}

0 commit comments

Comments
 (0)