Skip to content

Commit 64ef676

Browse files
committed
fix(go): defer template rendering in LoadPrompt to execution time
Previously, LoadPrompt called ToMessages with an empty DataArgument at load time, causing template variables to be replaced with empty values. This meant all subsequent Execute() calls would use prompts with empty template variable values. This change defers template rendering to execution time by using WithMessagesFn. The closure captures the raw template text and compiles/renders it with actual input values when Execute() or Render() is called. The fix properly handles: 1. Template variable substitution with actual input values 2. Multi-role messages (<<<dotprompt:role:XXX>>> markers) 3. History insertion (<<<dotprompt:history>>> markers) Added convertDotpromptMessages helper to convert dotprompt.Message to ai.Message format. Added regression test TestLoadPromptTemplateVariableSubstitution to verify template variables are correctly substituted with different input values on multiple calls. Fixes #3924
1 parent fe6469a commit 64ef676

File tree

2 files changed

+229
-34
lines changed

2 files changed

+229
-34
lines changed

go/ai/prompt.go

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,31 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) {
553553
return result, nil
554554
}
555555

556+
// convertDotpromptMessages converts []dotprompt.Message to []*Message
557+
func convertDotpromptMessages(msgs []dotprompt.Message) ([]*Message, error) {
558+
result := make([]*Message, 0, len(msgs))
559+
for _, msg := range msgs {
560+
parts, err := convertToPartPointers(msg.Content)
561+
if err != nil {
562+
return nil, err
563+
}
564+
// Filter out nil parts
565+
filteredParts := make([]*Part, 0, len(parts))
566+
for _, p := range parts {
567+
if p != nil {
568+
filteredParts = append(filteredParts, p)
569+
}
570+
}
571+
if len(filteredParts) > 0 {
572+
result = append(result, &Message{
573+
Role: Role(msg.Role),
574+
Content: filteredParts,
575+
})
576+
}
577+
}
578+
return result, nil
579+
}
580+
556581
// LoadPromptDir loads prompts and partials from the input directory for the given namespace.
557582
func LoadPromptDir(r api.Registry, dir string, namespace string) {
558583
useDefaultDir := false
@@ -708,49 +733,47 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
708733

709734
key := promptKey(name, variant, namespace)
710735

711-
dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
736+
// Defer template rendering to execution time.
737+
// See: https://github.com/firebase/genkit/issues/3924
738+
templateText := parsedPrompt.Template
739+
compiledTemplate, err := dp.Compile(templateText, &dotprompt.PromptMetadata{
740+
Input: dotprompt.PromptMetadataInput{
741+
Default: opts.DefaultInput,
742+
},
743+
})
712744
if err != nil {
713-
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
745+
slog.Error("Failed to compile prompt template", "file", sourceFile, "error", err)
714746
return nil
715747
}
716748

717-
var systemText string
718-
var nonSystemMessages []*Message
719-
for _, dpMsg := range dpMessages {
720-
parts, err := convertToPartPointers(dpMsg.Content)
749+
promptOpts := []PromptOption{opts}
750+
promptOpts = append(promptOpts, WithMessagesFn(func(ctx context.Context, input any) ([]*Message, error) {
751+
inputMap, err := buildVariables(input)
721752
if err != nil {
722-
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
723-
return nil
753+
return nil, err
724754
}
725755

726-
role := Role(dpMsg.Role)
727-
if role == RoleSystem {
728-
var textParts []string
729-
for _, part := range parts {
730-
if part.IsText() {
731-
textParts = append(textParts, part.Text)
732-
}
733-
}
734-
735-
if len(textParts) > 0 {
736-
systemText = strings.Join(textParts, " ")
737-
}
738-
} else {
739-
nonSystemMessages = append(nonSystemMessages, &Message{Role: role, Content: parts})
756+
// Prepare the data context for rendering
757+
dataContext := map[string]any{}
758+
actionCtx := core.FromContext(ctx)
759+
maps.Copy(dataContext, actionCtx)
760+
761+
// Render with actual input values at execution time
762+
rendered, err := compiledTemplate(&dotprompt.DataArgument{
763+
Input: inputMap,
764+
Context: dataContext,
765+
}, &dotprompt.PromptMetadata{
766+
Input: dotprompt.PromptMetadataInput{
767+
Default: opts.DefaultInput,
768+
},
769+
})
770+
if err != nil {
771+
return nil, fmt.Errorf("failed to render template: %w", err)
740772
}
741-
}
742-
743-
promptOpts := []PromptOption{opts}
744773

745-
if systemText != "" {
746-
promptOpts = append(promptOpts, WithSystem(systemText))
747-
}
748-
749-
if len(nonSystemMessages) > 0 {
750-
promptOpts = append(promptOpts, WithMessages(nonSystemMessages...))
751-
} else if systemText == "" {
752-
promptOpts = append(promptOpts, WithPrompt(parsedPrompt.Template))
753-
}
774+
// Convert dotprompt messages to ai messages
775+
return convertDotpromptMessages(rendered.Messages)
776+
}))
754777

755778
prompt := DefinePrompt(r, key, promptOpts...)
756779

go/ai/prompt_test.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,3 +1500,175 @@ func TestWithOutputSchemaName_DefinePrompt_Missing(t *testing.T) {
15001500
t.Errorf("Expected error 'schema \"MissingSchema\" not found', got: %v", err)
15011501
}
15021502
}
1503+
1504+
// TestLoadPromptTemplateVariableSubstitution tests that template variables are
1505+
// properly substituted with actual input values at execution time.
1506+
// This is a regression test for https://github.com/firebase/genkit/issues/3924
1507+
func TestLoadPromptTemplateVariableSubstitution(t *testing.T) {
1508+
t.Run("single role", func(t *testing.T) {
1509+
tempDir := t.TempDir()
1510+
1511+
mockPromptFile := filepath.Join(tempDir, "greeting.prompt")
1512+
mockPromptContent := `---
1513+
model: test/chat
1514+
description: A greeting prompt with variables
1515+
---
1516+
Hello {{name}}, welcome to {{place}}!
1517+
`
1518+
1519+
if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
1520+
t.Fatalf("Failed to create mock prompt file: %v", err)
1521+
}
1522+
1523+
prompt := LoadPrompt(registry.New(), tempDir, "greeting.prompt", "template-var-test")
1524+
1525+
// Test with first set of input values
1526+
actionOpts1, err := prompt.Render(context.Background(), map[string]any{
1527+
"name": "Alice",
1528+
"place": "Wonderland",
1529+
})
1530+
if err != nil {
1531+
t.Fatalf("Failed to render prompt with first input: %v", err)
1532+
}
1533+
1534+
if len(actionOpts1.Messages) != 1 {
1535+
t.Fatalf("Expected 1 message, got %d", len(actionOpts1.Messages))
1536+
}
1537+
1538+
text1 := actionOpts1.Messages[0].Content[0].Text
1539+
if !strings.Contains(text1, "Alice") {
1540+
t.Errorf("Expected message to contain 'Alice', got: %s", text1)
1541+
}
1542+
if !strings.Contains(text1, "Wonderland") {
1543+
t.Errorf("Expected message to contain 'Wonderland', got: %s", text1)
1544+
}
1545+
1546+
// Test with second set of input values (different from first)
1547+
actionOpts2, err := prompt.Render(context.Background(), map[string]any{
1548+
"name": "Bob",
1549+
"place": "Paradise",
1550+
})
1551+
if err != nil {
1552+
t.Fatalf("Failed to render prompt with second input: %v", err)
1553+
}
1554+
1555+
if len(actionOpts2.Messages) != 1 {
1556+
t.Fatalf("Expected 1 message, got %d", len(actionOpts2.Messages))
1557+
}
1558+
1559+
text2 := actionOpts2.Messages[0].Content[0].Text
1560+
if !strings.Contains(text2, "Bob") {
1561+
t.Errorf("Expected message to contain 'Bob', got: %s", text2)
1562+
}
1563+
if !strings.Contains(text2, "Paradise") {
1564+
t.Errorf("Expected message to contain 'Paradise', got: %s", text2)
1565+
}
1566+
1567+
// Critical: Ensure the second render did NOT use the first input values
1568+
if strings.Contains(text2, "Alice") {
1569+
t.Errorf("BUG: Second render contains 'Alice' from first input! Got: %s", text2)
1570+
}
1571+
if strings.Contains(text2, "Wonderland") {
1572+
t.Errorf("BUG: Second render contains 'Wonderland' from first input! Got: %s", text2)
1573+
}
1574+
})
1575+
1576+
t.Run("multi role", func(t *testing.T) {
1577+
tempDir := t.TempDir()
1578+
1579+
mockPromptFile := filepath.Join(tempDir, "multi_role.prompt")
1580+
mockPromptContent := `---
1581+
model: test/chat
1582+
description: A multi-role prompt with variables
1583+
---
1584+
<<<dotprompt:role:system>>>
1585+
You are a {{personality}} assistant.
1586+
1587+
<<<dotprompt:role:user>>>
1588+
Hello {{name}}, please help me with {{task}}.
1589+
`
1590+
1591+
if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644); err != nil {
1592+
t.Fatalf("Failed to create mock prompt file: %v", err)
1593+
}
1594+
1595+
prompt := LoadPrompt(registry.New(), tempDir, "multi_role.prompt", "multi-role-var-test")
1596+
1597+
// Test with first set of input values
1598+
actionOpts1, err := prompt.Render(context.Background(), map[string]any{
1599+
"personality": "helpful",
1600+
"name": "Alice",
1601+
"task": "coding",
1602+
})
1603+
if err != nil {
1604+
t.Fatalf("Failed to render prompt with first input: %v", err)
1605+
}
1606+
1607+
if len(actionOpts1.Messages) != 2 {
1608+
t.Fatalf("Expected 2 messages, got %d", len(actionOpts1.Messages))
1609+
}
1610+
1611+
// Check system message
1612+
systemMsg := actionOpts1.Messages[0]
1613+
if systemMsg.Role != RoleSystem {
1614+
t.Errorf("Expected first message role to be 'system', got '%s'", systemMsg.Role)
1615+
}
1616+
systemText := systemMsg.Content[0].Text
1617+
if !strings.Contains(systemText, "helpful") {
1618+
t.Errorf("Expected system message to contain 'helpful', got: %s", systemText)
1619+
}
1620+
1621+
// Check user message
1622+
userMsg := actionOpts1.Messages[1]
1623+
if userMsg.Role != RoleUser {
1624+
t.Errorf("Expected second message role to be 'user', got '%s'", userMsg.Role)
1625+
}
1626+
userText := userMsg.Content[0].Text
1627+
if !strings.Contains(userText, "Alice") {
1628+
t.Errorf("Expected user message to contain 'Alice', got: %s", userText)
1629+
}
1630+
if !strings.Contains(userText, "coding") {
1631+
t.Errorf("Expected user message to contain 'coding', got: %s", userText)
1632+
}
1633+
1634+
// Test with second set of input values (different from first)
1635+
actionOpts2, err := prompt.Render(context.Background(), map[string]any{
1636+
"personality": "professional",
1637+
"name": "Bob",
1638+
"task": "writing",
1639+
})
1640+
if err != nil {
1641+
t.Fatalf("Failed to render prompt with second input: %v", err)
1642+
}
1643+
1644+
if len(actionOpts2.Messages) != 2 {
1645+
t.Fatalf("Expected 2 messages, got %d", len(actionOpts2.Messages))
1646+
}
1647+
1648+
// Check system message with new values
1649+
systemMsg2 := actionOpts2.Messages[0]
1650+
systemText2 := systemMsg2.Content[0].Text
1651+
if !strings.Contains(systemText2, "professional") {
1652+
t.Errorf("Expected system message to contain 'professional', got: %s", systemText2)
1653+
}
1654+
if strings.Contains(systemText2, "helpful") {
1655+
t.Errorf("BUG: Second render system message contains 'helpful' from first input! Got: %s", systemText2)
1656+
}
1657+
1658+
// Check user message with new values
1659+
userMsg2 := actionOpts2.Messages[1]
1660+
userText2 := userMsg2.Content[0].Text
1661+
if !strings.Contains(userText2, "Bob") {
1662+
t.Errorf("Expected user message to contain 'Bob', got: %s", userText2)
1663+
}
1664+
if !strings.Contains(userText2, "writing") {
1665+
t.Errorf("Expected user message to contain 'writing', got: %s", userText2)
1666+
}
1667+
if strings.Contains(userText2, "Alice") {
1668+
t.Errorf("BUG: Second render user message contains 'Alice' from first input! Got: %s", userText2)
1669+
}
1670+
if strings.Contains(userText2, "coding") {
1671+
t.Errorf("BUG: Second render user message contains 'coding' from first input! Got: %s", userText2)
1672+
}
1673+
})
1674+
}

0 commit comments

Comments
 (0)