-
-
Notifications
You must be signed in to change notification settings - Fork 52
Expand file tree
/
Copy pathinit.lua
More file actions
109 lines (91 loc) · 3.11 KB
/
init.lua
File metadata and controls
109 lines (91 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
local config = require("neoai.config")
local ChatHistory = require("neoai.chat.history")
local M = {}
---@type string | nil
M.context = nil
---@type ChatHistory
M.chat_history = nil
local append_to_output = nil
---@type {name: ModelModule, model: string, params: table<string, string> | nil}[] A list of models
M.models = {}
M.selected_model = nil
M.setup_models = function()
M.selected_model = config.options.selected_model_index or 0
for _, model_obj in ipairs(config.options.models) do
local raw_model = model_obj.model
local models
if type(raw_model) == "string" then
models = { raw_model, }
else
models = raw_model
end
for _, model in ipairs(models) do
table.insert(M.models, {
name = require("neoai.chat.models." .. model_obj.name),
model = model,
params = model_obj.params,
})
end
end
end
M.new_chat_history = function ()
local model = M.get_current_model()
M.chat_history = ChatHistory:new(model.model, model.params, M.context)
end
M.select_next_model = function ()
local length = #M.models
M.selected_model = (M.selected_model + 1) % length
M.new_chat_history()
end
---Gets the current selected model
---@return { name: ModelModule, model: string, params: table<string, string> | nil } current_model The current model
M.get_current_model = function ()
return M.models[M.selected_model+1]
end
---@param buffer number
---@param line1 number
---@param line2 number
M.set_context = function(buffer, line1, line2)
local context = table.concat(vim.api.nvim_buf_get_lines(buffer, line1 - 1, line2, false), "\n")
M.context = context
M.new_chat_history()
end
M.reset = function()
M.context = nil
M.chat_history = nil
end
local chunks = {}
M.get_current_output = function()
return table.concat(chunks, "")
end
---Sends the prompt to the chat
---@param prompt string The prompt to send
---@param append_to_output_func fun(txt: string, type: integer) The function that will append the prompt to the output
---@param separators boolean True if separators should be included
---@param on_complete fun(output: string) Called when completed
M.send_prompt = function(prompt, append_to_output_func, separators, on_complete)
append_to_output = function (txt, type)
local ok, _ = pcall(append_to_output_func, txt, type)
end
if separators then
append_to_output(prompt .. "\n\n--------\n\n", 1)
end
local on_stdout_chunk = function (chunk)
append_to_output(chunk, 0)
end
local on_model_complete = function (err, output)
if err ~= nil then
vim.notify("NeoAI Error: " .. err, vim.log.levels.ERROR)
return
end
if separators then
append_to_output("\n\n--------\n\n", 1)
end
M.chat_history:add_message(false, output)
on_complete(output)
end
M.chat_history:add_message(true, prompt)
local send_to_model = M.get_current_model().name.send_to_model
send_to_model(M.chat_history, on_stdout_chunk, on_model_complete)
end
return M