Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .vim/coc-settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"Lua.workspace.library": [
"${3rd}/luassert/library"
]
}
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,23 @@ require("neoai").setup({
output_popup_height = 80, -- As percentage eg. 80%
submit = "<Enter>", -- Key binding to submit the prompt
},
selected_model_index = 0,
models = {
{
name = "openai",
model = "gpt-3.5-turbo",
params = nil,
},
{
name = "spark",
model = "v1",
params = nil,
},
{
name = "qianfan",
model = "ErnieBot-turbo",
params = nil,
}
},
register_output = {
["g"] = function(output)
Expand Down Expand Up @@ -222,6 +233,21 @@ require("neoai").setup({
-- end,
},
},
spark = {
random_threshold = 0.5,
max_tokens = 4096,
api_key = {
appid_env = "SPARK_APPID",
secret_env = "SPARK_SECRET",
apikey_env = "SPARK_APIKEY",
},
},
qianfan = {
api_key = {
secret_env = "QIANFAN_SECRET",
apikey_env = "QIANFAN_APIKEY",
},
},
shortcuts = {
{
name = "textify",
Expand Down Expand Up @@ -305,6 +331,22 @@ end
- `api_key.value`: The OpenAI API key, which takes precedence over `api_key .env`.
- `api_key.get`: A function that retrieves the OpenAI API key. For an example implementation, refer to the [Setup](#Setup) section. It has the higher precedence.

### Spark Options:
- `random_threshold` Kernel sampling threshold. Used to determine the randomness of the outcome, the higher the value, the stronger the randomness, that is, the higher the probability of different answers to the same question
- `max_tokens` The maximum length of tokens answered by the model
- `api_key.appid_env` The environment variable containing the Spark appid. The default value is "SPARK_APPID".
- `api_key.secret_env` The environment variable containing the Spark secret key. The default value is "SPARK_SECRET".
- `api_key.apikey_env` The environment variable containing the Spark api key. The default value is "SPARK_APIKEY".
- `api_key.appid` App appid, obtained from an app created in the Open Platform console
- `api_key.secret` App secret key, btained from an app created in the Open Platform console
- `api_key.apikey` App api key, btained from an app created in the Open Platform console
- `api_key.get` A function that retrieves the Spark API key. For an example implementation, refer to the [Setup](#Setup) section. It has the higher precedence.

### Qianfan Options:
- `api_key.secret_env` The environment variable containing the Qianfan secret key. The default value is "QIANFAN_SECRET".
- `api_key.apikey_env` The environment variable containing the Qianfan api key. The default value is "QIANFAN_APIKEY".
- `api_key.get` A function that retrieves the Qianfan API key. For an example implementation, refer to the [Setup](#Setup) section. It has the higher precedence.

### Mappings
- `mappings`: A table containing the following actions that can be keys:

Expand Down Expand Up @@ -381,6 +423,8 @@ visually selected text or the entire buffer if no selection is made.
Triggers a NeoAI shortcut that is created in the config via it's name instead of
a keybinding.

### :NeoAISetSource
Sets the source of the AI model.

## Roadmap:

Expand Down
46 changes: 33 additions & 13 deletions lua/neoai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,19 @@ M.inject = function(prompt, strip_function, start_line)

local current_line = start_line or vim.api.nvim_win_get_cursor(0)[1]
chat.send_prompt(
prompt,
function(txt, _)
-- Get differences between text
local txt1 = strip_function(chat.get_current_output())
local txt2 = strip_function(table.concat({ chat.get_current_output(), txt }, ""))

inject.append_to_buffer(string.sub(txt2, #txt1 + 1), current_line)
end,
false,
function(_)
inject.current_line = nil
vim.notify("NeoAI: Done generating AI response", vim.log.levels.INFO)
end
prompt,
function(txt, _)
-- Get differences between text
local txt1 = strip_function(chat.get_current_output())
local txt2 = strip_function(table.concat({ chat.get_current_output(), txt }, ""))

inject.append_to_buffer(string.sub(txt2, #txt1 + 1), current_line)
end,
false,
function(_)
inject.current_line = nil
vim.notify("NeoAI: Done generating AI response", vim.log.levels.INFO)
end
)
end

Expand All @@ -140,4 +140,24 @@ M.context_inject = function(prompt, strip_function, line1, line2)
M.inject(prompt, strip_function, line2)
end

local make_prompt = function()
local ret = ""
for i, model_obj in ipairs(config.options.models) do
ret = ret .. i-1 .. "." .. model_obj.name .. "\n"
end
ret = ret .. "Choose:"
return ret
end

M.set_source = function()
vim.ui.input({ prompt = make_prompt() }, function(text)
if text == nil or string.len(text) == 0 then return end
local n = tonumber(text)
if n ~= nil and n >= 0 and n < #config.options.models then
config.options.selected_model_index = n
chat.setup_models()
end
end)
end

return M
4 changes: 3 additions & 1 deletion lua/neoai/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ local append_to_output = nil

---@type {name: ModelModule, model: string, params: table<string, string> | nil}[] A list of models
M.models = {}
M.selected_model = 0

M.selected_model = nil

M.setup_models = function()
M.selected_model = config.options.selected_model_index or 0
for _, model_obj in ipairs(config.options.models) do
local raw_model = model_obj.model
local models
Expand Down
77 changes: 77 additions & 0 deletions lua/neoai/chat/models/qianfan.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
local utils = require("neoai.utils")
local config = require("neoai.config")

---@type ModelModule
local M = {}

M.name = "Qianfan"

M._chunks = {}
local raw_chunks = {}

M.get_current_output = function()
return table.concat(M._chunks, "")
end

---@param chunk string
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs
M._recieve_chunk = function(chunk, on_stdout_chunk)
local raw_json = chunk

table.insert(raw_chunks, raw_json)

local ok, path = pcall(vim.json.decode, raw_json)
if not ok then
return
end

path = path.result
if path == nil then
return
end

on_stdout_chunk(path)
-- append_to_output(path, 0)
table.insert(M._chunks, path)
end

---@param chat_history ChatHistory
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs
---@param on_complete fun(err?: string, output?: string) Function to call when model has finished
M.send_to_model = function(chat_history, on_stdout_chunk, on_complete)
local secret, apikey = config.options.qianfan.api_key.get()
local model = config.options.models[config.options.selected_model_index+1].model

local get_script_dir = function()
local info = debug.getinfo(1, "S")
local script_path = info.source:sub(2)
return script_path:match("(.*/)")
end

local py_script_path = get_script_dir() .. "qianfan.py"

os.execute("chmod +x "..py_script_path)

chunks = {}
raw_chunks = {}
utils.exec(py_script_path, {
apikey,
secret,
vim.json.encode(chat_history.messages),
model,
}, function(chunk)
M._recieve_chunk(chunk, on_stdout_chunk)
end, function(err, _)
local total_message = table.concat(raw_chunks, "")
local ok, json = pcall(vim.json.decode, total_message)
if ok then
if json.error ~= nil then
on_complete(json.error.message, nil)
return
end
end
on_complete(err, M.get_current_output())
end)
end

return M
42 changes: 42 additions & 0 deletions lua/neoai/chat/models/qianfan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3
import requests
import json
import argparse
import sys


urls = {
"ErnieBot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant",
"ErnieBot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant",
}


def chat(api_key, secret_key, messages, m):
url = urls[m] + "?access_token=" + get_access_token(api_key, secret_key)

payload = json.dumps({
"messages": messages
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
print(response.text)


def get_access_token(api_key, secret_key):
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials",
"client_id": api_key, "client_secret": secret_key}
return str(requests.post(url, params=params).json().get("access_token"))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('api_key')
parser.add_argument('secret_key')
parser.add_argument('messages')
parser.add_argument('model')
args = parser.parse_args()
messages = json.loads(args.messages)
chat(args.api_key, args.secret_key, messages, args.model)
109 changes: 109 additions & 0 deletions lua/neoai/chat/models/spark.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
local utils = require("neoai.utils")
local config = require("neoai.config")

---@type ModelModule
local M = {}

M.name = "Spark"

M._chunks = {}
local raw_chunks = {}

M.get_current_output = function()
return table.concat(M._chunks, "")
end

---@param chunk string
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs
M._recieve_chunk = function(chunk, on_stdout_chunk)
for line in chunk:gmatch("[^\n]+") do
local raw_json = line

table.insert(raw_chunks, raw_json)

local ok, path = pcall(vim.json.decode, raw_json)
if not ok then
goto continue
end

path = path.payload
if path == nil then
goto continue
end

path = path.choices
if path == nil then
goto continue
end

path = path.text
if path == nil then
goto continue
end

path = path[1]
if path == nil then
goto continue
end

path = path.content
if path == nil then
goto continue
end


on_stdout_chunk(path)
-- append_to_output(path, 0)
table.insert(M._chunks, path)
::continue::
end
end

---@param chat_history ChatHistory
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs
---@param on_complete fun(err?: string, output?: string) Function to call when model has finished
M.send_to_model = function(chat_history, on_stdout_chunk, on_complete)
local appid, secret, apikey = config.options.spark.api_key.get()
local ver = config.options.models[config.options.selected_model_index+1].model
local random_threshold = config.options.spark.random_threshold
local max_tokens = config.options.spark.max_tokens

local get_script_dir = function()
local info = debug.getinfo(1, "S")
local script_path = info.source:sub(2)
return script_path:match("(.*/)")
end

local py_script_path = get_script_dir() .. "spark.py"

os.execute("chmod +x "..py_script_path)

chunks = {}
raw_chunks = {}
utils.exec(py_script_path, {
appid,
secret,
apikey,
vim.json.encode(chat_history.messages),
"--ver",
ver,
"--random_threshold",
random_threshold,
"--max_tokens",
max_tokens
}, function(chunk)
M._recieve_chunk(chunk, on_stdout_chunk)
end, function(err, _)
local total_message = table.concat(raw_chunks, "")
local ok, json = pcall(vim.json.decode, total_message)
if ok then
if json.error ~= nil then
on_complete(json.error.message, nil)
return
end
end
on_complete(err, M.get_current_output())
end)
end

return M
Loading