Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 49 additions & 25 deletions cmd/cli/commands/nim.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package commands

import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
Expand All @@ -28,12 +29,21 @@ const (
nimPrefix = "nvcr.io/nim/"
// nimContainerPrefix is the prefix for NIM container names
nimContainerPrefix = "docker-model-nim-"
// nimDefaultPort is the default port for NIM containers
nimDefaultPort = 8000
// nimDefaultShmSize is the default shared memory size for NIM containers (16GB)
nimDefaultShmSize = 17179869184
)

var (
// nimDefaultPort is the default port for NIM containers
nimDefaultPort = 8000
)

// Message represents a single message in the chat conversation
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

// isNIMImage checks if the given model reference is an NVIDIA NIM image
func isNIMImage(model string) bool {
return strings.HasPrefix(model, nimPrefix)
Expand Down Expand Up @@ -389,7 +399,7 @@ func runNIMModel(ctx context.Context, dockerClient *client.Client, model string,
}

// chatWithNIM sends chat requests to a NIM container
func chatWithNIM(cmd *cobra.Command, model, prompt string) error {
func chatWithNIM(cmd *cobra.Command, model string, messages *[]Message, prompt string) error {
// Use the desktop client to chat with the NIM through its OpenAI-compatible API
// The NIM container runs on localhost:8000 and provides an OpenAI-compatible API

Expand All @@ -404,15 +414,25 @@ func chatWithNIM(cmd *cobra.Command, model, prompt string) error {
modelName = modelName[:idx]
}

reqBody := fmt.Sprintf(`{
"model": "%s",
"messages": [
{"role": "user", "content": %q}
],
"stream": true
}`, modelName, prompt)
// Append user message to history
*messages = append(*messages, Message{Role: "user", Content: prompt})

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", nimDefaultPort), strings.NewReader(reqBody))
requestPayload := struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
}{
Model: modelName,
Messages: *messages,
Stream: true,
}

reqBodyBytes, err := json.Marshal(requestPayload)
if err != nil {
return fmt.Errorf("failed to marshal request payload: %w", err)
}

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", nimDefaultPort), bytes.NewReader(reqBodyBytes))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
Expand All @@ -431,6 +451,7 @@ func chatWithNIM(cmd *cobra.Command, model, prompt string) error {
}

// Stream the response - parse SSE events
var assistantResponse strings.Builder
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
Expand All @@ -445,21 +466,20 @@ func chatWithNIM(cmd *cobra.Command, model, prompt string) error {
}

// Parse the JSON and extract the content
// For simplicity, we'll use basic string parsing
// In production, we'd use proper JSON parsing
if strings.Contains(data, `"content"`) {
// Extract content field - simple approach
contentStart := strings.Index(data, `"content":"`)
if contentStart != -1 {
contentStart += len(`"content":"`)
contentEnd := strings.Index(data[contentStart:], `"`)
if contentEnd != -1 {
content := data[contentStart : contentStart+contentEnd]
// Unescape basic JSON escapes
content = strings.ReplaceAll(content, `\n`, "\n")
content = strings.ReplaceAll(content, `\t`, "\t")
content = strings.ReplaceAll(content, `\"`, `"`)
var chatCompletion struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}

if err := json.Unmarshal([]byte(data), &chatCompletion); err == nil {
if len(chatCompletion.Choices) > 0 {
content := chatCompletion.Choices[0].Delta.Content
if content != "" {
cmd.Print(content)
assistantResponse.WriteString(content)
}
}
}
Expand All @@ -470,5 +490,9 @@ func chatWithNIM(cmd *cobra.Command, model, prompt string) error {
return fmt.Errorf("error reading response: %w", err)
}

// Append assistant message to history
*messages = append(*messages, Message{Role: "assistant", Content: assistantResponse.String()})

return nil
}

121 changes: 121 additions & 0 deletions cmd/cli/commands/nim_chat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package commands

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"

"github.com/spf13/cobra"
)

func TestChatWithNIM_Context(t *testing.T) {
// Save original port and restore after test
originalPort := nimDefaultPort
defer func() { nimDefaultPort = originalPort }()

// Track received messages
var receivedPayloads []struct {
Messages []Message `json:"messages"`
}

// Setup Mock Server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" {
t.Errorf("Expected path /v1/chat/completions, got %s", r.URL.Path)
http.Error(w, "Not found", http.StatusNotFound)
return
}

body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
}

var payload struct {
Messages []Message `json:"messages"`
}
if err := json.Unmarshal(body, &payload); err != nil {
t.Fatalf("Failed to unmarshal request body: %v", err)
}

receivedPayloads = append(receivedPayloads, payload)

// Mock response (SSE format)
w.Header().Set("Content-Type", "text/event-stream")
w.Write([]byte(`data: {"choices":[{"delta":{"content":"Response"}}]}
`))
w.Write([]byte(`data: [DONE]
`))
}))
defer server.Close()

