Skip to content

Commit

Permalink
clean and test
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek committed Sep 12, 2024
1 parent 0c8e6b7 commit e27ed8d
Show file tree
Hide file tree
Showing 6 changed files with 466 additions and 49 deletions.
16 changes: 16 additions & 0 deletions apisix/core/request.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

local lfs = require("lfs")
local log = require("apisix.core.log")
local json = require("apisix.core.json")
local io = require("apisix.core.io")
local req_add_header
if ngx.config.subsystem == "http" then
Expand Down Expand Up @@ -334,6 +335,21 @@ function _M.get_body(max_size, ctx)
end


function _M.get_json_request_body_table()
local body, err = _M.get_body()
if not body then
return nil, { message = "could not get body: " .. (err or "request body is empty") }
end

local body_tab, err = json.decode(body)
if not body_tab then
return nil, { message = "could not get parse JSON request body: " .. err }
end

return body_tab
end


function _M.get_scheme(ctx)
if not ctx then
ctx = ngx.ctx.api_ctx
Expand Down
25 changes: 12 additions & 13 deletions apisix/plugins/ai-rag.lua
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,15 @@ function _M.check_schema(conf)
return core.schema.check(schema, conf)
end


function _M.access(conf, ctx)
-- local conf = conf.rag
-- if conf then
local httpc = http.new()
local body_tab = core.request.get_json_request_body_table()

local body_tab, err = core.request.get_json_request_body_table()
if not body_tab then
return 400, err
end
if not body_tab["ai_rag"] then
core.log.error("request body must have \"ai-rag\" field")
return 400
Expand All @@ -128,25 +131,20 @@ function _M.access(conf, ctx)
return 400
end

local embeddings, err = embeddings_driver.get_embeddings(embeddings_provider_conf, body_tab["ai_rag"].embeddings, httpc)
local embeddings, status, err = embeddings_driver.get_embeddings(embeddings_provider_conf,
body_tab["ai_rag"].embeddings, httpc)
if not embeddings then
-- TODO: bring order
core.log.error("could not get embeddings: ", err)
return 500
return status, err
end
core.log.error("dibag err: ", err)
core.log.warn("dibag res: ", core.json.encode(embeddings))

local search_body = body_tab["ai_rag"].vector_search
search_body.embeddings = embeddings
local res, err = vector_search_driver.search(vector_search_provider_conf, search_body, httpc)
local res, status, err = vector_search_driver.search(vector_search_provider_conf, search_body, httpc)
if not res then
-- TODO: bring order
core.log.error("could not get vector_search: ", err)
return 500
core.log.error("could not get vector_search result: ", err)
return status, err
end
core.log.error("dibag err: ", err)
core.log.warn("dibag res: ", core.json.encode(res, true))

body_tab["ai_rag"] = nil
local prepend = {
Expand All @@ -166,4 +164,5 @@ function _M.access(conf, ctx)
ngx_req.set_body_data(req_body_json)
end


return _M
15 changes: 10 additions & 5 deletions apisix/plugins/ai-rag/embeddings/azure_openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
-- limitations under the License.
--
local core = require("apisix.core")

local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local _M = {}

function _M.get_embeddings(conf, body, httpc)

function _M.get_embeddings(conf, body, httpc)
local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
Expand All @@ -32,23 +32,28 @@ function _M.get_embeddings(conf, body, httpc)
return nil, err
end

if res.status ~= 200 then
return nil, res.status, res.body
end

local res_tab, err = core.json.decode(res.body)
if not res_tab then
return nil, err
return nil, internal_server_error, err
end

if type(res_tab.data) ~= "table" or #res_tab.data < 1 then
return nil, res.body
return nil, internal_server_error, res.body
end

local embeddings, err = core.json.encode(res_tab.data[1].embedding)
if not embeddings then
return nil, err
return nil, internal_server_error, err
end

return res_tab.data[1].embedding
end


_M.request_schema = {
type = "object",
properties = {
Expand Down
39 changes: 8 additions & 31 deletions apisix/plugins/ai-rag/vector-search/azure_ai_search.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
-- limitations under the License.
--
local core = require("apisix.core")
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR

local _M = {}


function _M.search(conf, search_body, httpc)
local body = {
vectorQueries = {
Expand All @@ -29,18 +31,6 @@ function _M.search(conf, search_body, httpc)
}
}
local final_body = core.json.encode(body)
local sb = core.json.encode(search_body)
core.log.warn("dibag final body: ", final_body)
core.log.warn("dibag final body: ", final_body)
-- [[{
-- "vectorQueries": [
-- {
-- "kind": "vector",
-- "vector": ]].. embeddings .. [[,
-- "fields": "contentVector"
-- }
-- ]
-- }]]
local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
Expand All @@ -51,29 +41,16 @@ function _M.search(conf, search_body, httpc)
})

if not res or not res.body then
return nil, err
return nil, internal_server_error, err
end

if res.status ~= 200 then
return nil, res.status, res.body
end

return res.body, err
return res.body
end

-- _M.request_schema = {
-- type = "object",
-- properties = {
-- vectorQueries = {
-- type = "array",
-- items = {
-- type = "object",
-- properties = {
-- fields = {
-- type = "string"
-- }
-- },
-- required = { "fields" }
-- },
-- },
-- },
-- }

_M.request_schema = {
type = "object",
Expand Down
25 changes: 25 additions & 0 deletions t/assets/embeddings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
123456789,
0.01902593,
0.008967914,
-0.013226582,
-0.026961878,
-0.017892223,
-0.0007785152,
-0.011031842,
0.0068531134
]
}
],
"model": "text-embedding-3-small",
"usage": {
"prompt_tokens": 4,
"total_tokens": 4
}
}
Loading

0 comments on commit e27ed8d

Please sign in to comment.