diff --git a/.vim/coc-settings.json b/.vim/coc-settings.json new file mode 100644 index 0000000..4d6fb60 --- /dev/null +++ b/.vim/coc-settings.json @@ -0,0 +1,5 @@ +{ + "Lua.workspace.library": [ + "${3rd}/luassert/library" + ] +} \ No newline at end of file diff --git a/README.md b/README.md index ca22c00..800d496 100644 --- a/README.md +++ b/README.md @@ -178,12 +178,23 @@ require("neoai").setup({ output_popup_height = 80, -- As percentage eg. 80% submit = "", -- Key binding to submit the prompt }, + selected_model_index = 0, models = { { name = "openai", model = "gpt-3.5-turbo", params = nil, }, + { + name = "spark", + model = "v1", + params = nil, + }, + { + name = "qianfan", + model = "ErnieBot-turbo", + params = nil, + } }, register_output = { ["g"] = function(output) @@ -222,6 +233,21 @@ require("neoai").setup({ -- end, }, }, + spark = { + random_threshold = 0.5, + max_tokens = 4096, + api_key = { + appid_env = "SPARK_APPID", + secret_env = "SPARK_SECRET", + apikey_env = "SPARK_APIKEY", + }, + }, + qianfan = { + api_key = { + secret_env = "QIANFAN_SECRET", + apikey_env = "QIANFAN_APIKEY", + }, + }, shortcuts = { { name = "textify", @@ -305,6 +331,22 @@ end - `api_key.value`: The OpenAI API key, which takes precedence over `api_key .env`. - `api_key.get`: A function that retrieves the OpenAI API key. For an example implementation, refer to the [Setup](#Setup) section. It has the higher precedence. +### Spark Options: +- `random_threshold` Kernel sampling threshold. Used to determine the randomness of the outcome, the higher the value, the stronger the randomness, that is, the higher the probability of different answers to the same question +- `max_tokens` The maximum length of tokens answered by the model +- `api_key.appid_env` The environment variable containing the Spark appid. The default value is "SPARK_APPID". +- `api_key.secret_env` The environment variable containing the Spark secret key. The default value is "SPARK_SECRET". +- `api_key.apikey_env` The environment variable containing the Spark api key. The default value is "SPARK_APIKEY". +- `api_key.appid` App appid, obtained from an app created in the Open Platform console +- `api_key.secret` App secret key, btained from an app created in the Open Platform console +- `api_key.apikey` App api key, btained from an app created in the Open Platform console +- `api_key.get` A function that retrieves the Spark API key. For an example implementation, refer to the [Setup](#Setup) section. It has the higher precedence. + +### Qianfan Options: +- `api_key.secret_env` The environment variable containing the Qianfan secret key. The default value is "QIANFAN_SECRET". +- `api_key.apikey_env` The environment variable containing the Qianfan api key. The default value is "QIANFAN_APIKEY". +- `api_key.get` A function that retrieves the Qianfan API key. For an example implementation, refer to the [Setup](#Setup) section. It has the higher precedence. + ### Mappings - `mappings`: A table containing the following actions that can be keys: @@ -381,6 +423,8 @@ visually selected text or the entire buffer if no selection is made. Triggers a NeoAI shortcut that is created in the config via it's name instead of a keybinding. +### :NeoAISetSource +Sets the source of the AI model. ## Roadmap: diff --git a/lua/neoai.lua b/lua/neoai.lua index 7f39a4f..b59ce6e 100644 --- a/lua/neoai.lua +++ b/lua/neoai.lua @@ -114,19 +114,19 @@ M.inject = function(prompt, strip_function, start_line) local current_line = start_line or vim.api.nvim_win_get_cursor(0)[1] chat.send_prompt( - prompt, - function(txt, _) - -- Get differences between text - local txt1 = strip_function(chat.get_current_output()) - local txt2 = strip_function(table.concat({ chat.get_current_output(), txt }, "")) - - inject.append_to_buffer(string.sub(txt2, #txt1 + 1), current_line) - end, - false, - function(_) - inject.current_line = nil - vim.notify("NeoAI: Done generating AI response", vim.log.levels.INFO) - end + prompt, + function(txt, _) + -- Get differences between text + local txt1 = strip_function(chat.get_current_output()) + local txt2 = strip_function(table.concat({ chat.get_current_output(), txt }, "")) + + inject.append_to_buffer(string.sub(txt2, #txt1 + 1), current_line) + end, + false, + function(_) + inject.current_line = nil + vim.notify("NeoAI: Done generating AI response", vim.log.levels.INFO) + end ) end @@ -140,4 +140,24 @@ M.context_inject = function(prompt, strip_function, line1, line2) M.inject(prompt, strip_function, line2) end +local make_prompt = function() + local ret = "" + for i, model_obj in ipairs(config.options.models) do + ret = ret .. i-1 .. "." .. model_obj.name .. "\n" + end + ret = ret .. "Choose:" + return ret +end + +M.set_source = function() + vim.ui.input({ prompt = make_prompt() }, function(text) + if text == nil or string.len(text) == 0 then return end + local n = tonumber(text) + if n ~= nil and n >= 0 and n < #config.options.models then + config.options.selected_model_index = n + chat.setup_models() + end + end) +end + return M diff --git a/lua/neoai/chat/init.lua b/lua/neoai/chat/init.lua index 8848bb4..cdaf3e5 100644 --- a/lua/neoai/chat/init.lua +++ b/lua/neoai/chat/init.lua @@ -12,9 +12,11 @@ local append_to_output = nil ---@type {name: ModelModule, model: string, params: table | nil}[] A list of models M.models = {} -M.selected_model = 0 + +M.selected_model = nil M.setup_models = function() + M.selected_model = config.options.selected_model_index or 0 for _, model_obj in ipairs(config.options.models) do local raw_model = model_obj.model local models diff --git a/lua/neoai/chat/models/qianfan.lua b/lua/neoai/chat/models/qianfan.lua new file mode 100644 index 0000000..f8cf087 --- /dev/null +++ b/lua/neoai/chat/models/qianfan.lua @@ -0,0 +1,77 @@ +local utils = require("neoai.utils") +local config = require("neoai.config") + +---@type ModelModule +local M = {} + +M.name = "Qianfan" + +M._chunks = {} +local raw_chunks = {} + +M.get_current_output = function() + return table.concat(M._chunks, "") +end + +---@param chunk string +---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs +M._recieve_chunk = function(chunk, on_stdout_chunk) + local raw_json = chunk + + table.insert(raw_chunks, raw_json) + + local ok, path = pcall(vim.json.decode, raw_json) + if not ok then + return + end + + path = path.result + if path == nil then + return + end + + on_stdout_chunk(path) + -- append_to_output(path, 0) + table.insert(M._chunks, path) +end + +---@param chat_history ChatHistory +---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs +---@param on_complete fun(err?: string, output?: string) Function to call when model has finished +M.send_to_model = function(chat_history, on_stdout_chunk, on_complete) + local secret, apikey = config.options.qianfan.api_key.get() + local model = config.options.models[config.options.selected_model_index+1].model + + local get_script_dir = function() + local info = debug.getinfo(1, "S") + local script_path = info.source:sub(2) + return script_path:match("(.*/)") + end + + local py_script_path = get_script_dir() .. "qianfan.py" + + os.execute("chmod +x "..py_script_path) + + chunks = {} + raw_chunks = {} + utils.exec(py_script_path, { + apikey, + secret, + vim.json.encode(chat_history.messages), + model, + }, function(chunk) + M._recieve_chunk(chunk, on_stdout_chunk) + end, function(err, _) + local total_message = table.concat(raw_chunks, "") + local ok, json = pcall(vim.json.decode, total_message) + if ok then + if json.error ~= nil then + on_complete(json.error.message, nil) + return + end + end + on_complete(err, M.get_current_output()) +end) +end + +return M diff --git a/lua/neoai/chat/models/qianfan.py b/lua/neoai/chat/models/qianfan.py new file mode 100755 index 0000000..3ea20bf --- /dev/null +++ b/lua/neoai/chat/models/qianfan.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +import requests +import json +import argparse +import sys + + +urls = { + "ErnieBot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ErnieBot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", +} + + +def chat(api_key, secret_key, messages, m): + url = urls[m] + "?access_token=" + get_access_token(api_key, secret_key) + + payload = json.dumps({ + "messages": messages + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + print(response.text) + + +def get_access_token(api_key, secret_key): + url = "https://aip.baidubce.com/oauth/2.0/token" + params = {"grant_type": "client_credentials", + "client_id": api_key, "client_secret": secret_key} + return str(requests.post(url, params=params).json().get("access_token")) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('api_key') + parser.add_argument('secret_key') + parser.add_argument('messages') + parser.add_argument('model') + args = parser.parse_args() + messages = json.loads(args.messages) + chat(args.api_key, args.secret_key, messages, args.model) diff --git a/lua/neoai/chat/models/spark.lua b/lua/neoai/chat/models/spark.lua new file mode 100644 index 0000000..fa8d340 --- /dev/null +++ b/lua/neoai/chat/models/spark.lua @@ -0,0 +1,109 @@ +local utils = require("neoai.utils") +local config = require("neoai.config") + +---@type ModelModule +local M = {} + +M.name = "Spark" + +M._chunks = {} +local raw_chunks = {} + +M.get_current_output = function() + return table.concat(M._chunks, "") +end + +---@param chunk string +---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs +M._recieve_chunk = function(chunk, on_stdout_chunk) + for line in chunk:gmatch("[^\n]+") do + local raw_json = line + + table.insert(raw_chunks, raw_json) + + local ok, path = pcall(vim.json.decode, raw_json) + if not ok then + goto continue + end + + path = path.payload + if path == nil then + goto continue + end + + path = path.choices + if path == nil then + goto continue + end + + path = path.text + if path == nil then + goto continue + end + + path = path[1] + if path == nil then + goto continue + end + + path = path.content + if path == nil then + goto continue + end + + + on_stdout_chunk(path) + -- append_to_output(path, 0) + table.insert(M._chunks, path) + ::continue:: + end +end + +---@param chat_history ChatHistory +---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs +---@param on_complete fun(err?: string, output?: string) Function to call when model has finished +M.send_to_model = function(chat_history, on_stdout_chunk, on_complete) + local appid, secret, apikey = config.options.spark.api_key.get() + local ver = config.options.models[config.options.selected_model_index+1].model + local random_threshold = config.options.spark.random_threshold + local max_tokens = config.options.spark.max_tokens + + local get_script_dir = function() + local info = debug.getinfo(1, "S") + local script_path = info.source:sub(2) + return script_path:match("(.*/)") + end + + local py_script_path = get_script_dir() .. "spark.py" + + os.execute("chmod +x "..py_script_path) + + chunks = {} + raw_chunks = {} + utils.exec(py_script_path, { + appid, + secret, + apikey, + vim.json.encode(chat_history.messages), + "--ver", + ver, + "--random_threshold", + random_threshold, + "--max_tokens", + max_tokens + }, function(chunk) + M._recieve_chunk(chunk, on_stdout_chunk) + end, function(err, _) + local total_message = table.concat(raw_chunks, "") + local ok, json = pcall(vim.json.decode, total_message) + if ok then + if json.error ~= nil then + on_complete(json.error.message, nil) + return + end + end + on_complete(err, M.get_current_output()) +end) +end + +return M diff --git a/lua/neoai/chat/models/spark.py b/lua/neoai/chat/models/spark.py new file mode 100755 index 0000000..76c6202 --- /dev/null +++ b/lua/neoai/chat/models/spark.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +import _thread as thread +import base64 +import datetime +import hashlib +import hmac +import json +from urllib.parse import urlparse +import ssl +from datetime import datetime +from time import mktime +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time +import argparse +import subprocess +import os + + +def install_package(package_name): + try: + with open(os.devnull, 'w') as null: + subprocess.check_call( + ["pip", "install", package_name], stderr=null, stdout=null) + except subprocess.CalledProcessError: + print(f"Failed to install {package_name}") + + +try: + import websocket +except ImportError: + install_package(package_name) + import websocket + + +class WSParam(object): + def __init__(self, APPID, APIKey, APISecret, spark_url): + self.APPID = APPID + self.APIKey = APIKey + self.APISecret = APISecret + self.host = urlparse(spark_url).netloc + self.path = urlparse(spark_url).path + self.spark_url = spark_url + + def create_url(self): + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + signature_origin = "host: " + self.host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + self.path + " HTTP/1.1" + + signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + + signature_sha_base64 = base64.b64encode( + signature_sha).decode(encoding='utf-8') + + authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' + + authorization = base64.b64encode( + authorization_origin.encode('utf-8')).decode(encoding='utf-8') + + v = { + "authorization": authorization, + "date": date, + "host": self.host + } + url = self.spark_url + '?' + urlencode(v) + return url + + +def on_error(ws, error): + print("### error:", error) + + +def on_close(ws, one, two): + print(" ") + + +def on_open(ws): + thread.start_new_thread(run, (ws,)) + + +def run(ws, *args): + data = json.dumps(gen_params( + appid=ws.appid, + domain=ws.domain, + messages=ws.messages, + random_threshold=ws.random_threshold, + max_tokens=ws.max_tokens)) + ws.send(data) + + +def gen_params(appid, domain, random_threshold, max_tokens, messages): + data = { + "header": { + "app_id": appid, + "uid": "1234" + }, + "parameter": { + "chat": { + "domain": domain, + "random_threshold": random_threshold, + "max_tokens": max_tokens, + "auditing": "default" + } + }, + "payload": { + "message": { + "text": messages + } + } + } + return data + + +def Request(appid, secret, apikey, messages, version, random_threshold, max_token): + if version == "v1": + spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" + domain = "general" + elif version == "v2": + spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" + domain = "generalv2" + ws_param = WSParam(appid, apikey, secret, spark_url) + ws_url = ws_param.create_url() + + def on_message(ws, message): + data = json.loads(message) + print(json.dumps(data, ensure_ascii=False)) + code = data['header']['code'] + if code != 0: + ws.close() + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + if status == 2: + ws.close() + ws = websocket.WebSocketApp( + ws_url, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) + ws.appid = appid + ws.messages = messages + ws.domain = domain + ws.random_threshold = random_threshold + ws.max_tokens = max_token + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("appid") + parser.add_argument("secret") + parser.add_argument("apikey") + parser.add_argument("messages") + parser.add_argument("--ver", "-v", default="v1") + parser.add_argument("--random_threshold", "-r", default=0.5, type=float) + parser.add_argument("--max_tokens", "-t", default=4096, type=int) + parse_result = parser.parse_args() + messages = json.loads(parse_result.messages) + Request(parse_result.appid, parse_result.secret, parse_result.apikey, messages, + parse_result.ver, parse_result.random_threshold, parse_result.max_tokens) + + +if __name__ == "__main__": + main() diff --git a/lua/neoai/config.lua b/lua/neoai/config.lua index b0d7893..d42108c 100644 --- a/lua/neoai/config.lua +++ b/lua/neoai/config.lua @@ -16,12 +16,23 @@ M.get_defaults = function() output_popup_height = 80, -- As percentage eg. 80% submit = "", }, + selected_model_index = 0, models = { { name = "openai", model = "gpt-3.5-turbo", params = nil, }, + { + name = "spark", + model = "v1", + params = nil, + }, + { + name = "qianfan", + model = "ErnieBot-turbo", + params = nil, + } }, register_output = { ["g"] = function(output) @@ -35,9 +46,9 @@ M.get_defaults = function() prompts = { context_prompt = function(context) return "Hey, I'd like to provide some context for future " - .. "messages. Here is the code/text that I want to refer " - .. "to in our upcoming conversations (TEXT/CODE ONLY):\n\n" - .. context + .. "messages. Here is the code/text that I want to refer " + .. "to in our upcoming conversations (TEXT/CODE ONLY):\n\n" + .. context end, }, mappings = { @@ -67,12 +78,66 @@ M.get_defaults = function() return open_api_key end local msg = M.options.open_ai.api_key.env - .. " environment variable is not set, and open_api_key.value is empty" + .. " environment variable is not set, and open_api_key.value is empty" logger.error(msg) error(msg) end, }, }, + spark = { + random_threshold = 0.5, + max_tokens = 4096, + api_key = { + appid_env = "SPARK_APPID", + secret_env = "SPARK_SECRET", + apikey_env = "SPARK_APIKEY", + appid = nil, + secret = nil, + apikey = nil, + get = function() + if not M.options.spark.api_key.appid then + M.options.spark.api_key.appid = os.getenv(M.options.spark.api_key.appid_env) + end + if not M.options.spark.api_key.secret then + M.options.spark.api_key.secret = os.getenv(M.options.spark.api_key.secret_env) + end + if not M.options.spark.api_key.apikey then + M.options.spark.api_key.apikey = os.getenv(M.options.spark.api_key.apikey_env) + end + if M.options.spark.api_key.appid and M.options.spark.api_key.secret and M.options.spark.api_key.apikey then + return M.options.spark.api_key.appid, M.options.spark.api_key.secret, M.options.spark.api_key.apikey + end + local msg = M.options.spark.api_key.appid_env .. "/" + .. M.options.spark.api_key.secret_env .. "/" + .. M.options.spark.api_key.apikey_env .. " environment variable is not set" + logger.error(msg) + error(msg) + end + }, + }, + qianfan = { + api_key = { + secret_env = "QIANFAN_SECRET", + apikey_env = "QIANFAN_APIKEY", + secret = nil, + apikey = nil, + get = function() + if not M.options.qianfan.api_key.secret then + M.options.qianfan.api_key.secret = os.getenv(M.options.qianfan.api_key.secret_env) + end + if not M.options.qianfan.api_key.apikey then + M.options.qianfan.api_key.apikey = os.getenv(M.options.qianfan.api_key.apikey_env) + end + if M.options.qianfan.api_key.secret and M.options.qianfan.api_key.apikey then + return M.options.qianfan.api_key.secret, M.options.qianfan.api_key.apikey + end + local msg = M.options.qianfan.api_key.secret_env .. "/" + .. M.options.qianfan.api_key.apikey_env .. " environment variable is not set" + logger.error(msg) + error(msg) + end + }, + }, shortcuts = { { name = "textify", @@ -80,9 +145,9 @@ M.get_defaults = function() desc = "NeoAI fix text with AI", use_context = true, prompt = [[ - Please rewrite the text to make it more readable, clear, - concise, and fix any grammatical, punctuation, or spelling - errors + Please rewrite the text to make it more readable, clear, + concise, and fix any grammatical, punctuation, or spelling + errors ]], modes = { "v" }, strip_function = nil, @@ -94,9 +159,9 @@ M.get_defaults = function() use_context = false, prompt = function() return [[ - Using the following git diff generate a consise and - clear git commit message, with a short title summary - that is 75 characters or less: + Using the following git diff generate a consise and + clear git commit message, with a short title summary + that is 75 characters or less: ]] .. vim.fn.system("git diff --cached") end, modes = { "n" }, @@ -141,15 +206,43 @@ end ---@field value string | nil The value of the open api key to use, if nil then use the environment variable ---@field get fun(): string The function to get the open api key +---@class Spark_Options +---@field random_threshold float The random_threshold +---@field max_tokens int The max tokens count +---@field version ("v1" | "v2") The model version +---@field api_key Spark_Key_Options The Spark api key options + +---@class Qianfan_Options +---@field api_key Qianfan_Key_Options The Qianfan api key options + +---@class Spark_Key_Options +---@field appid_env string The environment variable to get the spark appid +---@field secret_env string The environment variable to get the spark secret +---@field apikey_env string The environment variable to get the spark apikey +---@field appid string The spark appid +---@field secret string The spark secret +---@field apikey string The spark apikey +---@field get fun(): string,string,string The function to get the open api key + +---@class Qianfan_Key_Options +---@field secret_env string The environment variable to get the qianfan secret +---@field apikey_env string The environment variable to get the qianfan apikey +---@field secret string The qianfan secret +---@field apikey string The qianfan apikey +---@field get fun(): string,string The function to get the open api key + ---@class Options ---@field ui UI_Options UI configurations ---@field model string The OpenAI model to use by default @depricated ---@field models Model_Options[] A list of different model options to use. First element will be default +---@field selected_model_index int Selected model index (started from zero) ---@field register_output table A table with a register as the key and a function that takes the raw output from the AI and outputs what you want to save into that register ---@field inject Inject_Options The inject options ---@field prompts Prompt_Options The custom prompt options ---@field open_api_key_env string The environment variable that contains the openai api key ---@field open_ai Open_AI_Options The open api key options +---@field spark Spark_Options The Spark api key options +---@field qianfan Qianfan_Options The Qianfan api key options ---@field mappings table<"select_up" | "select_down", nil|string|string[]> A table of actions with it's mapping(s) ---@field shortcuts Shortcut[] Array of shortcuts M.options = {} diff --git a/plugin/neoai.lua b/plugin/neoai.lua index 48c6f37..c5f007f 100644 --- a/plugin/neoai.lua +++ b/plugin/neoai.lua @@ -1,75 +1,75 @@ -- Plain GUI vim.api.nvim_create_user_command("NeoAI", function(opts) - require("neoai").smart_toggle(opts.args) + require("neoai").smart_toggle(opts.args) end, { - nargs = "*", +nargs = "*", }) vim.api.nvim_create_user_command("NeoAIToggle", function(opts) - require("neoai").toggle(opts.args) + require("neoai").toggle(opts.args) end, { - nargs = "*", +nargs = "*", }) vim.api.nvim_create_user_command("NeoAIOpen", function(opts) - require("neoai").toggle(true, opts.args) + require("neoai").toggle(true, opts.args) end, { - nargs = "*", +nargs = "*", }) vim.api.nvim_create_user_command("NeoAIClose", function() - require("neoai").toggle(false) + require("neoai").toggle(false) end, {}) -- Context GUI vim.api.nvim_create_user_command("NeoAIContext", function(opts) - require("neoai").context_smart_toggle(opts.args, opts.line1, opts.line2) + require("neoai").context_smart_toggle(opts.args, opts.line1, opts.line2) end, { - range = "%", - nargs = "*", +range = "%", +nargs = "*", }) vim.api.nvim_create_user_command("NeoAIContextOpen", function(opts) - require("neoai").context_toggle(true, opts.args, opts.line1, opts.line2) + require("neoai").context_toggle(true, opts.args, opts.line1, opts.line2) end, { - range = "%", - nargs = "*", +range = "%", +nargs = "*", }) vim.api.nvim_create_user_command("NeoAIContextClose", function() - require("neoai").context_toggle(false, "", nil, nil) + require("neoai").context_toggle(false, "", nil, nil) end, {}) -- Inject Mode vim.api.nvim_create_user_command("NeoAIInject", function(opts) - require("neoai").inject(opts.args) + require("neoai").inject(opts.args) end, { - nargs = "+", +nargs = "+", }) vim.api.nvim_create_user_command("NeoAIInjectCode", function(opts) - local extract_code_snippets = require("neoai.utils").extract_code_snippets - require("neoai").inject(opts.args, extract_code_snippets) + local extract_code_snippets = require("neoai.utils").extract_code_snippets + require("neoai").inject(opts.args, extract_code_snippets) end, { - nargs = "+", +nargs = "+", }) vim.api.nvim_create_user_command("NeoAIInjectContext", function(opts) - require("neoai").context_inject(opts.args, nil, opts.line1, opts.line2) + require("neoai").context_inject(opts.args, nil, opts.line1, opts.line2) end, { - range = "%", - nargs = "+", +range = "%", +nargs = "+", }) vim.api.nvim_create_user_command("NeoAIInjectContextCode", function(opts) - local extract_code_snippets = require("neoai.utils").extract_code_snippets - require("neoai").context_inject(opts.args, extract_code_snippets, opts.line1, opts.line2) + local extract_code_snippets = require("neoai.utils").extract_code_snippets + require("neoai").context_inject(opts.args, extract_code_snippets, opts.line1, opts.line2) end, { - range = "%", - nargs = "+", +range = "%", +nargs = "+", }) vim.api.nvim_create_user_command("NeoAIShortcut", function (opts) @@ -87,7 +87,88 @@ vim.api.nvim_create_user_command("NeoAIShortcut", function (opts) end func() end, { - nargs = 1, - range = true, - complete = require("neoai.shortcuts").complete_shortcut +nargs = 1, +range = true, +complete = require("neoai.shortcuts").complete_shortcut }) + +-- Versions of the commands that use vim.ui.input for retrieving the prompt/context +-- Plain +vim.api.nvim_create_user_command("NeoAIPrompt", function(_) + vim.ui.input({ prompt = "Prompt: " }, function(text) + if text == nil or string.len(text) == 0 then return end + require("neoai").smart_toggle(text) + end) +end, { +nargs = 0, +}) + +vim.api.nvim_create_user_command("NeoAISetSource", require("neoai").set_source, { + nargs = 0, +}) + +-- Context +vim.api.nvim_create_user_command("NeoAIContextPrompt", function(opts) + vim.ui.input({ prompt = "Context: " }, function(text) + if text == nil or string.len(text) == 0 then return end + require("neoai").context_smart_toggle(text, opts.line1, opts.line2) + end) +end, { +nargs = 0, +range = "%", +}) + +-- Inject +vim.api.nvim_create_user_command( +"NeoAIInjectPrompt", +function(_) + vim.ui.input({ prompt = "Prompt: " }, function(text) + if text == nil or string.len(text) == 0 then return end + require("neoai").inject(text) + end) +end, +{ + nargs = 0, +} +) + +vim.api.nvim_create_user_command( +"NeoAIInjectCodePrompt", +function(_) + vim.ui.input({ prompt = "Prompt: " }, function(text) + if text == nil or string.len(text) == 0 then return end + require("neoai").inject(text, require('neoai.utils').extract_code_snippets) + end) +end, +{ + nargs = 0, +} +) + +vim.api.nvim_create_user_command( +"NeoAIInjectContextPrompt", +function(opts) + vim.ui.input({ prompt = "Context: " }, function(text) + if text == nil or string.len(text) == 0 then return end + require("neoai").context_inject(text, nil, opts.line1, opts.line2) + end) +end, +{ + range = "%", + nargs = 0, +} +) + +vim.api.nvim_create_user_command( +"NeoAIInjectContextCodePrompt", +function(opts) + vim.ui.input({ prompt = "Context: " }, function(text) + if text == nil or string.len(text) == 0 then return end + require("neoai").context_inject(text, require('neoai.utils').extract_code_snippets, opts.line1, opts.line2) + end) +end, +{ + range = "%", + nargs = 0, +} +)