// Parse server URL to get the port
u, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("Failed to parse server URL: %v", err)
}
port, err := strconv.Atoi(u.Port())
if err != nil {
t.Fatalf("Failed to parse port: %v", err)
}
nimDefaultPort = port

// Initialize messages slice
var messages []Message
cmd := &cobra.Command{}

// First interaction
err = chatWithNIM(cmd, "ai/model", &messages, "Hello")
if err != nil {
t.Fatalf("First chatWithNIM failed: %v", err)
}

// Verify first request
if len(receivedPayloads) != 1 {
t.Fatalf("Expected 1 request, got %d", len(receivedPayloads))
}
if len(receivedPayloads[0].Messages) != 1 {
t.Errorf("Expected 1 message in first request, got %d", len(receivedPayloads[0].Messages))
}
if receivedPayloads[0].Messages[0].Content != "Hello" {
t.Errorf("Expected content 'Hello', got '%s'", receivedPayloads[0].Messages[0].Content)
}

// Second interaction
err = chatWithNIM(cmd, "ai/model", &messages, "How are you?")
if err != nil {
t.Fatalf("Second chatWithNIM failed: %v", err)
}

// Verify second request
if len(receivedPayloads) != 2 {
t.Fatalf("Expected 2 requests, got %d", len(receivedPayloads))
}

// This is where we expect it to fail if the issue exists
// We expect:
// 1. User: Hello
// 2. Assistant: Response
// 3. User: How are you?
if len(receivedPayloads[1].Messages) != 3 {
t.Errorf("Expected 3 messages in second request, got %d", len(receivedPayloads[1].Messages))
for i, m := range receivedPayloads[1].Messages {
t.Logf("Message %d: Role=%s, Content=%s", i, m.Role, m.Content)
}
} else {
// Verify message content
if receivedPayloads[1].Messages[0].Content != "Hello" {
t.Errorf("Msg 0: Expected 'Hello', got '%s'", receivedPayloads[1].Messages[0].Content)
}
if receivedPayloads[1].Messages[1].Role != "assistant" {
t.Errorf("Msg 1: Expected role 'assistant', got '%s'", receivedPayloads[1].Messages[1].Role)
}
if receivedPayloads[1].Messages[2].Content != "How are you?" {
t.Errorf("Msg 2: Expected 'How are you?', got '%s'", receivedPayloads[1].Messages[2].Content)
}
}
}
29 changes: 19 additions & 10 deletions cmd/cli/commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.

var sb strings.Builder
var multiline bool
// Maintain conversation history
var messages []desktop.OpenAIChatMessage

// Add a helper function to handle file inclusion when @ is pressed
// We'll implement a basic version here that shows a message when @ is pressed
Expand Down Expand Up @@ -246,7 +248,7 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
}
}()

err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput)
err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput, &messages)

