diff --git a/lua/neoai/chat/history.lua b/lua/neoai/chat/history.lua index bf0ae05..e45ef07 100644 --- a/lua/neoai/chat/history.lua +++ b/lua/neoai/chat/history.lua @@ -12,40 +12,39 @@ local ChatHistory = { model = "", params = {}, messages = {} } ---@param context string | nil The context to use ---@return ChatHistory function ChatHistory:new(model, params, context) - local obj = {} + local obj = {} - setmetatable(obj, self) - self.__index = self + setmetatable(obj, self) + self.__index = self - self.model = model - self.params = params or {} - self.messages = {} + self.model = model + self.params = params or {} + self.messages = {} - if context ~= nil then - local context_prompt = config.options.prompts.context_prompt(context) - self:set_prompt(context_prompt) - end - return obj + if context ~= nil then + local context_prompt = config.options.prompts.context_prompt(context) + self:set_prompt(context_prompt) + end + return obj end --- @param prompt string system prompt function ChatHistory:set_prompt(prompt) - local system_msg = { - role = "system", - content = prompt, - } - table.insert(self.messages, system_msg) + local system_msg = { + role = "system", + content = prompt, + } + table.insert(self.messages, system_msg) end ---@param user boolean True if user sent msg ---@param msg string The message to add function ChatHistory:add_message(user, msg) - local role = user and "user" or "assistant" - - table.insert(self.messages, { - role = role, - content = msg, - }) + local role = user and "user" or "assistant" + table.insert(self.messages, { + role = role, + content = msg, + }) end return ChatHistory diff --git a/lua/neoai/chat/init.lua b/lua/neoai/chat/init.lua index 8848bb4..d29df5a 100644 --- a/lua/neoai/chat/init.lua +++ b/lua/neoai/chat/init.lua @@ -62,6 +62,7 @@ end M.reset = function() M.context = nil M.chat_history = nil + M.get_current_model().name.cancel_stream() end local chunks = {} diff --git a/lua/neoai/chat/models/openai.lua b/lua/neoai/chat/models/openai.lua index 011d8ed..da50b8f 100644 --- a/lua/neoai/chat/models/openai.lua +++ b/lua/neoai/chat/models/openai.lua @@ -1,91 +1,105 @@ -local utils = require("neoai.utils") local config = require("neoai.config") +local curl = require("plenary.curl") +local utils = require("neoai.utils") ---@type ModelModule local M = {} - M.name = "OpenAI" +local handler local chunks = {} local raw_chunks = {} + +---@brief Cancel the current stream and shut down the handler +M.cancel_stream = function() + if handler ~= nil then + handler:shutdown() + handler = nil + end +end + M.get_current_output = function() - return table.concat(chunks, "") + return table.concat(chunks, "") end ---@param chunk string ---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs -M._recieve_chunk = function(chunk, on_stdout_chunk) - for line in chunk:gmatch("[^\n]+") do - local raw_json = string.gsub(line, "^data: ", "") +M._receive_chunk = function(chunk, on_stdout_chunk) + local function safely_extract_delta_content(decoded_json) + local path = decoded_json.choices + if not path then + return nil + end + + path = path[1] + if not path then + return nil + end + + path = path.delta + if not path then + return nil + end + + return path.content + end + -- Remove "data:" prefix from chunk + local raw_json = string.gsub(chunk, "%s*data:%s*", "") + table.insert(raw_chunks, raw_json) - table.insert(raw_chunks, raw_json) - local ok, path = pcall(vim.json.decode, raw_json) - if not ok then - goto continue - end + local ok, decoded_json = pcall(vim.json.decode, raw_json) + if not ok then + return -- Ignore invalid JSON chunks + end - path = path.choices - if path == nil then - goto continue - end - path = path[1] - if path == nil then - goto continue - end - path = path.delta - if path == nil then - goto continue - end - path = path.content - if path == nil then - goto continue - end - on_stdout_chunk(path) - -- append_to_output(path, 0) - table.insert(chunks, path) - ::continue:: - end + local delta_content = safely_extract_delta_content(decoded_json) + if delta_content then + table.insert(chunks, delta_content) + on_stdout_chunk(delta_content) + end end ---@param chat_history ChatHistory ---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs ---@param on_complete fun(err?: string, output?: string) Function to call when model has finished M.send_to_model = function(chat_history, on_stdout_chunk, on_complete) - local api_key = os.getenv(config.options.open_api_key_env) + local api_key = os.getenv(config.options.open_api_key_env) - local data = { - model = chat_history.model, - stream = true, - messages = chat_history.messages, - } - data = vim.tbl_deep_extend("force", {}, data, chat_history.params) + local data = { + model = chat_history.model, + stream = true, + messages = chat_history.messages, + } + data = vim.tbl_deep_extend("force", {}, data, chat_history.params) - chunks = {} - raw_chunks = {} - utils.exec("curl", { - "--silent", - "--show-error", - "--no-buffer", - "https://api.openai.com/v1/chat/completions", - "-H", - "Content-Type: application/json", - "-H", - "Authorization: Bearer " .. api_key, - "-d", - vim.json.encode(data), - }, function(chunk) - M._recieve_chunk(chunk, on_stdout_chunk) - end, function(err, _) - local total_message = table.concat(raw_chunks, "") - local ok, json = pcall(vim.json.decode, total_message) - if ok then - if json.error ~= nil then - on_complete(json.error.message, nil) - return - end - end - on_complete(err, M.get_current_output()) - end) + chunks = {} + raw_chunks = {} + handler = curl.post({ + url = "https://api.openai.com/v1/chat/completions", + raw = { "--no-buffer" }, + headers = { + content_type = "application/json", + Authorization = "Bearer " .. api_key, + }, + body = vim.json.encode(data), + stream = function(_, chunk) + if chunk ~= "" then + -- The following workaround helps to identify when the model has completed its task. + if string.match(chunk, "%[DONE%]") then + vim.schedule(function() + on_complete(nil, M.get_current_output()) + end) + else + vim.schedule(function() + M._receive_chunk(chunk, on_stdout_chunk) + end) + end + end + end, + on_error = function(err, _, _) + return on_complete(err, nil) + end, + }) end return M diff --git a/lua/neoai/ui.lua b/lua/neoai/ui.lua index 525ffca..b24ede0 100644 --- a/lua/neoai/ui.lua +++ b/lua/neoai/ui.lua @@ -13,284 +13,290 @@ M.input_popup = nil M.layout = nil ---@param prompt string -M.submit_prompt = function(prompt) -end +M.submit_prompt = function(prompt) end M.clear_input = function() - if M.input_popup ~= nil then - local buffer = M.input_popup.bufnr - vim.api.nvim_buf_set_lines(buffer, 0, -1, false, {}) - end + if M.input_popup ~= nil then + local buffer = M.input_popup.bufnr + vim.api.nvim_buf_set_lines(buffer, 0, -1, false, {}) + end end M.get_component_heights = function(output_height_percentage) - local lines_height = vim.api.nvim_get_option("lines") - local statusline_height = vim.o.laststatus == 0 and 0 or 1 -- height of the statusline if present - local cmdline_height = vim.o.cmdheight -- height of the cmdline if present - local tabline_height = vim.o.showtabline == 0 and 0 or 1 -- height of the tabline if present - local total_height = lines_height - local used_height = statusline_height + cmdline_height + tabline_height - local layout_height = total_height - used_height - local output_height = math.floor(layout_height * output_height_percentage / 100) - local prompt_height = layout_height - output_height - local starting_row = tabline_height == 0 and 0 or 1 - - return { - starting_row = starting_row, - layout = layout_height, - output = output_height, - prompt = prompt_height, - } + local lines_height = vim.api.nvim_get_option("lines") + local statusline_height = vim.o.laststatus == 0 and 0 or 1 -- height of the statusline if present + local cmdline_height = vim.o.cmdheight -- height of the cmdline if present + local tabline_height = vim.o.showtabline == 0 and 0 or 1 -- height of the tabline if present + local total_height = lines_height + local used_height = statusline_height + cmdline_height + tabline_height + local layout_height = total_height - used_height + local output_height = math.floor(layout_height * output_height_percentage / 100) + local prompt_height = layout_height - output_height + local starting_row = tabline_height == 0 and 0 or 1 + + return { + starting_row = starting_row, + layout = layout_height, + output = output_height, + prompt = prompt_height, + } end M.is_focused = function() - if M.input_popup == nil then - vim.notify("NeoAI GUI needs to be open", vim.log.levels.ERROR) - return - end - local win = vim.api.nvim_get_current_win() - return win == M.output_popup.winid or win == M.input_popup.winid + if M.input_popup == nil then + vim.notify("NeoAI GUI needs to be open", vim.log.levels.ERROR) + return + end + local win = vim.api.nvim_get_current_win() + return win == M.output_popup.winid or win == M.input_popup.winid end M.focus = function() - if M.input_popup == nil then - vim.notify("NeoAI GUI needs to be open", vim.log.levels.ERROR) - return - end - vim.api.nvim_set_current_win(M.input_popup.winid) + if M.input_popup == nil then + vim.notify("NeoAI GUI needs to be open", vim.log.levels.ERROR) + return + end + vim.api.nvim_set_current_win(M.input_popup.winid) end M.create_ui = function() - -- Destroy UI if already open - if M.is_open() then - return - end - - local current_model = chat.get_current_model() - - M.output_popup = Popup({ - enter = false, - focusable = true, - zindex = 50, - position = "50%", - border = { - style = "rounded", - text = { - top = " " .. config.options.ui.output_popup_text .. " ", - top_align = "center", - bottom = " Model: " .. current_model.model .. " (" .. current_model.name.name .. ") ", - bottom_align = "left", - }, - }, - buf_options = { - -- modifiable = true, - -- readonly = false, - filetype = "neoai-output", - }, - -- win_options = { - -- winblend = 10, - -- winhighlight = "Normal:Normal,FloatBorder:FloatBorder", - -- }, - win_options = { - wrap = true, - }, - }) - - M.input_popup = Popup({ - enter = true, - focusable = true, - zindex = 50, - position = "50%", - border = { - style = "rounded", - padding = { - left = 1, - right = 1, - }, - text = { - top = " " .. config.options.ui.input_popup_text .. " ", - top_align = "center", - }, - }, - buf_options = { - modifiable = true, - readonly = false, - filetype = "neoai-input", - }, - win_options = { - winblend = 0, - winhighlight = "Normal:Normal,FloatBorder:FloatBorder", - wrap = true, - }, - }) - - local component_heights = M.get_component_heights(config.options.ui.output_popup_height) - M.layout = Layout( - { - relative = "editor", - position = { - row = component_heights.starting_row, - col = "100%", - }, - size = { - width = config.options.ui.width .. "%", - height = component_heights.layout, - }, - }, - Layout.Box({ - Layout.Box(M.output_popup, { size = component_heights.output }), - Layout.Box(M.input_popup, { size = component_heights.prompt }), - }, { dir = "col" }) - ) - M.layout:mount() - - M.output_popup:on({ event.BufDelete, event.WinClosed }, function() - M.destroy_ui() - end) - M.input_popup:on({ event.BufDelete, event.WinClosed }, function() - M.destroy_ui() - end) - - chat.new_chat_history() - - local input_buffer = M.input_popup.bufnr - local output_buffer = M.output_popup.bufnr - - M.submit_prompt = function() - local lines = vim.api.nvim_buf_get_lines(input_buffer, 0, -1, false) - local prompt = table.concat(lines, "\n") - M.send_prompt(prompt) - M.clear_input() - end - - local opts = { noremap = true, silent = true } - vim.api.nvim_buf_set_keymap(input_buffer, "i", "", "", opts) - vim.api.nvim_buf_set_keymap(input_buffer, "i", "", "lua require('neoai.ui').submit_prompt()", opts) - - local key = config.options.mappings["select_up"] - if key ~= nil then - local keys = {} - if type(key) == "table" then - keys = key - else - keys = { key } - end - for _, k in ipairs(keys) do - vim.api.nvim_buf_set_keymap( - input_buffer, - "n", - k, - "lua vim.api.nvim_set_current_win(require('neoai.ui').output_popup.winid)", - opts - ) - end - end - key = config.options.mappings["select_down"] - if key ~= nil then - local keys = {} - if type(key) == "table" then - keys = key - else - keys = { key } - end - for _, k in ipairs(keys) do - vim.api.nvim_buf_set_keymap( - output_buffer, - "n", - k, - "lua vim.api.nvim_set_current_win(require('neoai.ui').input_popup.winid)", - opts - ) - end - end - - M.set_destroy_key_mappings(input_buffer) + -- Destroy UI if already open + if M.is_open() then + return + end + + local current_model = chat.get_current_model() + + M.output_popup = Popup({ + enter = false, + focusable = true, + zindex = 50, + position = "50%", + border = { + style = "rounded", + text = { + top = " " .. config.options.ui.output_popup_text .. " ", + top_align = "center", + bottom = " Model: " .. current_model.model .. " (" .. current_model.name.name .. ") ", + bottom_align = "left", + }, + }, + buf_options = { + -- modifiable = true, + -- readonly = false, + filetype = "neoai-output", + }, + -- win_options = { + -- winblend = 10, + -- winhighlight = "Normal:Normal,FloatBorder:FloatBorder", + -- }, + win_options = { + wrap = true, + }, + }) + + M.input_popup = Popup({ + enter = true, + focusable = true, + zindex = 50, + position = "50%", + border = { + style = "rounded", + padding = { + left = 1, + right = 1, + }, + text = { + top = " " .. config.options.ui.input_popup_text .. " ", + top_align = "center", + }, + }, + buf_options = { + modifiable = true, + readonly = false, + filetype = "neoai-input", + }, + win_options = { + winblend = 0, + winhighlight = "Normal:Normal,FloatBorder:FloatBorder", + wrap = true, + }, + }) + + local component_heights = M.get_component_heights(config.options.ui.output_popup_height) + M.layout = Layout( + { + relative = "editor", + position = { + row = component_heights.starting_row, + col = "100%", + }, + size = { + width = config.options.ui.width .. "%", + height = component_heights.layout, + }, + }, + Layout.Box({ + Layout.Box(M.output_popup, { size = component_heights.output }), + Layout.Box(M.input_popup, { size = component_heights.prompt }), + }, { dir = "col" }) + ) + M.layout:mount() + + M.output_popup:on({ event.BufDelete, event.WinClosed }, function() + M.destroy_ui() + end) + M.input_popup:on({ event.BufDelete, event.WinClosed }, function() + M.destroy_ui() + end) + + chat.new_chat_history() + + local input_buffer = M.input_popup.bufnr + local output_buffer = M.output_popup.bufnr + + M.submit_prompt = function() + local lines = vim.api.nvim_buf_get_lines(input_buffer, 0, -1, false) + local prompt = table.concat(lines, "\n") + M.send_prompt(prompt) + M.clear_input() + end + + local opts = { noremap = true, silent = true } + vim.api.nvim_buf_set_keymap(input_buffer, "i", "", "", opts) + vim.api.nvim_buf_set_keymap(input_buffer, "i", "", "lua require('neoai.ui').submit_prompt()", opts) + vim.api.nvim_buf_set_keymap( + input_buffer, + "n", + "", + "lua vim.print(require('neoai.chat').get_current_model().name.cancel_stream())", + opts + ) + + local key = config.options.mappings["select_up"] + if key ~= nil then + local keys = {} + if type(key) == "table" then + keys = key + else + keys = { key } + end + for _, k in ipairs(keys) do + vim.api.nvim_buf_set_keymap( + input_buffer, + "n", + k, + "lua vim.api.nvim_set_current_win(require('neoai.ui').output_popup.winid)", + opts + ) + end + end + key = config.options.mappings["select_down"] + if key ~= nil then + local keys = {} + if type(key) == "table" then + keys = key + else + keys = { key } + end + for _, k in ipairs(keys) do + vim.api.nvim_buf_set_keymap( + output_buffer, + "n", + k, + "lua vim.api.nvim_set_current_win(require('neoai.ui').input_popup.winid)", + opts + ) + end + end + + M.set_destroy_key_mappings(input_buffer) end -- This function sets a keymap for the input buffer. In normal mode, pressing -- the '' or key triggers the 'neoai.ui' module's 'destroy_ui' function, which M.set_destroy_key_mappings = function(input_buffer) - local mappings = { - "", - "", - } - for _, key in ipairs(mappings) do - vim.api.nvim_buf_set_keymap( - input_buffer, - "n", - key, - "lua require('neoai.ui').destroy_ui()", - { noremap = true, silent = true } - ) - end + local mappings = { + "", + "", + } + for _, key in ipairs(mappings) do + vim.api.nvim_buf_set_keymap( + input_buffer, + "n", + key, + "lua require('neoai.ui').destroy_ui()", + { noremap = true, silent = true } + ) + end end M.send_prompt = function(prompt) - chat.send_prompt(prompt, M.append_to_output, true, function(output) - utils.save_to_registers(output) - end) + chat.send_prompt(prompt, M.append_to_output, true, function(output) + utils.save_to_registers(output) + end) end M.destroy_ui = function() - if M.layout ~= nil then - M.layout:unmount() - end - M.layout = nil - M.submit_prompt = function() - -- Empty function - end - chat.reset() + if M.layout ~= nil then + M.layout:unmount() + end + M.layout = nil + M.submit_prompt = function() + -- Empty function + end + chat.reset() end M.is_open = function() - return M.layout ~= nil + return M.layout ~= nil end ---Append text to the output, GPT should populate this ---@param txt string The text to append to the UI ---@param type integer 0/nil = normal, 1 = input M.append_to_output = function(txt, type) - local lines = vim.split(txt, "\n", {}) - - local ns = vim.api.nvim_get_namespaces().neoai_output - - if ns == nil then - ns = vim.api.nvim_create_namespace("neoai_output") - end - - local hl = "Normal" - if type == 1 then - -- hl = "NeoAIInput" - hl = "ErrorMsg" - end - - local length = #lines - - if M.output_popup == nil then - vim.notify("NeoAI window needs to be open", vim.log.levels.ERROR) - return - end - local buffer = M.output_popup.bufnr - local win = M.output_popup.winid - - for i, line in ipairs(lines) do - local currentLine = vim.api.nvim_buf_get_lines(buffer, -2, -1, false)[1] - vim.api.nvim_buf_set_lines(buffer, -2, -1, false, { currentLine .. line }) - - -- local last_line_num = vim.api.nvim_buf_line_count(buffer) - -- local last_line = vim.api.nvim_buf_get_lines(buffer, last_line_num - 1, last_line_num, false)[1] - -- local new_text = last_line .. line - - -- vim.api.nvim_buf_set_lines(buffer, last_line_num - 1, last_line_num, false, { new_text }) - local last_line_num = vim.api.nvim_buf_line_count(buffer) - -- vim.api.nvim_buf_add_highlight(buffer, ns, hl, last_line_num - 1, 0, -1) - - if i < length then - -- Add new line - vim.api.nvim_buf_set_lines(buffer, -1, -1, false, { "" }) - end - vim.api.nvim_win_set_cursor(win, { last_line_num, 0 }) - end + local lines = vim.split(txt, "\n", {}) + + local ns = vim.api.nvim_get_namespaces().neoai_output + + if ns == nil then + ns = vim.api.nvim_create_namespace("neoai_output") + end + + local hl = "Normal" + if type == 1 then + -- hl = "NeoAIInput" + hl = "ErrorMsg" + end + + local length = #lines + + if M.output_popup == nil then + vim.notify("NeoAI window needs to be open", vim.log.levels.ERROR) + return + end + local buffer = M.output_popup.bufnr + local win = M.output_popup.winid + + for i, line in ipairs(lines) do + local currentLine = vim.api.nvim_buf_get_lines(buffer, -2, -1, false)[1] + vim.api.nvim_buf_set_lines(buffer, -2, -1, false, { currentLine .. line }) + + -- local last_line_num = vim.api.nvim_buf_line_count(buffer) + -- local last_line = vim.api.nvim_buf_get_lines(buffer, last_line_num - 1, last_line_num, false)[1] + -- local new_text = last_line .. line + + -- vim.api.nvim_buf_set_lines(buffer, last_line_num - 1, last_line_num, false, { new_text }) + local last_line_num = vim.api.nvim_buf_line_count(buffer) + -- vim.api.nvim_buf_add_highlight(buffer, ns, hl, last_line_num - 1, 0, -1) + + if i < length then + -- Add new line + vim.api.nvim_buf_set_lines(buffer, -1, -1, false, { "" }) + end + vim.api.nvim_win_set_cursor(win, { last_line_num, 0 }) + end end return M diff --git a/lua/neoai/utils.lua b/lua/neoai/utils.lua index c9f882b..5361fb4 100644 --- a/lua/neoai/utils.lua +++ b/lua/neoai/utils.lua @@ -1,6 +1,7 @@ local config = require("neoai.config") local M = {} + ---@param text string ---@return string M.extract_code_snippets = function(text) @@ -21,58 +22,8 @@ end ---@param output string M.save_to_registers = function(output) - for register, strip_func in pairs(config.options.register_output) do - vim.fn.setreg(register, strip_func(output)) - end -end - ----Executes command getting stdout chunks ----@param cmd string ----@param args string[] ----@param on_stdout_chunk fun(chunk: string): nil ----@param on_complete fun(err: string?, output: string?): nil -function M.exec(cmd, args, on_stdout_chunk, on_complete) - local stdout = vim.loop.new_pipe() - local function on_stdout_read(_, chunk) - if chunk then - vim.schedule(function() - on_stdout_chunk(chunk) - end) - end - end - - local stderr = vim.loop.new_pipe() - local stderr_chunks = {} - local function on_stderr_read(_, chunk) - if chunk then - table.insert(stderr_chunks, chunk) - end - end - - local handle - - handle, err = vim.loop.spawn(cmd, { - args = args, - stdio = { nil, stdout, stderr }, - }, function(code) - stdout:close() - stderr:close() - handle:close() - - vim.schedule(function() - if code ~= 0 then - on_complete(vim.trim(table.concat(stderr_chunks, ""))) - else - on_complete() - end - end) - end) - - if not handle then - on_complete(cmd .. " could not be started: " .. err) - else - stdout:read_start(on_stdout_read) - stderr:read_start(on_stderr_read) + for register, strip_func in pairs(config.options.register_output) do + vim.fn.setreg(register, strip_func(output)) end end @@ -80,4 +31,5 @@ M.is_empty = function(s) return s == nil or s == "" end + return M