diff --git a/cmd/cli/commands/nim.go b/cmd/cli/commands/nim.go index 66e27910..6c3f4f9d 100644 --- a/cmd/cli/commands/nim.go +++ b/cmd/cli/commands/nim.go @@ -2,6 +2,7 @@ package commands import ( "bufio" + "bytes" "context" "encoding/base64" "encoding/json" @@ -19,6 +20,7 @@ import ( "github.com/docker/docker/api/types/mount" "github.com/docker/docker/client" "github.com/docker/go-connections/nat" + "github.com/docker/model-runner/cmd/cli/desktop" gpupkg "github.com/docker/model-runner/cmd/cli/pkg/gpu" "github.com/spf13/cobra" ) @@ -28,12 +30,15 @@ const ( nimPrefix = "nvcr.io/nim/" // nimContainerPrefix is the prefix for NIM container names nimContainerPrefix = "docker-model-nim-" - // nimDefaultPort is the default port for NIM containers - nimDefaultPort = 8000 // nimDefaultShmSize is the default shared memory size for NIM containers (16GB) nimDefaultShmSize = 17179869184 ) +var ( + // nimDefaultPort is the default port for NIM containers + nimDefaultPort = 8000 +) + // isNIMImage checks if the given model reference is an NVIDIA NIM image func isNIMImage(model string) bool { return strings.HasPrefix(model, nimPrefix) @@ -389,7 +394,7 @@ func runNIMModel(ctx context.Context, dockerClient *client.Client, model string, } // chatWithNIM sends chat requests to a NIM container -func chatWithNIM(cmd *cobra.Command, model, prompt string) error { +func chatWithNIM(cmd *cobra.Command, model string, messages *[]desktop.OpenAIChatMessage, prompt string) error { // Use the desktop client to chat with the NIM through its OpenAI-compatible API // The NIM container runs on localhost:8000 and provides an OpenAI-compatible API @@ -404,15 +409,25 @@ func chatWithNIM(cmd *cobra.Command, model, prompt string) error { modelName = modelName[:idx] } - reqBody := fmt.Sprintf(`{ - "model": "%s", - "messages": [ - {"role": "user", "content": %q} - ], - "stream": true - }`, modelName, prompt) + // Append user message to history + *messages = append(*messages, desktop.OpenAIChatMessage{Role: "user", Content: prompt}) + + requestPayload := struct { + Model string `json:"model"` + Messages []desktop.OpenAIChatMessage `json:"messages"` + Stream bool `json:"stream"` + }{ + Model: modelName, + Messages: *messages, + Stream: true, + } + + reqBodyBytes, err := json.Marshal(requestPayload) + if err != nil { + return fmt.Errorf("failed to marshal request payload: %w", err) + } - req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", nimDefaultPort), strings.NewReader(reqBody)) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/v1/chat/completions", nimDefaultPort), bytes.NewReader(reqBodyBytes)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } @@ -431,6 +446,7 @@ func chatWithNIM(cmd *cobra.Command, model, prompt string) error { } // Stream the response - parse SSE events + var assistantResponse strings.Builder scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() @@ -445,21 +461,20 @@ func chatWithNIM(cmd *cobra.Command, model, prompt string) error { } // Parse the JSON and extract the content - // For simplicity, we'll use basic string parsing - // In production, we'd use proper JSON parsing - if strings.Contains(data, `"content"`) { - // Extract content field - simple approach - contentStart := strings.Index(data, `"content":"`) - if contentStart != -1 { - contentStart += len(`"content":"`) - contentEnd := strings.Index(data[contentStart:], `"`) - if contentEnd != -1 { - content := data[contentStart : contentStart+contentEnd] - // Unescape basic JSON escapes - content = strings.ReplaceAll(content, `\n`, "\n") - content = strings.ReplaceAll(content, `\t`, "\t") - content = strings.ReplaceAll(content, `\"`, `"`) + var chatCompletion struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + } `json:"choices"` + } + + if err := json.Unmarshal([]byte(data), &chatCompletion); err == nil { + if len(chatCompletion.Choices) > 0 { + content := chatCompletion.Choices[0].Delta.Content + if content != "" { cmd.Print(content) + assistantResponse.WriteString(content) } } } @@ -470,5 +485,8 @@ func chatWithNIM(cmd *cobra.Command, model, prompt string) error { return fmt.Errorf("error reading response: %w", err) } + // Append assistant message to history + *messages = append(*messages, desktop.OpenAIChatMessage{Role: "assistant", Content: assistantResponse.String()}) + return nil } diff --git a/cmd/cli/commands/nim_chat_test.go b/cmd/cli/commands/nim_chat_test.go new file mode 100644 index 00000000..acbf5a80 --- /dev/null +++ b/cmd/cli/commands/nim_chat_test.go @@ -0,0 +1,122 @@ +package commands + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + + "github.com/docker/model-runner/cmd/cli/desktop" // Add this import + "github.com/spf13/cobra" +) + +func TestChatWithNIM_Context(t *testing.T) { + // Save original port and restore after test + originalPort := nimDefaultPort + defer func() { nimDefaultPort = originalPort }() + + // Track received messages + var receivedPayloads []struct { + Messages []desktop.OpenAIChatMessage `json:"messages"` + } + + // Setup Mock Server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("Expected path /v1/chat/completions, got %s", r.URL.Path) + http.Error(w, "Not found", http.StatusNotFound) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("Failed to read request body: %v", err) + } + + var payload struct { + Messages []desktop.OpenAIChatMessage `json:"messages"` + } + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("Failed to unmarshal request body: %v", err) + } + + receivedPayloads = append(receivedPayloads, payload) + + // Mock response (SSE format) + w.Header().Set("Content-Type", "text/event-stream") + w.Write([]byte(`data: {"choices":[{"delta":{"content":"Response"}}]} +`)) + w.Write([]byte(`data: [DONE] +`)) + })) + defer server.Close() + + // Parse server URL to get the port + u, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + port, err := strconv.Atoi(u.Port()) + if err != nil { + t.Fatalf("Failed to parse port: %v", err) + } + nimDefaultPort = port + + // Initialize messages slice + var messages []desktop.OpenAIChatMessage + cmd := &cobra.Command{} + + // First interaction + err = chatWithNIM(cmd, "ai/model", &messages, "Hello") + if err != nil { + t.Fatalf("First chatWithNIM failed: %v", err) + } + + // Verify first request + if len(receivedPayloads) != 1 { + t.Fatalf("Expected 1 request, got %d", len(receivedPayloads)) + } + if len(receivedPayloads[0].Messages) != 1 { + t.Errorf("Expected 1 message in first request, got %d", len(receivedPayloads[0].Messages)) + } + if receivedPayloads[0].Messages[0].Content != "Hello" { + t.Errorf("Expected content 'Hello', got '%s'", receivedPayloads[0].Messages[0].Content) + } + + // Second interaction + err = chatWithNIM(cmd, "ai/model", &messages, "How are you?") + if err != nil { + t.Fatalf("Second chatWithNIM failed: %v", err) + } + + // Verify second request + if len(receivedPayloads) != 2 { + t.Fatalf("Expected 2 requests, got %d", len(receivedPayloads)) + } + + // This is where we expect it to fail if the issue exists + // We expect: + // 1. User: Hello + // 2. Assistant: Response + // 3. User: How are you? + if len(receivedPayloads[1].Messages) != 3 { + t.Errorf("Expected 3 messages in second request, got %d", len(receivedPayloads[1].Messages)) + for i, m := range receivedPayloads[1].Messages { + t.Logf("Message %d: Role=%s, Content=%s", i, m.Role, m.Content) + } + } else { + // Verify message content + if receivedPayloads[1].Messages[0].Content != "Hello" { + t.Errorf("Msg 0: Expected 'Hello', got '%s'", receivedPayloads[1].Messages[0].Content) + } + if receivedPayloads[1].Messages[1].Role != "assistant" { + t.Errorf("Msg 1: Expected role 'assistant', got '%s'", receivedPayloads[1].Messages[1].Role) + } + if receivedPayloads[1].Messages[2].Content != "How are you?" { + t.Errorf("Msg 2: Expected 'How are you?', got '%s'", receivedPayloads[1].Messages[2].Content) + } + } +} diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 8f99d3cf..b2eeab53 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -155,6 +155,8 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. var sb strings.Builder var multiline bool + // Maintain conversation history + var messages []desktop.OpenAIChatMessage // Add a helper function to handle file inclusion when @ is pressed // We'll implement a basic version here that shows a message when @ is pressed @@ -246,7 +248,7 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. } }() - err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput) + err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput, &messages) // Clean up signal handler signal.Stop(sigChan) @@ -273,6 +275,8 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. // generateInteractiveBasic provides a basic interactive mode (fallback) func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { scanner := bufio.NewScanner(os.Stdin) + // Maintain conversation history + var messages []desktop.OpenAIChatMessage for { userInput, err := readMultilineInput(cmd, scanner) if err != nil { @@ -307,7 +311,7 @@ func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client, } }() - err = chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput) + err = chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput, &messages) cancelChat() signal.Stop(sigChan) @@ -509,12 +513,12 @@ func renderMarkdown(content string) (string, error) { } // chatWithMarkdown performs chat and streams the response with selective markdown rendering. -func chatWithMarkdown(cmd *cobra.Command, client *desktop.Client, model, prompt string) error { - return chatWithMarkdownContext(cmd.Context(), cmd, client, model, prompt) +func chatWithMarkdown(cmd *cobra.Command, client *desktop.Client, model, prompt string, messages *[]desktop.OpenAIChatMessage) error { + return chatWithMarkdownContext(cmd.Context(), cmd, client, model, prompt, messages) } // chatWithMarkdownContext performs chat with context support and streams the response with selective markdown rendering. -func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *desktop.Client, model, prompt string) error { +func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *desktop.Client, model, prompt string, messages *[]desktop.OpenAIChatMessage) error { colorMode, _ := cmd.Flags().GetString("color") useMarkdown := shouldUseMarkdown(colorMode) debug, _ := cmd.Flags().GetBool("debug") @@ -535,7 +539,7 @@ func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *de if !useMarkdown { // Simple case: just stream as plain text - return client.ChatWithContext(ctx, model, prompt, imageURLs, func(content string) { + return client.ChatWithContext(ctx, model, prompt, imageURLs, messages, func(content string) { cmd.Print(content) }, false) } @@ -543,7 +547,7 @@ func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *de // For markdown: use streaming buffer to render code blocks as they complete markdownBuffer := NewStreamingMarkdownBuffer() - err = client.ChatWithContext(ctx, model, prompt, imageURLs, func(content string) { + err = client.ChatWithContext(ctx, model, prompt, imageURLs, messages, func(content string) { // Use the streaming markdown buffer to intelligently render content rendered, err := markdownBuffer.AddContent(content, true) if err != nil { @@ -639,6 +643,8 @@ func newRunCmd() *cobra.Command { scanner := bufio.NewScanner(os.Stdin) cmd.Println("Interactive chat mode started. Type '/bye' to exit.") + var messages []desktop.OpenAIChatMessage // Declare messages slice for NIM interactive mode + for { userInput, err := readMultilineInput(cmd, scanner) if err != nil { @@ -658,7 +664,8 @@ func newRunCmd() *cobra.Command { continue } - if err := chatWithNIM(cmd, model, userInput); err != nil { + // Pass the address of the messages slice + if err := chatWithNIM(cmd, model, &messages, userInput); err != nil { cmd.PrintErr(fmt.Errorf("failed to chat with NIM: %w", err)) continue } @@ -669,7 +676,9 @@ func newRunCmd() *cobra.Command { } // Single prompt mode - if err := chatWithNIM(cmd, model, prompt); err != nil { + // Declare messages slice for NIM single prompt mode + var messages []desktop.OpenAIChatMessage + if err := chatWithNIM(cmd, model, &messages, prompt); err != nil { return fmt.Errorf("failed to chat with NIM: %w", err) } cmd.Println() @@ -707,7 +716,7 @@ func newRunCmd() *cobra.Command { } if prompt != "" { - if err := chatWithMarkdown(cmd, desktopClient, model, prompt); err != nil { + if err := chatWithMarkdown(cmd, desktopClient, model, prompt, nil); err != nil { return handleClientError(err, "Failed to generate a response") } cmd.Println() diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 2d15462f..b20b170b 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -307,11 +307,13 @@ func (c *Client) fullModelID(id string) (string, error) { // Chat performs a chat request and streams the response content with selective markdown rendering. func (c *Client) Chat(model, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) error { - return c.ChatWithContext(context.Background(), model, prompt, imageURLs, outputFunc, shouldUseMarkdown) + return c.ChatWithContext(context.Background(), model, prompt, imageURLs, nil, outputFunc, shouldUseMarkdown) } // ChatWithContext performs a chat request with context support for cancellation and streams the response content with selective markdown rendering. -func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) error { +// If messages is provided, it will be used as conversation history. The function updates +// the provided messages slice to include the new user message and the assistant's response. +func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imageURLs []string, messages *[]OpenAIChatMessage, outputFunc func(string), shouldUseMarkdown bool) error { model = normalizeHuggingFaceModelName(model) if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -350,15 +352,26 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag messageContent = prompt } + // Prepare messages for the request + userMessage := OpenAIChatMessage{ + Role: "user", + Content: messageContent, + } + + var requestMessages []OpenAIChatMessage + if messages != nil { + // For a conversation, append the new message and use the full history for the request. + *messages = append(*messages, userMessage) + requestMessages = *messages + } else { + // For a single-shot chat, just send the new user message. + requestMessages = []OpenAIChatMessage{userMessage} + } + reqBody := OpenAIChatRequest{ - Model: model, - Messages: []OpenAIChatMessage{ - { - Role: "user", - Content: messageContent, - }, - }, - Stream: true, + Model: model, + Messages: requestMessages, + Stream: true, } jsonData, err := json.Marshal(reqBody) @@ -400,6 +413,9 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag TotalTokens int `json:"total_tokens"` } + // Accumulate assistant response for conversation history + var assistantResponse strings.Builder + scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { // Check if context was cancelled @@ -453,6 +469,7 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag } else { outputFunc(chunk) } + // Note: reasoning content is not included in the assistant message content } if streamResp.Choices[0].Delta.Content != "" { chunk := streamResp.Choices[0].Delta.Content @@ -461,6 +478,8 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag } printerState = chatPrinterContent outputFunc(chunk) + // Accumulate the assistant's content for conversation history + assistantResponse.WriteString(chunk) } } } @@ -469,6 +488,15 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag return fmt.Errorf("error reading response stream: %w", err) } + // Append assistant message to conversation history + if messages != nil && assistantResponse.Len() > 0 { + assistantMessage := OpenAIChatMessage{ + Role: "assistant", + Content: assistantResponse.String(), + } + *messages = append(*messages, assistantMessage) + } + if finalUsage != nil { usageInfo := fmt.Sprintf("\n\nToken usage: %d prompt + %d completion = %d total", finalUsage.PromptTokens, diff --git a/pkg/inference/backends/llamacpp/download.go b/pkg/inference/backends/llamacpp/download.go index 519d08d1..1862dec4 100644 --- a/pkg/inference/backends/llamacpp/download.go +++ b/pkg/inference/backends/llamacpp/download.go @@ -27,7 +27,7 @@ const ( var ( ShouldUseGPUVariant bool ShouldUseGPUVariantLock sync.Mutex - ShouldUpdateServer = true + ShouldUpdateServer = false ShouldUpdateServerLock sync.Mutex DesiredServerVersion = "latest" DesiredServerVersionLock sync.Mutex diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index c6b29201..a8016412 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -95,8 +95,8 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { // We don't currently support this backend on Windows. We'll likely // never support it on Intel Macs. - if (runtime.GOOS == "darwin" && runtime.GOARCH == "amd64") || - (runtime.GOOS == "windows" && !(runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64")) { + // We don't currently support this backend on Windows. + if (runtime.GOOS == "windows" && !(runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64")) { return errors.New("platform not supported") }