// Clean up signal handler
signal.Stop(sigChan)
Expand All @@ -273,6 +275,8 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
// generateInteractiveBasic provides a basic interactive mode (fallback)
func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
scanner := bufio.NewScanner(os.Stdin)
// Maintain conversation history
var messages []desktop.OpenAIChatMessage
for {
userInput, err := readMultilineInput(cmd, scanner)
if err != nil {
Expand Down Expand Up @@ -307,7 +311,7 @@ func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client,
}
}()

err = chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput)
err = chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput, &messages)

cancelChat()
signal.Stop(sigChan)
Expand Down Expand Up @@ -509,12 +513,12 @@ func renderMarkdown(content string) (string, error) {
}

// chatWithMarkdown performs chat and streams the response with selective markdown rendering.
func chatWithMarkdown(cmd *cobra.Command, client *desktop.Client, model, prompt string) error {
return chatWithMarkdownContext(cmd.Context(), cmd, client, model, prompt)
func chatWithMarkdown(cmd *cobra.Command, client *desktop.Client, model, prompt string, messages *[]desktop.OpenAIChatMessage) error {
return chatWithMarkdownContext(cmd.Context(), cmd, client, model, prompt, messages)
}

// chatWithMarkdownContext performs chat with context support and streams the response with selective markdown rendering.
func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *desktop.Client, model, prompt string) error {
func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *desktop.Client, model, prompt string, messages *[]desktop.OpenAIChatMessage) error {
colorMode, _ := cmd.Flags().GetString("color")
useMarkdown := shouldUseMarkdown(colorMode)
debug, _ := cmd.Flags().GetBool("debug")
Expand All @@ -535,15 +539,15 @@ func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *de

if !useMarkdown {
// Simple case: just stream as plain text
return client.ChatWithContext(ctx, model, prompt, imageURLs, func(content string) {
return client.ChatWithContext(ctx, model, prompt, imageURLs, messages, func(content string) {
cmd.Print(content)
}, false)
}

// For markdown: use streaming buffer to render code blocks as they complete
markdownBuffer := NewStreamingMarkdownBuffer()

err = client.ChatWithContext(ctx, model, prompt, imageURLs, func(content string) {
err = client.ChatWithContext(ctx, model, prompt, imageURLs, messages, func(content string) {
// Use the streaming markdown buffer to intelligently render content
rendered, err := markdownBuffer.AddContent(content, true)
if err != nil {
Expand Down Expand Up @@ -639,6 +643,8 @@ func newRunCmd() *cobra.Command {
scanner := bufio.NewScanner(os.Stdin)
cmd.Println("Interactive chat mode started. Type '/bye' to exit.")

var messages []Message // Declare messages slice for NIM interactive mode

for {
userInput, err := readMultilineInput(cmd, scanner)
if err != nil {
Expand All @@ -658,7 +664,8 @@ func newRunCmd() *cobra.Command {
continue
}

if err := chatWithNIM(cmd, model, userInput); err != nil {
// Pass the address of the messages slice
if err := chatWithNIM(cmd, model, &messages, userInput); err != nil {
cmd.PrintErr(fmt.Errorf("failed to chat with NIM: %w", err))
continue
}
Expand All @@ -669,7 +676,9 @@ func newRunCmd() *cobra.Command {
}

// Single prompt mode
if err := chatWithNIM(cmd, model, prompt); err != nil {
// Declare messages slice for NIM single prompt mode
var messages []Message
if err := chatWithNIM(cmd, model, &messages, prompt); err != nil {
return fmt.Errorf("failed to chat with NIM: %w", err)
}
cmd.Println()
Expand Down Expand Up @@ -707,7 +716,7 @@ func newRunCmd() *cobra.Command {
}

if prompt != "" {
if err := chatWithMarkdown(cmd, desktopClient, model, prompt); err != nil {
if err := chatWithMarkdown(cmd, desktopClient, model, prompt, nil); err != nil {
return handleClientError(err, "Failed to generate a response")
}
cmd.Println()
Expand Down
Loading