diff --git a/changelog/unreleased/kong/feat-add-llm-bedrock-agent-runtime-sdk.yml b/changelog/unreleased/kong/feat-add-llm-bedrock-agent-runtime-sdk.yml new file mode 100644 index 00000000000..c8d7d91c739 --- /dev/null +++ b/changelog/unreleased/kong/feat-add-llm-bedrock-agent-runtime-sdk.yml @@ -0,0 +1,3 @@ +message: "**llm**: Added support for Bedrock agent SDK, including rerand, converse, converse-stream, retrieveAndGenerate, and retrieveAndGenerate-stream." +type: "feature" +scope: "Plugin" diff --git a/changelog/unreleased/kong/feat-add-support-for-huggingface-serverless-inference-provider.yml b/changelog/unreleased/kong/feat-add-support-for-huggingface-serverless-inference-provider.yml new file mode 100644 index 00000000000..82ceb15a88d --- /dev/null +++ b/changelog/unreleased/kong/feat-add-support-for-huggingface-serverless-inference-provider.yml @@ -0,0 +1,4 @@ +message: | + **ai-proxy**: Added support for HuggingFace's new serverless API in the AI Proxy plugin, enabling seamless integration and improved compatibility. +type: "feature" +scope: "Plugin" \ No newline at end of file diff --git a/changelog/unreleased/kong/fix-ai-analytics.yml b/changelog/unreleased/kong/fix-ai-analytics.yml new file mode 100644 index 00000000000..66764bf4f23 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-analytics.yml @@ -0,0 +1,3 @@ +message: Fixed an issue where some of ai metrics was missed in analytics +scope: Plugin +type: bugfix diff --git a/changelog/unreleased/kong/fix-ai-bedrock-tool-use.yml b/changelog/unreleased/kong/fix-ai-bedrock-tool-use.yml new file mode 100644 index 00000000000..2535c2eabdc --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-bedrock-tool-use.yml @@ -0,0 +1,4 @@ +message: > + **AI Plugins**: Fixed an issue where Bedrock (Converse API) didn't properly parse multiple `toolUse` tool calls in one turn, on some models. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-ai-driver-upstream-scheme-400.yml b/changelog/unreleased/kong/fix-ai-driver-upstream-scheme-400.yml new file mode 100644 index 00000000000..76b7fda37b4 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-driver-upstream-scheme-400.yml @@ -0,0 +1,4 @@ +message: > + **ai-driver**: Fixed an issue where AI Proxy and AI Proxy Advanced generated incorrect default ports for upstream schemes, leading to 400 errors. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-ai-gcp-model-armor-floor-blocked.yml b/changelog/unreleased/kong/fix-ai-gcp-model-armor-floor-blocked.yml new file mode 100644 index 00000000000..e0859d0f43b --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-gcp-model-armor-floor-blocked.yml @@ -0,0 +1,4 @@ +message: > + **AI Plugins**: Fixed an issue where the Gemini provider would not correctly return Model Armor 'Floor' blocking responses to the caller. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-ai-log-for-huge-request-payload.yml b/changelog/unreleased/kong/fix-ai-log-for-huge-request-payload.yml new file mode 100644 index 00000000000..f5ce31a652f --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-log-for-huge-request-payload.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where large request payload was not logged. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-plugins-using-capture-groups.yml b/changelog/unreleased/kong/fix-ai-plugins-using-capture-groups.yml new file mode 100644 index 00000000000..ae44fd779a4 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-plugins-using-capture-groups.yml @@ -0,0 +1,4 @@ +message: | + **ai-proxy**: Fixed an issue where patterns using multiple capture groups (e.g., `$(group1)/$(group2)`) failed to extract expected matches. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-anthropic-completions-stop-turn-mapping.yml b/changelog/unreleased/kong/fix-ai-proxy-anthropic-completions-stop-turn-mapping.yml new file mode 100644 index 00000000000..5102f1851cd --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-anthropic-completions-stop-turn-mapping.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where the Anthropic provider mapped the wrong stop_reason to Chat and Completions responses. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-anthropic-streaming-tool-use.yml b/changelog/unreleased/kong/fix-ai-proxy-anthropic-streaming-tool-use.yml new file mode 100644 index 00000000000..1141c04e3aa --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-anthropic-streaming-tool-use.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where the Anthropic provider failed to stream function call responses. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-ai-proxy-anthropic-streaming-truncation.yml b/changelog/unreleased/kong/fix-ai-proxy-anthropic-streaming-truncation.yml new file mode 100644 index 00000000000..61c76a591ea --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-anthropic-streaming-truncation.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where the Anthropic provider may truncate tokens in streaming responses. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-ai-proxy-anthropic-structured-output.yml b/changelog/unreleased/kong/fix-ai-proxy-anthropic-structured-output.yml new file mode 100644 index 00000000000..0b037b37ce6 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-anthropic-structured-output.yml @@ -0,0 +1,3 @@ +message: "**AI Plugins**: Fixed an issue Structured Output did not work correctly with Anthropic Claude models." +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-bedrock-streaming-structured-output.yml b/changelog/unreleased/kong/fix-ai-proxy-bedrock-streaming-structured-output.yml new file mode 100644 index 00000000000..6a1f76973f2 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-bedrock-streaming-structured-output.yml @@ -0,0 +1,3 @@ +message: "**AI Plugins**: Fixed an issue where structured output did not work for the Bedrock provider when SSE streaming responses." +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-bedrock-structured-output.yml b/changelog/unreleased/kong/fix-ai-proxy-bedrock-structured-output.yml new file mode 100644 index 00000000000..063347aec92 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-bedrock-structured-output.yml @@ -0,0 +1,3 @@ +message: "**AI Plugins**: Fixed an issue Structured Output did not work correctly with Bedrock (Converse API) models." +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-gemini-anthropic-rawpredict.yml b/changelog/unreleased/kong/fix-ai-proxy-gemini-anthropic-rawpredict.yml new file mode 100644 index 00000000000..bf45e8189b7 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-gemini-anthropic-rawpredict.yml @@ -0,0 +1,3 @@ +message: "**AI Plugins**: Fixed an issue where the Gemini provider could not use Anthropic 'rawPredict' endpoint models hosted in Vertex." +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-gemini-id-created-missing.yml b/changelog/unreleased/kong/fix-ai-proxy-gemini-id-created-missing.yml new file mode 100644 index 00000000000..de9fba21bfe --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-gemini-id-created-missing.yml @@ -0,0 +1,3 @@ +message: "**AI Plugins**: Fixed an issue where `id`, `created`, and/or `finish_reason` fields were missing from or incorrect in Gemini responses." +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-gemini-model-missing.yml b/changelog/unreleased/kong/fix-ai-proxy-gemini-model-missing.yml new file mode 100644 index 00000000000..d0d33b2e1e6 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-gemini-model-missing.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed an issue where `model` field was missing in OpenAI format response from Gemini provider in some situations." +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-gemini-tool-calls.yml b/changelog/unreleased/kong/fix-ai-proxy-gemini-tool-calls.yml new file mode 100644 index 00000000000..07e54640c8b --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-gemini-tool-calls.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where the Gemini provider did not properly handle multiple tool calls across different turns. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-latency-for-streaming.yml b/changelog/unreleased/kong/fix-ai-proxy-latency-for-streaming.yml new file mode 100644 index 00000000000..38395f9e7eb --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-latency-for-streaming.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where latency metric was not implemented for streaming responses. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-structured-output-empty-array.yml b/changelog/unreleased/kong/fix-ai-proxy-structured-output-empty-array.yml new file mode 100644 index 00000000000..c70bb87369b --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-structured-output-empty-array.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed an issue when empty array in structured output was encoded as empty object." +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-proxy-structured-output.yml b/changelog/unreleased/kong/fix-ai-proxy-structured-output.yml new file mode 100644 index 00000000000..836cd11793e --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-proxy-structured-output.yml @@ -0,0 +1,3 @@ +message: "**AI Plugins**: Fixed an issue where Structured-Output requests were ignored." +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-ai-retry-plugin-ctx-incorrect-ns.yml b/changelog/unreleased/kong/fix-ai-retry-plugin-ctx-incorrect-ns.yml new file mode 100644 index 00000000000..db693b33453 --- /dev/null +++ b/changelog/unreleased/kong/fix-ai-retry-plugin-ctx-incorrect-ns.yml @@ -0,0 +1,4 @@ +message: | + **ai-proxy-advanced**: Fixed an issue where ai-retry phase was not correctly setting the namespace in the kong.plugin.ctx, causing ai-proxy-advanced balancer retry first target more than once. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-aws-stream-parse-incomplete-frame.yml b/changelog/unreleased/kong/fix-aws-stream-parse-incomplete-frame.yml new file mode 100644 index 00000000000..fd6f6b48950 --- /dev/null +++ b/changelog/unreleased/kong/fix-aws-stream-parse-incomplete-frame.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where aws stream parser didn't parse correctly when the frame was incomplete. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-gemini-vertex.yml b/changelog/unreleased/kong/fix-gemini-vertex.yml new file mode 100644 index 00000000000..12d8e3feea1 --- /dev/null +++ b/changelog/unreleased/kong/fix-gemini-vertex.yml @@ -0,0 +1,4 @@ +message: | + **Semantic Plugins*: Fixed an issue where Gemini Vertex AI embeddings failed due to incorrect URL construction and response parsing. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-huggingface-embedding-response-inproperly-parsed.yml b/changelog/unreleased/kong/fix-huggingface-embedding-response-inproperly-parsed.yml new file mode 100644 index 00000000000..38028b761c1 --- /dev/null +++ b/changelog/unreleased/kong/fix-huggingface-embedding-response-inproperly-parsed.yml @@ -0,0 +1,4 @@ +message: | + **ai-proxy**: Fixed an issue where HuggingFace embedding driver were incorrectly parsed responses from embedding API +type: bugfix +scope: Plugin \ No newline at end of file diff --git a/changelog/unreleased/kong/fix-json-array-iterator.yml b/changelog/unreleased/kong/fix-json-array-iterator.yml new file mode 100644 index 00000000000..eeaa2b871c9 --- /dev/null +++ b/changelog/unreleased/kong/fix-json-array-iterator.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where JSON array data is not correctly parsed. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-llm-failover-bedrock.yml b/changelog/unreleased/kong/fix-llm-failover-bedrock.yml new file mode 100644 index 00000000000..eadad986994 --- /dev/null +++ b/changelog/unreleased/kong/fix-llm-failover-bedrock.yml @@ -0,0 +1,4 @@ +message: > + Fixed an issue where AI Proxy and AI Proxy Advanced can't properly failover to a Bedrock target. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-llm-sse-parser-with-truncated-sse.yml b/changelog/unreleased/kong/fix-llm-sse-parser-with-truncated-sse.yml new file mode 100644 index 00000000000..581eb528b6f --- /dev/null +++ b/changelog/unreleased/kong/fix-llm-sse-parser-with-truncated-sse.yml @@ -0,0 +1,4 @@ +message: > + Fixed an issue where AI Proxy and AI Proxy Advanced might produce duplicate content in the response when the SSE event was truncated. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-llm-sse-parser-with-truncated-sse2.yml b/changelog/unreleased/kong/fix-llm-sse-parser-with-truncated-sse2.yml new file mode 100644 index 00000000000..74f63088fe2 --- /dev/null +++ b/changelog/unreleased/kong/fix-llm-sse-parser-with-truncated-sse2.yml @@ -0,0 +1,4 @@ +message: > + Fixed an issue where AI Proxy and AI Proxy Advanced might drop content in the response when the SSE event was truncated. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix-sse-terminator.yml b/changelog/unreleased/kong/fix-sse-terminator.yml new file mode 100644 index 00000000000..444209ee4b0 --- /dev/null +++ b/changelog/unreleased/kong/fix-sse-terminator.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where SSE terminator may not have correct ending characters. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix-wrong-body-for-observability.yml b/changelog/unreleased/kong/fix-wrong-body-for-observability.yml new file mode 100644 index 00000000000..047f310af30 --- /dev/null +++ b/changelog/unreleased/kong/fix-wrong-body-for-observability.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where the response body for observability may be larger than the real one because of the stale data. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix_ai-proxy-huggingface-request-extra-parameters-not-allowd.yml b/changelog/unreleased/kong/fix_ai-proxy-huggingface-request-extra-parameters-not-allowd.yml new file mode 100644 index 00000000000..47228cf3c7c --- /dev/null +++ b/changelog/unreleased/kong/fix_ai-proxy-huggingface-request-extra-parameters-not-allowd.yml @@ -0,0 +1,4 @@ +message: | + **AI Proxy**: Fixed an issue where extra inputs were not permitted for huggingface inference provider +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/fix_ai_embeddings_titan_embed_dimensions.yml b/changelog/unreleased/kong/fix_ai_embeddings_titan_embed_dimensions.yml new file mode 100644 index 00000000000..690a20489d6 --- /dev/null +++ b/changelog/unreleased/kong/fix_ai_embeddings_titan_embed_dimensions.yml @@ -0,0 +1,4 @@ +message: > + **AI Plugins**: Fixed unable to use Titan Text Embeddings models. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix_ai_gemini_model_garden.yml b/changelog/unreleased/kong/fix_ai_gemini_model_garden.yml new file mode 100644 index 00000000000..08f4d0d62ae --- /dev/null +++ b/changelog/unreleased/kong/fix_ai_gemini_model_garden.yml @@ -0,0 +1,4 @@ +message: > + "AI Proxy Advanced: Fixed an issue where gemini provider not support model garden" +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix_ai_prompt_decorator_content_length.yml b/changelog/unreleased/kong/fix_ai_prompt_decorator_content_length.yml new file mode 100644 index 00000000000..261dc0dbcd7 --- /dev/null +++ b/changelog/unreleased/kong/fix_ai_prompt_decorator_content_length.yml @@ -0,0 +1,4 @@ +message: | + Fixed an issue where the `ai-prompt-decorator` plugin prompt field was too short. +type: feature +scope: Plugin diff --git a/changelog/unreleased/kong/fix_ai_prompt_decorator_missing_fields.yml b/changelog/unreleased/kong/fix_ai_prompt_decorator_missing_fields.yml new file mode 100644 index 00000000000..2e6b7b9696a --- /dev/null +++ b/changelog/unreleased/kong/fix_ai_prompt_decorator_missing_fields.yml @@ -0,0 +1,4 @@ +message: | + **ai-prompt-decorator**: Fixed an issue where the plugin would miss the model, temperature, and other user-defined fields. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix_ai_proxy_gemini_driver_return_403.yml b/changelog/unreleased/kong/fix_ai_proxy_gemini_driver_return_403.yml new file mode 100644 index 00000000000..351ba11e3de --- /dev/null +++ b/changelog/unreleased/kong/fix_ai_proxy_gemini_driver_return_403.yml @@ -0,0 +1,4 @@ +message: > + **ai-proxy**: Fixed an issue where AI Proxy returns 403 using gemini provider. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix_ai_proxy_ollama_rag_and_tools.yml b/changelog/unreleased/kong/fix_ai_proxy_ollama_rag_and_tools.yml new file mode 100644 index 00000000000..5017dce1f35 --- /dev/null +++ b/changelog/unreleased/kong/fix_ai_proxy_ollama_rag_and_tools.yml @@ -0,0 +1,4 @@ +message: > + **ollama**: Fixed an issue where 'tools' and 'RAG' decoration did not work for Ollama (llama2) provider. +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/fix_mistral_remove_seed.yml b/changelog/unreleased/kong/fix_mistral_remove_seed.yml new file mode 100644 index 00000000000..8fcf2d8e4b0 --- /dev/null +++ b/changelog/unreleased/kong/fix_mistral_remove_seed.yml @@ -0,0 +1,4 @@ +message: | + **AI Proxy Advanced**: Fixed an issue where Mistral models would return `Unsupported field: seed` when using some inference libraries. +type: "bugfix" +scope: "Plugin" diff --git a/changelog/unreleased/kong/perf-parse-sse-chunk.yml b/changelog/unreleased/kong/perf-parse-sse-chunk.yml new file mode 100644 index 00000000000..5e92cf358d9 --- /dev/null +++ b/changelog/unreleased/kong/perf-parse-sse-chunk.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Implemented a faster SSE parser." +type: "performance" +scope: "Plugin" diff --git a/kong-latest.rockspec b/kong-latest.rockspec index b5f9806b264..b35de84f92b 100644 --- a/kong-latest.rockspec +++ b/kong-latest.rockspec @@ -664,6 +664,7 @@ build = { ["kong.llm.plugin.shared-filters.parse-json-response"] = "kong/llm/plugin/shared-filters/parse-json-response.lua", ["kong.llm.plugin.shared-filters.parse-request"] = "kong/llm/plugin/shared-filters/parse-request.lua", ["kong.llm.plugin.shared-filters.parse-sse-chunk"] = "kong/llm/plugin/shared-filters/parse-sse-chunk.lua", + ["kong.llm.plugin.shared-filters.save-request-body"] = "kong/llm/plugin/shared-filters/save-request-body.lua", ["kong.llm.plugin.shared-filters.serialize-analytics"] = "kong/llm/plugin/shared-filters/serialize-analytics.lua", ["kong.llm.adapters.bedrock"] = "kong/llm/adapters/bedrock.lua", diff --git a/kong/llm/adapters/bedrock.lua b/kong/llm/adapters/bedrock.lua index 3ec0d72a95d..7643ad5ee05 100644 --- a/kong/llm/adapters/bedrock.lua +++ b/kong/llm/adapters/bedrock.lua @@ -1,6 +1,7 @@ local cjson = require("cjson.safe") local fmt = string.format + local _BedrockAdapter = {} _BedrockAdapter.role_map = { @@ -305,10 +306,51 @@ function _BedrockAdapter:to_kong_req(bedrock_table, kong) openai_table.tools = self:extract_tools(bedrock_table.toolConfig.tools) end + local url = kong.request.get_path() + if url:find("/rerank", 1, true) then + self.forward_path = "/rerank" + openai_table.messages = bedrock_table.queries + return openai_table + elseif url:find("/retrieveAndGenerateStream", 1, true) then + self.forward_path = "/retrieveAndGenerateStream" + openai_table.prompt = bedrock_table.input and bedrock_table.input.text + openai_table.stream = true + return openai_table + elseif url:find("/retrieveAndGenerate", 1, true) then + self.forward_path = "/retrieveAndGenerate" + openai_table.prompt = bedrock_table.input and bedrock_table.input.text + return openai_table + elseif url:find('converse-stream',1, true) then + self.forward_path = "/model/%s/converse-stream" + openai_table.stream = true + return openai_table + elseif url:find("converse",1, true) then + self.forward_path = "/model/%s/converse" + return openai_table + end return openai_table end +function _BedrockAdapter:get_forwarded_path(model) + if not self.forward_path then + return + end + + -- if the forward path is a string, it means we have to format it with the model name + if type(self.forward_path) == "string" then + local forward_path = fmt(self.forward_path, model) + if self.forward_path == "/rerank" or + self.forward_path == "/retrieveAndGenerate" or + self.forward_path == "/retrieveAndGenerateStream" then + return forward_path, "bedrock_agent" + end + return forward_path, "bedrock" + end + + return +end + -- for unit tests if _G.TEST then _BedrockAdapter._set_kong = function(this_kong) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 6540654fffd..d20fdd129a9 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -1,7 +1,7 @@ local _M = {} -- imports -local cjson = require("cjson.safe") +local cjson = require("kong.tools.cjson") local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require "socket.url" @@ -12,12 +12,21 @@ local ai_plugin_ctx = require("kong.llm.plugin.ctx") -- globals local DRIVER_NAME = "anthropic" +local get_global_ctx, set_global_ctx = ai_plugin_ctx.get_global_accessors(DRIVER_NAME) -- local function kong_prompt_to_claude_prompt(prompt) return fmt("Human: %s\n\nAssistant:", prompt) end +local _OPENAI_STOP_REASON_MAPPING = { + ["max_tokens"] = "length", + ["end_turn"] = "stop", + ["tool_use"] = "tool_calls", + ["guardrail_intervened"] = "guardrail_intervened", + ["stop_sequence"] = "stop" +} + local function kong_messages_to_claude_prompt(messages) local buf = buffer.new() @@ -120,6 +129,7 @@ local function to_tools(in_tools) if v['function'] then v['function'].input_schema = v['function'].parameters v['function'].parameters = nil + v['function'].strict = nil -- unsupported table.insert(out_tools, v['function']) end @@ -155,6 +165,33 @@ local function to_tool_choice(openai_tool_choice) return nil end +local function extract_structured_content_schema(request_table) + -- bounds check EVERYTHING first + if not request_table + or not request_table.response_format + or type(request_table.response_format.type) ~= "string" + or type(request_table.response_format.json_schema) ~= "table" + or type(request_table.response_format.json_schema.schema) ~= "table" + then + return nil + end + + -- return + return request_table.response_format.json_schema.schema +end + +local function read_content(delta) + if delta then + if delta.delta and delta.delta.text then + return delta.delta.text + end + + -- reserve for more content types here later + end + + return "" +end + local transformers_to = { ["llm/v1/chat"] = function(request_table, model) local messages = {} @@ -174,6 +211,32 @@ local transformers_to = { messages.tools = request_table.tools and to_tools(request_table.tools) messages.tool_choice = request_table.tool_choice and to_tool_choice(request_table.tool_choice) + local structured_content_schema = extract_structured_content_schema(request_table) + if structured_content_schema then + set_global_ctx("structured_output_mode", true) + + -- make a tool call that covers all responses, adhering to specific json schema + -- we will extract this tool call into the OpenAI "response content" later + messages.tools = messages.tools or {} + messages.tools[#messages.tools + 1] = { + name = ai_shared._CONST.STRUCTURED_OUTPUT_TOOL_NAME, + description = "Responds with strict structured output.", + input_schema = structured_content_schema + } + + -- force the json tool use + messages.tool_choice = { + name = ai_shared._CONST.STRUCTURED_OUTPUT_TOOL_NAME, + type = "tool", + } + end + + -- these are specific customizations for when another provider calls this transformer + if model.provider == "gemini" then + messages.anthropic_version = request_table.anthropic_version + messages.model = nil -- gemini throws an error for some reason if model is in the body (it's in the URL) + end + return messages, "application/json", nil end, @@ -196,30 +259,128 @@ local transformers_to = { } local function delta_to_event(delta, model_info) - local data = { - choices = { - [1] = { - delta = { - content = (delta.delta - and delta.delta.text) - or (delta.content_block - and "") - or "", + local data + + local message_id = kong + and kong.ctx + and kong.ctx.plugin + and kong.ctx.plugin.ai_proxy_anthropic_stream_id + + if delta['type'] and delta['type'] == "content_block_start" then + if delta.content_block then + if delta.content_block['type'] == "text" then + -- mark the entrypoint for messages + data = { + choices = { + [1] = { + delta = { + content = read_content(delta), + role = "assistant", + }, + index = 0, + finish_reason = cjson.null, + logprobs = cjson.null, + }, }, - index = 0, - finish_reason = cjson.null, - logprobs = cjson.null, - }, - }, - id = kong - and kong.ctx - and kong.ctx.plugin - and kong.ctx.plugin.ai_proxy_anthropic_stream_id, - model = model_info.name, - object = "chat.completion.chunk", - } + id = message_id, + model = model_info.name, + object = "chat.completion.chunk", + } - return cjson.encode(data), nil, nil + elseif delta.content_block['type'] == "tool_use" then + -- this is an ENTRYPOINT for a function call + data = { + choices = { + [1] = { + delta = { + content = cjson.null, + role = "assistant", + tool_calls = { + { + index = 0, + id = delta.content_block['id'], + ['type'] = "function", + ['function'] = { + name = delta.content_block['name'], + arguments = "", + } + } + } + }, + index = 0, + finish_reason = cjson.null, + logprobs = cjson.null, + }, + }, + id = message_id, + model = model_info.name, + object = "chat.completion.chunk", + } + end + end + + else + local delta_type + + if delta.delta + and delta.delta['type'] + then + delta_type = delta.delta['type'] + end + + if delta_type then + if delta_type == "text_delta" then + -- handle generic chat + data = { + choices = { + [1] = { + delta = { + content = read_content(delta), + }, + index = 0, + finish_reason = cjson.null, + logprobs = cjson.null, + }, + }, + id = message_id, + model = model_info.name, + object = "chat.completion.chunk", + } + + elseif delta_type == "input_json_delta" then + -- this is an ARGS DELTA for a function call + data = { + choices = { + [1] = { + delta = { + tool_calls = { + { + index = 0, + ['function'] = { + arguments = delta.delta.partial_json, + } + } + } + }, + index = 0, + finish_reason = cjson.null, + logprobs = cjson.null, + }, + }, + id = message_id, + model = model_info.name, + object = "chat.completion.chunk", + } + + end + end + end + + if data then + return cjson.encode(data) + end + + return end local function start_to_event(event_data, model_info) @@ -231,7 +392,7 @@ local function start_to_event(event_data, model_info) completion_tokens = meta.usage and meta.usage.output_tokens, model = meta.model, - stop_reason = meta.stop_reason, + stop_reason = _OPENAI_STOP_REASON_MAPPING[meta.stop_reason or "end_turn"], stop_sequence = meta.stop_sequence, } @@ -260,7 +421,7 @@ end local function handle_stream_event(event_t, model_info, route_type) local event_id = event_t.event - local event_data = cjson.decode(event_t.data) + local event_data = cjson.decode_with_array_mt(event_t.data) if not event_id or not event_data then return nil, "transformation to stream event failed or empty stream event received", nil @@ -280,24 +441,71 @@ local function handle_stream_event(event_t, model_info, route_type) -- last few frames / iterations if event_data and event_data.usage then + local completion_tokens = event_data.usage.output_tokens or 0 + local stop_reason = event_data.delta + and event_data.delta.stop_reason + and _OPENAI_STOP_REASON_MAPPING[event_data.delta.stop_reason or "end_turn"] + + if get_global_ctx("structured_output_mode") then + stop_reason = _OPENAI_STOP_REASON_MAPPING["end_turn"] + end + + local stop_sequence = event_data.delta + and event_data.delta.stop_sequence + return nil, nil, { prompt_tokens = nil, - completion_tokens = event_data.usage.output_tokens, - stop_reason = event_data.delta - and event_data.delta.stop_reason, - stop_sequence = event_data.delta - and event_data.delta.stop_sequence, + completion_tokens = completion_tokens, + stop_reason = stop_reason, + stop_sequence = stop_sequence, } else return nil, "message_delta is missing the metadata block", nil end elseif event_id == "content_block_start" then - -- content_block_start is just an empty string and indicates - -- that we're getting an actual answer + -- if this content block is starting a structure output tool call, + -- we convert it into a standard text block start marker + if get_global_ctx("structured_output_mode") then + if event_data.content_block + and event_data.content_block.type == "tool_use" + and event_data.content_block.name == ai_shared._CONST.STRUCTURED_OUTPUT_TOOL_NAME + then + return delta_to_event( + { + content_block = { + text = "", + type = "text", + }, + index = 0, + type = "content_block_start", + } + , model_info) + end + end + return delta_to_event(event_data, model_info) elseif event_id == "content_block_delta" then + -- if this content block is continuing a structure output tool call, + -- we convert it into a standard text block + if get_global_ctx("structured_output_mode") then + if event_data.delta + and event_data.delta.type == "input_json_delta" + then + return delta_to_event( + { + delta = { + text = event_data.delta.partial_json or "", + type = "text_delta", + }, + index = 0, + type = "content_block_delta", + } + , model_info) + end + end + return delta_to_event(event_data, model_info) elseif event_id == "message_stop" then @@ -311,7 +519,7 @@ end local transformers_from = { ["llm/v1/chat"] = function(response_string) - local response_table, err = cjson.decode(response_string) + local response_table, err = cjson.decode_with_array_mt(response_string) if err then return nil, "failed to decode anthropic response" end @@ -374,7 +582,7 @@ local transformers_from = { content = extract_text_from_content(response_table.content), tool_calls = extract_tools_from_content(response_table.content) }, - finish_reason = response_table.stop_reason, + finish_reason = _OPENAI_STOP_REASON_MAPPING[response_table.stop_reason or "end_turn"], }, }, usage = usage, @@ -382,6 +590,13 @@ local transformers_from = { object = "chat.completion", } + -- check for structured output tool call + if get_global_ctx("structured_output_mode") then + -- if we have a structured output tool call, we need to convert it to a message + -- this is a workaround for the fact that Anthropic does not return structured output + res = ai_shared.convert_structured_output_tool(res) or res + end + return cjson.encode(res) else -- it's probably an error block, return generic error @@ -401,7 +616,7 @@ local transformers_from = { { index = 0, text = response_table.completion, - finish_reason = response_table.stop_reason, + finish_reason = _OPENAI_STOP_REASON_MAPPING[response_table.stop_reason or "end_turn"], }, }, model = response_table.model, @@ -567,9 +782,8 @@ function _M.configure_request(conf) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) - - + local default_port = (parsed_url.scheme == "https") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) kong.service.request.set_header("anthropic-version", model.options.anthropic_version) diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index f5e5a2f8dfb..5992dcf7a00 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -136,7 +136,8 @@ function _M.configure_request(conf) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local default_port = (parsed_url.scheme == "https" or parsed_url.scheme == "wss") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index 8e9df96b336..7f0b6ee4269 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -1,7 +1,7 @@ local _M = {} -- imports -local cjson = require("cjson.safe") +local cjson = require("kong.tools.cjson") local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require("socket.url") @@ -13,7 +13,7 @@ local ai_plugin_ctx = require("kong.llm.plugin.ctx") -- globals local DRIVER_NAME = "bedrock" -local get_global_ctx, _ = ai_plugin_ctx.get_global_accessors(DRIVER_NAME) +local get_global_ctx, set_global_ctx = ai_plugin_ctx.get_global_accessors(DRIVER_NAME) -- local _OPENAI_ROLE_MAPPING = { @@ -78,7 +78,7 @@ local function to_tools(in_tools) if v['function'] then out_tools = out_tools or {} - out_tools[i] = { + table_insert(out_tools, { toolSpec = { name = v['function'].name, description = v['function'].description, @@ -86,7 +86,7 @@ local function to_tools(in_tools) json = v['function'].parameters, }, }, - } + }) end end @@ -122,6 +122,23 @@ local function from_tool_call_response(content) return tools_used end +local function new_chat_chunk_event(model_info) + local event = { + choices = { + { + delta = {}, + index = 0, + logprobs = cjson.null, + }, + }, + model = model_info.name, + object = "chat.completion.chunk", + system_fingerprint = cjson.null, + } + + return event +end + local function handle_stream_event(event_t, model_info, route_type) local new_event, metadata @@ -144,20 +161,10 @@ local function handle_stream_event(event_t, model_info, route_type) end if event_type == "messageStart" then - new_event = { - choices = { - [1] = { - delta = { - content = "", - role = body.role, - }, - index = 0, - logprobs = cjson.null, - }, - }, - model = model_info.name, - object = "chat.completion.chunk", - system_fingerprint = cjson.null, + new_event = new_chat_chunk_event(model_info) + new_event.choices[1].delta = { + content = "", + role = body.role, } elseif event_type == "contentBlockStart" then @@ -166,88 +173,84 @@ local function handle_stream_event(event_t, model_info, route_type) local tool_name = body.start.toolUse.name local tool_id = body.start.toolUse.toolUseId - new_event = { - choices = { - [1] = { - delta = { - tool_calls = { - { - index = body.contentBlockIndex, - id = tool_id, - ['function'] = { - name = tool_name, - arguments = "", - }, - } + if get_global_ctx("structured_output_mode") and (tool_name == ai_shared._CONST.STRUCTURED_OUTPUT_TOOL_NAME) then + -- structured output tool call: return as if we're starting content + new_event = new_chat_chunk_event(model_info) + new_event.choices[1].delta = { + content = "", + role = "assistant", + } + + else + new_event = new_chat_chunk_event(model_info) + new_event.choices[1].delta = { + tool_calls = { + { + index = body.contentBlockIndex, + id = tool_id, + ['function'] = { + name = tool_name, + arguments = "", } - }, - index = 0, - logprobs = cjson.null, - }, - }, - model = model_info.name, - object = "chat.completion.chunk", - system_fingerprint = cjson.null, - } + } + } + } + end end elseif event_type == "contentBlockDelta" then -- check for async streamed tool parameters if body.delta and body.delta.toolUse then - new_event = { - choices = { - [1] = { - delta = { - tool_calls = { - { - index = body.contentBlockIndex, - ['function'] = { - arguments = body.delta.toolUse.input, - }, - } - } - }, - index = 0, - logprobs = cjson.null, - }, - }, - model = model_info.name, - object = "chat.completion.chunk", - system_fingerprint = cjson.null, - } + + if get_global_ctx("structured_output_mode") then + -- structured output tool call: return as if we're starting content + new_event = new_chat_chunk_event(model_info) + new_event.choices[1].delta = { + content = (body.delta and + body.delta.toolUse + and body.delta.toolUse.input) + or "" + } + + else + new_event = new_chat_chunk_event(model_info) + new_event.choices[1].delta = { + tool_calls = { + { + index = body.contentBlockIndex, + ['function'] = { + arguments = body.delta.toolUse.input, + }, + } + } + } + end else - new_event = { - choices = { - [1] = { - delta = { - content = (body.delta - and body.delta.text) - or "", - }, - index = 0, - logprobs = cjson.null, - }, - }, - model = model_info.name, - object = "chat.completion.chunk", - system_fingerprint = cjson.null, + if get_global_ctx("structured_output_mode") then + -- return no frame to the caller, because the LLM might be + -- chatting over the top of our structured output tool result + + return + end + + new_event = new_chat_chunk_event(model_info) + new_event.choices[1].delta = { + content = (body.delta + and body.delta.text) + or "", } end elseif event_type == "messageStop" then - new_event = { - choices = { - [1] = { - delta = {}, - index = 0, - finish_reason = _OPENAI_STOP_REASON_MAPPING[body.stopReason] or "stop", - logprobs = cjson.null, - }, - }, - model = model_info.name, - object = "chat.completion.chunk", - } + if get_global_ctx("structured_output_mode") then + new_event = new_chat_chunk_event(model_info) + new_event.choices[1].finish_reason = ((body.stopReason and body.stopReason ~= "tool_use") and _OPENAI_STOP_REASON_MAPPING[body.stopReason]) or "stop" + + else + new_event = new_chat_chunk_event(model_info) + new_event.choices[1].finish_reason = _OPENAI_STOP_REASON_MAPPING[body.stopReason] or "stop" + end elseif event_type == "metadata" then metadata = { @@ -271,6 +274,33 @@ local function handle_stream_event(event_t, model_info, route_type) end end +local function extract_structured_content_schema(request_table) + -- bounds check EVERYTHING first + if not request_table + or not request_table.response_format + or type(request_table.response_format.type) ~= "string" + or type(request_table.response_format.json_schema) ~= "table" + or type(request_table.response_format.json_schema.schema) ~= "table" + then + return nil + end + + -- return + return request_table.response_format.json_schema.schema +end + +local function schema_to_toolspec(structured_content_schema) + return { + toolSpec = { + name = ai_shared._CONST.STRUCTURED_OUTPUT_TOOL_NAME, + description = "Responds with strict structured output.", + inputSchema = { + json = structured_content_schema + } + } + } +end + local function to_bedrock_chat_openai(request_table, model_info, route_type) if not request_table then local err = "empty request table received for transformation" @@ -293,23 +323,43 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) system_prompts[#system_prompts+1] = { text = v.content } elseif v.role and v.role == "tool" then - local tool_literal_content - local tool_execution_content, err = cjson.decode(v.content) - if err then - return nil, nil, "failed to decode function response arguments, not JSON format" - end + -- To mimic OpenAI behaviour, but also support Bedrock, here's what we do: + ---- If it's a string, and JSON decode fails, just treat as a literal reply in "text" field. + ---- If it's a string, and JSON decode succeeds, set that object as "json" field, and let Kong re-encode it later. + ---- If it's already a deep JSON object from the caller, set that straight to "json" as well, and let Kong re-encode it later. + -- https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use-inference-call.html#tool-calllng-make-tool-request + local tool_content + local tool_literal_content = v.content + if type(tool_literal_content) == "string" then + -- try + local tool_decoded_content, err = cjson.decode(tool_literal_content) + if err then + -- catch + tool_content = { + text = tool_literal_content + } + + else + -- Lua bug? Catch if it decoded a single number from JSON for some reason + if type(tool_decoded_content) == "number" then + tool_content = { + text = tool_literal_content + } + else + tool_content = { + json = tool_decoded_content + } + end + + end - if type(tool_execution_content) == "table" then - tool_literal_content = { - json = tool_execution_content + elseif type(tool_literal_content) == "table" then + tool_content = { + json = tool_literal_content } else - tool_literal_content = { - json = { - result = tool_execution_content - } - } + return nil, nil, "failed to decode function response arguments: expecting string or nested-JSON" end local content = { @@ -317,18 +367,26 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) toolResult = { toolUseId = v.tool_call_id, content = { - tool_literal_content + tool_content }, status = v.status, }, }, } - new_r.messages = new_r.messages or {} - table_insert(new_r.messages, { - role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' - content = content, - }) + if i > 1 and ai_shared.is_tool_result_message(request_table.messages[i-1]) then + -- append to the previous message's 'content' array + local previous_content = new_r.messages[#new_r.messages].content or {} + + previous_content[#previous_content+1] = content[1] + new_r.messages[#new_r.messages].content = previous_content + else + new_r.messages = new_r.messages or {} + table_insert(new_r.messages, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + content = content, + }) + end else local content @@ -342,15 +400,14 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) return nil, nil, "failed to decode function response arguments from assistant's message, not JSON format" end - content = { - { - toolUse = { - toolUseId = tool.id, - name = tool['function'].name, - input = inputs, - }, - }, - } + content = content or {} + table_insert(content, { + toolUse = { + toolUseId = tool.id, + name = tool['function'].name, + input = inputs, + } + }) end else @@ -398,6 +455,26 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) new_r.toolConfig.tools = to_tools(request_table.tools) end + -- check for "structured output" mode + -- https://docs.aws.amazon.com/nova/latest/userguide/prompting-structured-output.html "Example 3" + local structured_content_schema = extract_structured_content_schema(request_table) + if structured_content_schema then + set_global_ctx("structured_output_mode", true) + + -- set the output schema as an available tool + new_r.toolConfig = new_r.toolConfig or {} + new_r.toolConfig.tools = new_r.toolConfig.tools or {} + + table_insert(new_r.toolConfig.tools, schema_to_toolspec(structured_content_schema)) + + -- force the use of that tool + new_r.toolChoice = { + tool = { + name = ai_shared._CONST.STRUCTURED_OUTPUT_TOOL_NAME + } + } + end + new_r.additionalModelRequestFields = request_table.bedrock and request_table.bedrock.additionalModelRequestFields and to_additional_request_fields(request_table) @@ -406,7 +483,7 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) end local function from_bedrock_chat_openai(response, model_info, route_type) - local response, err = cjson.decode(response) + local response, err = cjson.decode_with_array_mt(response) if err then local err_client = "failed to decode response from Bedrock" @@ -434,6 +511,13 @@ local function from_bedrock_chat_openai(response, model_info, route_type) client_response.object = "chat.completion" client_response.model = model_info.name + -- check for structured output tool call + if get_global_ctx("structured_output_mode") then + -- if we have a structured output tool call, we need to convert it to a message + -- this is a workaround for the fact that Bedrock-Converse does not support structured output + client_response = ai_shared.convert_structured_output_tool(client_response) or client_response + end + else -- probably a server fault or other unexpected response local err = "no generation candidates received from Bedrock, or max_tokens too short" ngx.log(ngx.ERR, "[bedrock] ", err) @@ -454,6 +538,7 @@ local function from_bedrock_chat_openai(response, model_info, route_type) return cjson.encode(client_response) end + local transformers_to = { ["llm/v1/chat"] = to_bedrock_chat_openai, } @@ -505,6 +590,7 @@ function _M.to_format(request_table, model_info, route_type) request_table, model_info ) + if err or (not ok) then return nil, nil, fmt("error transforming to %s://%s: %s", model_info.provider, route_type, err) end @@ -619,13 +705,22 @@ function _M.configure_request(conf, aws_sdk) return nil, "invalid model parameter" end - local operation = get_global_ctx("stream_mode") and "converse-stream" - or "converse" + local operation = conf.route_type ~= "llm/v1/embeddings" and (get_global_ctx("stream_mode") and "converse-stream" + or "converse") or nil + local bedrocks_driver = DRIVER_NAME + local llm_format_adapter = get_global_ctx("llm_format_adapter") + local forward_path, runtime_name + if llm_format_adapter then + forward_path, runtime_name = llm_format_adapter:get_forwarded_path(model.name) + if runtime_name then + bedrocks_driver = runtime_name + end + end local f_url = model.options and model.options.upstream_url if not f_url then -- upstream_url override is not set - local uri = fmt(ai_shared.upstream_url_format[DRIVER_NAME], aws_sdk.config.region) - local path = fmt( + local uri = fmt(ai_shared.upstream_url_format[bedrocks_driver], aws_sdk.config.region) + local path = forward_path or fmt( ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, model.name, operation) @@ -647,7 +742,8 @@ function _M.configure_request(conf, aws_sdk) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local default_port = (parsed_url.scheme == "https" or parsed_url.scheme == "wss") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) -- do the IAM auth and signature headers aws_sdk.config.signatureVersion = "v4" @@ -659,7 +755,7 @@ function _M.configure_request(conf, aws_sdk) path = parsed_url.path, host = parsed_url.host, port = tonumber(parsed_url.port) or 443, - body = kong.request.get_raw_body() + body = kong.request.get_raw_body(), } local signature, err = signer(aws_sdk.config, r) diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 55467c093e2..3ef80a83da8 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -519,7 +519,8 @@ function _M.configure_request(conf) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local default_port = (parsed_url.scheme == "https" or parsed_url.scheme == "wss") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index d09edb88292..641377a4315 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -1,29 +1,43 @@ local _M = {} -- imports -local cjson = require("cjson.safe") +local cjson = require("kong.tools.cjson") local fmt = string.format +local anthropic = require("kong.llm.drivers.anthropic") +local openai = require("kong.llm.drivers.openai") local ai_shared = require("kong.llm.drivers.shared") local socket_url = require("socket.url") local string_gsub = string.gsub local buffer = require("string.buffer") local table_insert = table.insert -local string_lower = string.lower +local table_concat = table.concat local ai_plugin_ctx = require("kong.llm.plugin.ctx") local ai_plugin_base = require("kong.llm.plugin.base") local pl_string = require "pl.stringx" - +local model_garden_openai_chat_path = "/v1/projects/%s/locations/%s/endpoints/%s/chat/completions" +local uuid = require("kong.tools.uuid").uuid -- -- globals local DRIVER_NAME = "gemini" -local get_global_ctx, _ = ai_plugin_ctx.get_global_accessors(DRIVER_NAME) +local get_global_ctx, set_global_ctx = ai_plugin_ctx.get_global_accessors(DRIVER_NAME) -- local _OPENAI_ROLE_MAPPING = { ["system"] = "system", ["user"] = "user", ["assistant"] = "model", + ["model"] = "assistant", + ["tool"] = "user", +} + +local _OPENAI_STOP_REASON_MAPPING = { + ["MAX_TOKENS"] = "length", + ["STOP"] = "stop", +} + +local _OPENAI_STRUCTURED_OUTPUT_TYPE_MAP = { + ["json_schema"] = "application/json", } local function to_gemini_generation_config(request_table) @@ -64,15 +78,6 @@ local function is_tool_content(content) and content.candidates[1].content.parts[1].functionCall end -local function is_function_call_message(message) - return message - and message.role - and message.role == "assistant" - and message.tool_calls - and type(message.tool_calls) == "table" - and #message.tool_calls > 0 -end - local function has_finish_reason(event) return event and event.candidates @@ -81,6 +86,117 @@ local function has_finish_reason(event) or nil end +local function extract_structured_content(request_table) + -- bounds check EVERYTHING first + if not request_table + or not request_table.response_format + or type(request_table.response_format.type) ~= "string" + or type(request_table.response_format.json_schema) ~= "table" + or type(request_table.response_format.json_schema.schema) ~= "table" + then + return nil + end + + -- transform + ---- no transformations for OpenAI-to-Gemini + + -- return + return _OPENAI_STRUCTURED_OUTPUT_TYPE_MAP[request_table.response_format.type], request_table.response_format.json_schema.schema +end + + +local function get_model_coordinates(model_name, stream_mode) + if not model_name then + return nil, "model_name must be set to get model coordinates" + end + + -- anthropic + if model_name:sub(1, 7) == "claude-" then + return { + publisher = "anthropic", + operation = stream_mode and "streamRawPredict" or "rawPredict", + } + + -- mistral + elseif model_name:sub(1, 8) == "mistral-" then + return { + publisher = "mistral", + operation = stream_mode and "streamRawPredict" or "rawPredict", + } + + -- ai21 (jamba) + elseif model_name:sub(1, 6) == "jamba-" then + return { + publisher = "ai21", + operation = stream_mode and "streamRawPredict" or "rawPredict", + } + + else + return { + publisher = "google", + operation = stream_mode and "streamGenerateContent" or "generateContent", + } + end +end + +-- assume 'not vertex mode' if the model options are not set properly +-- the plugin schema will prevent misconfugiration +local function is_vertex_mode(model) + return model + and model.options + and model.options.gemini + and model.options.gemini.api_endpoint + and model.options.gemini.project_id + and model.options.gemini.location_id + and true +end + +-- this will never be called unless is_vertex_mode(model) above is true +-- so the deep table checks are not needed +local function get_gemini_vertex_url(model, route_type, stream_mode) + if not model.options or not model.options.gemini then + return nil, "model.options.gemini.* options must be set for vertex mode" + end + + local forward_path, err + local llm_format_adapter = get_global_ctx("llm_format_adapter") + if llm_format_adapter then + local _ + forward_path, _, err = llm_format_adapter:get_forwarded_path(model.name, model.options) + if err then + ngx.log(ngx.WARN, "failed to get forwarded path from llm_format_adapter: ", err) + forward_path = nil + end + end + + if not forward_path then + local coordinates, err = get_model_coordinates(model.name, stream_mode) + if err then + return nil, err + end + + if model.options.gemini.endpoint_id and model.options.gemini.endpoint_id ~= ngx.null then + -- means gcp vertex garden + forward_path = fmt(model_garden_openai_chat_path, + model.options.gemini.project_id, + model.options.gemini.location_id, + model.options.gemini.endpoint_id) + else + forward_path = fmt(ai_shared.operation_map["gemini_vertex"][route_type].path, + model.options.gemini.project_id, + model.options.gemini.location_id, + coordinates.publisher, + model.name, + coordinates.operation) + end + + + end + + return fmt(ai_shared.upstream_url_format["gemini_vertex"], + model.options.gemini.api_endpoint) .. forward_path +end + local function handle_stream_event(event_t, model_info, route_type) -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then @@ -91,7 +207,7 @@ local function handle_stream_event(event_t, model_info, route_type) return ai_shared._CONST.SSE_TERMINATOR, nil, nil end - local event, err = cjson.decode(event_t.data) + local event, err = cjson.decode_with_array_mt(event_t.data) if err then ngx.log(ngx.WARN, "failed to decode stream event frame from gemini: ", err) return nil, "failed to decode stream event frame from gemini", nil @@ -101,9 +217,16 @@ local function handle_stream_event(event_t, model_info, route_type) if is_response_content(event) then local metadata = {} - metadata.finish_reason = finish_reason + metadata.finish_reason = (finish_reason and _OPENAI_STOP_REASON_MAPPING[finish_reason or "STOP"]) metadata.completion_tokens = event.usageMetadata and event.usageMetadata.candidatesTokenCount or 0 metadata.prompt_tokens = event.usageMetadata and event.usageMetadata.promptTokenCount or 0 + metadata.created, err = ai_shared.iso_8601_to_epoch(event.createTime or ai_shared._CONST.UNIX_EPOCH) + if err then + ngx.log(ngx.WARN, "failed to convert createTime to epoch: ", err, ", fallback to 1970-01-01T00:00:00Z") + metadata.created = 0 + end + + metadata.id = event.responseId local new_event = { model = model_info.name, @@ -114,18 +237,32 @@ local function handle_stream_event(event_t, model_info, route_type) role = "assistant", }, index = 0, - finish_reason = finish_reason + finish_reason = (finish_reason and _OPENAI_STOP_REASON_MAPPING[finish_reason or "STOP"]) }, }, } + if get_global_ctx("structured_output_mode") + and #event.candidates[1].content.parts[1].text > 2 + and event.candidates[1].content.parts[1].text == "[]" then + -- simulate OpenAI "refusal" where the question/answer doesn't fit the structured output schema + new_event.choices[1].delta.content = nil + new_event.choices[1].delta.refusal = "Kong: Vertex refused to answer the question, because it did not fit the structured output schema" + end + return cjson.encode(new_event), nil, metadata elseif is_tool_content(event) then local metadata = {} - metadata.finish_reason = finish_reason + metadata.finish_reason = _OPENAI_STOP_REASON_MAPPING[finish_reason or "STOP"] metadata.completion_tokens = event.usageMetadata and event.usageMetadata.candidatesTokenCount or 0 metadata.prompt_tokens = event.usageMetadata and event.usageMetadata.promptTokenCount or 0 + metadata.created, err = ai_shared.iso_8601_to_epoch(event.createTime or ai_shared._CONST.UNIX_EPOCH) + if err then + ngx.log(ngx.WARN, "failed to convert createTime to epoch: ", err, ", fallback to 1970-01-01T00:00:00Z") + metadata.created = 0 + end + metadata.id = event.responseId if event.candidates and #event.candidates > 0 then local new_event = { @@ -157,31 +294,7 @@ local function handle_stream_event(event_t, model_info, route_type) return cjson.encode(new_event), nil, metadata end - - - end -end - -local function to_tools(in_tools) - if not in_tools then - return nil - end - - local out_tools - - for i, v in ipairs(in_tools) do - if v['function'] then - out_tools = out_tools or { - [1] = { - function_declarations = {} - } - } - - out_tools[1].function_declarations[i] = v['function'] - end end - - return out_tools end @@ -269,10 +382,129 @@ local function openai_part_to_gemini_part(openai_part) return gemini_part end +local function extract_function_calls(message) + if message + and message.tool_calls + and type(message.tool_calls) == "table" + and #message.tool_calls > 0 + then + local function_calls = {} + + for _, tool_call in ipairs(message.tool_calls) do + if tool_call['type'] == "function" and tool_call['function'] then + local args, err = cjson.decode(tool_call['function'].arguments) + if err then + return nil, "failed to decode function response arguments from assistant's message, not JSON format" + end + + local gemini_tool_call = { + functionCall = { + name = tool_call['function'].name, + args = args, + }, + } + + -- Gemini expects function calls to be in the 'parts' array + table_insert(function_calls, gemini_tool_call) + end + end + + return #function_calls > 0 and function_calls or nil + + else + return nil + end +end + +local function to_tools(tools) + if not tools or type(tools) ~= "table" then + return nil + end + + local out_tools + + for i, tool in ipairs(tools) do + if tool.type == "function" and tool['function'] then + out_tools = out_tools or { + [1] = { + function_declarations = {} + } + } + + out_tools[1].function_declarations[i] = tool['function'] + end + end + + return out_tools +end + +local function openai_toolcontent_to_gemini_toolcontent(content) + if not content then + return { + result = true -- for now, we'll assume that OpenAI accepts empty result as "WAS EXECUTED" + } + end + + if type(content) == "string" then + -- try to decode json + local decoded_content, _ = cjson.decode(content) + if decoded_content and type(decoded_content) == "table" then + return decoded_content + + else + return { + result = content -- it's probably a string result on its own, emulate this OpenAI functionality for Gemini + } + end + end +end + +local function extract_function_called_name(tool_calls, tool_call_id) + if not tool_calls + or not tool_calls.tool_calls + or type(tool_calls.tool_calls) ~= "table" or #tool_calls.tool_calls < 1 then + return nil + end + + -- Find the tool_call_id from the tool_calls message previous + if tool_call_id then + for _, tool_call in ipairs(tool_calls.tool_calls) do + if tool_call.id == tool_call_id then + return tool_call['function'] and tool_call['function'].name + end + end + end + + return nil +end + +local function openai_toolresult_to_gemini_toolresult(tool_calls_msg, tool_result_message) + if not tool_result_message then + return nil, "tool result message is nil" + end + + local function_called_name = extract_function_called_name(tool_calls_msg, tool_result_message.tool_call_id) + local tool_result_content = openai_toolcontent_to_gemini_toolcontent(tool_result_message.content) + + if function_called_name and tool_result_content then + return { + functionResponse = { + name = function_called_name, -- JTODO: get the name from the previous message, link up the IDs + response = tool_result_content, + } + } + + else + return nil, "tool result message content does not match expected OpenAI format" + end +end + local function to_gemini_chat_openai(request_table, model_info, route_type) local new_r = {} if request_table then + local tool_calls_message_cache + if request_table.messages and #request_table.messages > 0 then local system_prompt @@ -283,39 +515,45 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) system_prompt = system_prompt or buffer.new() system_prompt:put(v.content or "") - elseif v.role and v.role == "tool" then - -- handle tool execution output - table_insert(new_r.contents, { - role = "function", - parts = { - { - function_response = { - response = { - content = { - v.content, - }, - }, - name = "get_product_info", - }, - }, - }, - }) + elseif ai_shared.is_tool_result_message(v) then + -- This will also have "content" as a string but + -- requires completely different out-of-sequence handling + + -- Is the previous message already a tool result? + -- Gemini expects all tool results from a previous message to + -- be in the same 'parts' array in its following reply message. + if i > 1 and ai_shared.is_tool_result_message(request_table.messages[i-1]) then + -- append to the previous message's 'parts' array + local previous_parts = new_r.contents[#new_r.contents].parts or {} + + local this_part, err = openai_toolresult_to_gemini_toolresult(tool_calls_message_cache, v) + if not this_part then + if not err then + err = "tool call result at position " .. i-1 .. " does not match expected OpenAI format" + end + return nil, nil, err + end - elseif is_function_call_message(v) then - -- treat specific 'assistant function call' tool execution input message - local function_calls = {} - for i, t in ipairs(v.tool_calls) do - function_calls[i] = { - function_call = { - name = t['function'].name, - }, - } - end + table_insert(previous_parts, this_part) + new_r.contents[#new_r.contents].parts = previous_parts + + else + -- This is the start of the function result blocks + tool_calls_message_cache = request_table.messages[i-1] -- hold this for next N "tool use" msgs as we iterate + + local this_part, err = openai_toolresult_to_gemini_toolresult(tool_calls_message_cache, v) + if not this_part then + if not err then + err = "tool call result at position " .. i-1 .. " does not match expected OpenAI format" + end + return nil, nil, err + end - table_insert(new_r.contents, { - role = "function", - parts = function_calls, - }) + table_insert(new_r.contents, { + role = _OPENAI_ROLE_MAPPING[v.role or "tool"], -- default to Gemini 'user' + parts = { this_part }, + }) + end else local this_parts = {} @@ -331,7 +569,7 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) } elseif type(v.content) == "table" then - if #v.content > 0 then -- check it has ome kind of array element + if #v.content > 0 then -- check it has some kind of array element for j, part in ipairs(v.content) do local this_part, err = openai_part_to_gemini_part(part) @@ -349,6 +587,18 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) end end + local function_calls, err = extract_function_calls(v) + if err then + return nil, nil, "failed to extract function calls from message at position " .. i-1 .. ": " .. err + end + + if function_calls then + for _, func_call in ipairs(function_calls) do + -- Gemini expects function calls to be in the 'parts' array + table_insert(this_parts, func_call) + end + end + table_insert(new_r.contents, { role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' parts = this_parts, @@ -374,6 +624,15 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) new_r.generationConfig = to_gemini_generation_config(request_table) + -- convert OpenAI structured output to Gemini, if it is specified + local response_mime_type, response_schema = extract_structured_content(request_table) + if response_mime_type and response_schema then + set_global_ctx("structured_output_mode", true) + new_r.generationConfig = new_r.generationConfig or {} + new_r.generationConfig.response_mime_type = response_mime_type + new_r.generationConfig.response_json_schema = response_schema + end + -- handle function calling translation from OpenAI format new_r.tools = request_table.tools and to_tools(request_table.tools) new_r.tool_config = request_table.tool_config @@ -382,10 +641,93 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) return new_r, "application/json", nil end +local function extract_response_finish_reason(response_candidate) + if response_candidate + and response_candidate.content + and response_candidate.content.parts + and type(response_candidate.content.parts) == "table" + and #response_candidate.content.parts > 0 + then + for _, part in ipairs(response_candidate.content.parts) do + if part.functionCall then + return "tool_calls" + end + end + end + + return "stop" +end + +local function feedback_to_kong_error(promptFeedback) + if promptFeedback + and type(promptFeedback) == "table" + then + return { + error = true, + message = promptFeedback.blockReasonMessage or cjson.null, + reason = promptFeedback.blockReason or cjson.null, + } + end + + return nil +end + +local function extract_response_tool_calls(response_candidate) + local tool_calls + + if response_candidate + and response_candidate.content + and response_candidate.content.parts + and type(response_candidate.content.parts) == "table" + and #response_candidate.content.parts > 0 + then + for _, part in ipairs(response_candidate.content.parts) do + if part.functionCall then + tool_calls = tool_calls or {} + + table_insert(tool_calls, { + ['id'] = "call_" .. uuid(), + ['type'] = "function", + ['function'] = { + ['name'] = part.functionCall.name, + ['arguments'] = cjson.encode(part.functionCall.args), + }, + }) + end + end + + return tool_calls + end + + return nil +end + +-- openai only supports a single string text part in response right now, +-- so concat all gemini response parts together +local function extract_response_content(response_candidate) + if response_candidate + and response_candidate.content + and response_candidate.content.parts + and type(response_candidate.content.parts) == "table" + and #response_candidate.content.parts > 0 + then + local content = {} + for _, part in ipairs(response_candidate.content.parts) do + if part.text then + table_insert(content, part.text) + end + end + + return #content > 0 and table_concat(content, "\\n") or nil + end + + return nil +end + local function from_gemini_chat_openai(response, model_info, route_type) local err if response and (type(response) == "string") then - response, err = cjson.decode(response) + response, err = cjson.decode_with_array_mt(response) end if err then @@ -394,10 +736,8 @@ local function from_gemini_chat_openai(response, model_info, route_type) return nil, err_client end - -- messages/choices table is only 1 size, so don't need to static allocate - local messages = {} - messages.choices = {} - messages.model = model_info.name -- openai format always contains the model name + local kong_response = {} + kong_response.model = model_info.name -- openai format always contains the model name if response.candidates and #response.candidates > 0 then -- for transformer plugins only @@ -410,48 +750,75 @@ local function from_gemini_chat_openai(response, model_info, route_type) return nil, err - elseif is_response_content(response) then - messages.choices[1] = { - index = 0, - message = { - role = "assistant", - content = response.candidates[1].content.parts[1].text, - }, - finish_reason = string_lower(response.candidates[1].finishReason), - } - messages.object = "chat.completion" - messages.model = model_info.name - - elseif is_tool_content(response) then - messages.choices[1] = { - index = 0, - message = { - role = "assistant", - tool_calls = {}, - }, - } + else + -- process 'content' messages one by one + for i, v in ipairs(response.candidates) do + kong_response.choices = kong_response.choices or {} + + kong_response.choices[i] = { + index = i - 1, + message = { + role = _OPENAI_ROLE_MAPPING[v.role or "model"], -- default to openai 'assistant' + content = extract_response_content(v), -- may be nil + tool_calls = extract_response_tool_calls(v), -- may be nil + }, + finish_reason = extract_response_finish_reason(v), + } + end - local function_call_responses = response.candidates[1].content.parts - for i, v in ipairs(function_call_responses) do - messages.choices[1].message.tool_calls[i] = - { - ['function'] = { - name = v.functionCall.name, - arguments = cjson.encode(v.functionCall.args), - }, - } + if get_global_ctx("structured_output_mode") + and #response.candidates[1].content.parts[1].text > 2 + and response.candidates[1].content.parts[1].text == "[]" then + -- simulate OpenAI "refusal" where the question/answer doesn't fit the structured output schema + kong_response.choices[1].message.content = nil + kong_response.choices[1].message.refusal = "Kong: Vertex refused to answer the question, because it did not fit the structured output schema" end + + kong_response.object = "chat.completion" + kong_response.model = response.modelVersion or model_info.name + + kong_response.created, err = ai_shared.iso_8601_to_epoch(response.createTime or ai_shared._CONST.UNIX_EPOCH) + if err then + ngx.log(ngx.WARN, "failed to convert createTime to epoch: ", err, ", fallback to 1970-01-01T00:00:00Z") + kong_response.created = 0 + end + + kong_response.id = response.responseId or "chatcmpl-" .. uuid() end -- process analytics if response.usageMetadata then - messages.usage = { + kong_response.usage = { prompt_tokens = response.usageMetadata.promptTokenCount, completion_tokens = response.usageMetadata.candidatesTokenCount, total_tokens = response.usageMetadata.totalTokenCount, } end + elseif response.promptFeedback then + kong_response = feedback_to_kong_error(response.promptFeedback) + + if get_global_ctx("stream_mode") then + set_global_ctx("blocked_by_guard", kong_response) + else + kong.response.set_status(400) -- safety call this in case we have already returned from e.g. another AI plugin + + -- This is duplicated DELIBERATELY - to avoid regression, + -- moving it outside of the block above may cause bugs that + -- we can't predict. + if response.usageMetadata and + (response.usageMetadata.promptTokenCount + or response.usageMetadata.candidatesTokenCount + or response.usageMetadata.totalTokenCount) + then + kong_response.usage = { + prompt_tokens = response.usageMetadata.promptTokenCount, + completion_tokens = response.usageMetadata.candidatesTokenCount, + total_tokens = response.usageMetadata.totalTokenCount, + } + end + end + else -- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) @@ -459,7 +826,7 @@ local function from_gemini_chat_openai(response, model_info, route_type) end - return cjson.encode(messages) + return cjson.encode(kong_response) end local transformers_to = { @@ -479,6 +846,28 @@ function _M.from_format(response_string, model_info, route_type) return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end + -- try to get the ACTUAL model in use + -- this is gemini-specific, as it supports many different drivers in one package + local ok, model_t = pcall(ai_plugin_ctx.get_request_model_table_inuse) + if not ok then + -- set back to the plugin config's passed in object + model_t = model_info + end + + local coordinates = get_model_coordinates(model_t.name) + + if coordinates and coordinates.publisher == "anthropic" then + -- use anthropic's transformer + return anthropic.from_format(response_string, model_info, route_type) + end + -- otherwise, use the Gemini transformer + + if model_info.options and model_info.options ~= ngx.null and model_info.options.upstream_url and model_info.options.upstream_url ~= ngx.null then + if string.find(model_info.options.upstream_url, "/endpoints/") then + return openai.from_format(response_string, model_info, route_type) + end + end + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) if not ok then err = response_string @@ -508,6 +897,31 @@ function _M.to_format(request_table, model_info, route_type) request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + local coordinates, err = get_model_coordinates(model_info.name) + if err then + return nil, nil, err + end + + if coordinates and coordinates.publisher == "anthropic" then + -- use anthropic transformer + request_table.anthropic_version = (model_info.options and model_info.options.anthropic_version) or (request_table.anthropic_version) or "vertex-2023-10-16" + assert(request_table.anthropic_version, "anthropic_version must be set for anthropic models") + request_table.model = nil + return anthropic.to_format(request_table, model_info, route_type) + end + -- otherwise, use the Gemini transformer + + if model_info.options and model_info.options ~= ngx.null and model_info.options.upstream_url and model_info.options.upstream_url ~= ngx.null then + -- vertex ai model garden most model are openai compatible + if string.find(model_info.options.upstream_url, "/endpoints/") then + request_table.model = nil + local req = openai.to_format(request_table, model_info, route_type) + -- request not accept model arg + req.model = nil + return req + end + end + local ok, response_object, content_type, err = pcall( transformers_to[route_type], request_table, @@ -535,42 +949,41 @@ function _M.subrequest(body, conf, http_opts, return_res_table, identity_interfa return nil, nil, "body must be table or string" end - local operation = get_global_ctx("stream_mode") and "streamGenerateContent" - or "generateContent" local f_url = conf.model.options and conf.model.options.upstream_url if not f_url then -- upstream_url override is not set -- check if this is "public" or "vertex" gemini deployment - if conf.model.options - and conf.model.options.gemini - and conf.model.options.gemini.api_endpoint - and conf.model.options.gemini.project_id - and conf.model.options.gemini.location_id - then - -- vertex mode - f_url = fmt(ai_shared.upstream_url_format["gemini_vertex"], - conf.model.options.gemini.api_endpoint) .. - fmt(ai_shared.operation_map["gemini_vertex"][conf.route_type].path, - conf.model.options.gemini.project_id, - conf.model.options.gemini.location_id, - conf.model.name, - operation) + if is_vertex_mode(conf.model) then + local err + f_url, err = get_gemini_vertex_url(conf.model, conf.route_type, get_global_ctx("stream_mode")) + if err then + return nil, "failed to calculate vertex URL: " .. err + end + else - -- public mode + -- 'consumer' Gemini mode f_url = ai_shared.upstream_url_format["gemini"] .. fmt(ai_shared.operation_map["gemini"][conf.route_type].path, conf.model.name, - operation) + (get_global_ctx("stream_mode") and "streamGenerateContent" or "generateContent")) end end local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method + local auth_param_name = conf.auth and conf.auth.param_name + local auth_param_value = conf.auth and conf.auth.param_value + local auth_param_location = conf.auth and conf.auth.param_location + local headers = { ["Accept"] = "application/json", ["Content-Type"] = "application/json", } + if auth_param_name and auth_param_value and auth_param_location == "query" then + f_url = fmt("%s?%s=%s", f_url, auth_param_name, auth_param_value) + end + if identity_interface and identity_interface.interface then if identity_interface.interface:needsRefresh() then -- HACK: A bug in lua-resty-gcp tries to re-load the environment @@ -639,32 +1052,23 @@ function _M.configure_request(conf, identity_interface) end local parsed_url - local operation = get_global_ctx("stream_mode") and "streamGenerateContent" - or "generateContent" local f_url = model.options and model.options.upstream_url if not f_url then -- upstream_url override is not set -- check if this is "public" or "vertex" gemini deployment - if model.options - and model.options.gemini - and model.options.gemini.api_endpoint - and model.options.gemini.project_id - and model.options.gemini.location_id - then - -- vertex mode - f_url = fmt(ai_shared.upstream_url_format["gemini_vertex"], - model.options.gemini.api_endpoint) .. - fmt(ai_shared.operation_map["gemini_vertex"][conf.route_type].path, - model.options.gemini.project_id, - model.options.gemini.location_id, - model.name, - operation) + if is_vertex_mode(model) then + local err + f_url, err = get_gemini_vertex_url(model, conf.route_type, get_global_ctx("stream_mode")) + if err then + return nil, "failed to calculate vertex URL: " .. err + end + else - -- public mode + -- 'consumer' Gemini mode f_url = ai_shared.upstream_url_format["gemini"] .. fmt(ai_shared.operation_map["gemini"][conf.route_type].path, model.name, - operation) + (get_global_ctx("stream_mode") and "streamGenerateContent" or "generateContent")) end end @@ -683,7 +1087,8 @@ function _M.configure_request(conf, identity_interface) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local default_port = (parsed_url.scheme == "https") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value @@ -726,12 +1131,20 @@ function _M.configure_request(conf, identity_interface) end +_M.get_model_coordinates = get_model_coordinates + + if _G._TEST then -- export locals for testing _M._to_tools = to_tools _M._to_gemini_chat_openai = to_gemini_chat_openai _M._from_gemini_chat_openai = from_gemini_chat_openai _M._openai_part_to_gemini_part = openai_part_to_gemini_part + _M._is_vertex_mode = is_vertex_mode + _M._get_gemini_vertex_url = get_gemini_vertex_url + _M._extract_response_tool_calls = extract_response_tool_calls + _M._extract_function_calls = extract_function_calls + _M._openai_toolresult_to_gemini_toolresult = openai_toolresult_to_gemini_toolresult end diff --git a/kong/llm/drivers/huggingface.lua b/kong/llm/drivers/huggingface.lua index 20006bee050..1e17aa4636a 100644 --- a/kong/llm/drivers/huggingface.lua +++ b/kong/llm/drivers/huggingface.lua @@ -28,6 +28,7 @@ local function from_huggingface(response_string, model_info, route_type) end local transformed_response = { + id = response_table.id, model = model_info.name, object = response_table.object or route_type, choices = {}, @@ -47,8 +48,9 @@ local function from_huggingface(response_string, model_info, route_type) elseif response_table.choices then for i, choice in ipairs(response_table.choices) do local content = choice.message and choice.message.content or "" + local role = choice.message and choice.message.role or "" table.insert(transformed_response.choices, { - message = { content = content }, + message = { content = content, role = role }, index = i - 1, finish_reason = "complete", }) @@ -66,63 +68,6 @@ local function from_huggingface(response_string, model_info, route_type) return result_string, nil end -local function set_huggingface_options(model_info) - local use_cache = false - local wait_for_model = false - - if model_info and model_info.options and model_info.options.huggingface then - use_cache = model_info.options.huggingface.use_cache or false - wait_for_model = model_info.options.huggingface.wait_for_model or false - end - - return { - use_cache = use_cache, - wait_for_model = wait_for_model, - } -end - -local function set_default_parameters(request_table) - local parameters = request_table.parameters or {} - if parameters.top_k == nil then - parameters.top_k = request_table.top_k - end - if parameters.top_p == nil then - parameters.top_p = request_table.top_p - end - if parameters.temperature == nil then - parameters.temperature = request_table.temperature - end - if parameters.max_tokens == nil then - if request_table.messages then - -- conversational model use the max_length param - -- https://huggingface.co/docs/api-inference/en/detailed_parameters?code=curl#conversational-task - parameters.max_length = request_table.max_tokens - else - parameters.max_new_tokens = request_table.max_tokens - end - end - request_table.top_k = nil - request_table.top_p = nil - request_table.temperature = nil - request_table.max_tokens = nil - - return parameters -end - -local function to_huggingface(task, request_table, model_info) - local parameters = set_default_parameters(request_table) - local options = set_huggingface_options(model_info) - if task == "llm/v1/completions" then - request_table.inputs = request_table.prompt - request_table.prompt = nil - end - request_table.options = options - request_table.parameters = parameters - request_table.model = model_info.name or request_table.model - - return request_table, "application/json", nil -end - local function safe_access(tbl, ...) local value = tbl for _, key in ipairs({ ... }) do @@ -203,6 +148,22 @@ function _M.from_format(response_string, model_info, route_type) return response_string, nil, metadata end +local function to_huggingface(task, request_table, model_info) + if task == "llm/v1/completions" then + request_table.inputs = request_table.prompt + request_table.prompt = nil + request_table.model = model_info.name or request_table.model + + elseif task == "llm/v1/chat" then + -- For router.huggingface.co, we need to include the model in the request body + request_table.model = model_info.name or request_table.model + request_table.inputs = request_table.prompt + + end + + return request_table, "application/json", nil +end + local transformers_to = { ["llm/v1/chat"] = to_huggingface, ["llm/v1/completions"] = to_huggingface, @@ -224,10 +185,6 @@ function _M.to_format(request_table, model_info, route_type) return response_object, content_type, nil end -local function build_url(base_url, route_type) - return (route_type == "llm/v1/completions") and base_url or (base_url .. "/v1/chat/completions") -end - local function huggingface_endpoint(conf, model) local parsed_url @@ -240,8 +197,7 @@ local function huggingface_endpoint(conf, model) return nil end - local url = build_url(base_url, conf.route_type) - parsed_url = socket_url.parse(url) + parsed_url = socket_url.parse(base_url) return parsed_url end @@ -260,14 +216,20 @@ function _M.configure_request(conf) kong.service.request.set_path(parsed_url.path) end kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, tonumber(parsed_url.port) or 443) + local default_port = (parsed_url.scheme == "https" or parsed_url.scheme == "wss") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value if auth_header_name and auth_header_value then - kong.service.request.set_header(auth_header_name, auth_header_value) + kong.service.request.set_header(auth_header_name, auth_header_value) end + + -- clear upstream `X-Forwarded-Host` header due to cloudfront limitation + ngx.var.upstream_x_forwarded_host = "" + return true, nil end diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index 7d2a47100ca..eeecc820dc0 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -278,7 +278,9 @@ function _M.configure_request(conf) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local default_port = (parsed_url.scheme == "https" or parsed_url.scheme == "wss") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index add82ebf627..4fcab5cb5a3 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -174,7 +174,8 @@ function _M.configure_request(conf) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local default_port = (parsed_url.scheme == "https" or parsed_url.scheme == "wss") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index f9a3a05d67f..3e8d22d319d 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -216,7 +216,8 @@ function _M.configure_request(conf) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local default_port = (parsed_url.scheme == "https") and 443 or 80 + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or default_port)) local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 939fec3477f..66babeebebd 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -14,9 +14,13 @@ local ai_plugin_o11y = require("kong.llm.plugin.observability") -- static local ipairs = ipairs local str_find = string.find +local str_byte = string.byte local str_sub = string.sub -local split = require("kong.tools.string").split local splitn = require("kong.tools.string").splitn +local table_remove = table.remove + +local NEWLINE = str_byte("\n") +local SPACE = str_byte(" ") local function str_ltrim(s) -- remove leading whitespace from string. return type(s) == "string" and s:gsub("^%s*", "") @@ -73,6 +77,9 @@ _M._CONST = { ["SSE_TERMINATOR"] = "[DONE]", ["AWS_STREAM_CONTENT_TYPE"] = "application/vnd.amazon.eventstream", ["GEMINI_STREAM_CONTENT_TYPE"] = "application/json", + ["SSE_CONTENT_TYPE"] = "text/event-stream", + ["UNIX_EPOCH"] = "1970-01-01T00:00:00.000000Z", + ["STRUCTURED_OUTPUT_TOOL_NAME"] = "kong_inc_to_structured_output", } _M._SUPPORTED_STREAMING_CONTENT_TYPES = { @@ -99,8 +106,9 @@ _M.upstream_url_format = { gemini = "https://generativelanguage.googleapis.com", gemini_vertex = "https://%s", bedrock = "https://bedrock-runtime.%s.amazonaws.com", + bedrock_agent = "https://bedrock-agent-runtime.%s.amazonaws.com", mistral = "https://api.mistral.ai:443", - huggingface = "https://api-inference.huggingface.co/models/%s", + huggingface = "https://router.huggingface.co", } _M.operation_map = { @@ -152,7 +160,7 @@ _M.operation_map = { }, gemini_vertex = { ["llm/v1/chat"] = { - path = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + path = "/v1/projects/%s/locations/%s/publishers/%s/models/%s:%s", }, }, mistral = { @@ -162,14 +170,11 @@ _M.operation_map = { }, }, huggingface = { - ["llm/v1/completions"] = { - path = "/models/%s", - method = "POST", - }, ["llm/v1/chat"] = { - path = "/models/%s", + path = "/v1/chat/completions", method = "POST", }, + -- NYI: image, audio }, bedrock = { ["llm/v1/chat"] = { @@ -217,7 +222,7 @@ _M.clear_response_headers = { function _M.merge_config_defaults(request, options, request_format) if options then request.temperature = options.temperature or request.temperature - request.max_tokens = options.max_tokens or request.max_tokens + request.max_tokens = options.max_tokens or request.max_tokens request.top_p = options.top_p or request.top_p request.top_k = options.top_k or request.top_k end @@ -225,7 +230,7 @@ function _M.merge_config_defaults(request, options, request_format) return request, nil end -local function handle_stream_event(event_table, model_info, route_type) +local function handle_ollama_stream_event(event_table, model_info, route_type) if event_table.done then -- return analytics table return "[DONE]", nil, { @@ -334,8 +339,27 @@ _M.cloud_identity_function = function(this_cache, plugin_config) end end +local _KEYBASTION = setmetatable({}, { + __mode = "k", + __index = _M.cloud_identity_function, +}) + +_M.get_key_bastion = function() + return _KEYBASTION +end + -local function json_array_iterator(input_str, prev_state) +local JSON_ARRAY_TYPE = { + JSONL = "jsonl", -- JSON Lines format + GEMINI = "gemini", -- Gemini [{}, {}] format + FLAT_JSON = "flat_json", -- Like JSON Lines format, but the lines may be separated by '\n' or '\r' +} + +-- json_array_iterator is an iterator that parses a JSON array from a string input. +-- @param input_str The input string containing the JSON array. +-- @param prev_state The previous state of the iterator, if any. +-- @param json_array_type The type of JSON array format to parse. See the JSON_ARRAY_TYPE table above for available types. +local function json_array_iterator(input_str, prev_state, json_array_type) local state = prev_state or { started = false, pos = 1, @@ -343,6 +367,22 @@ local function json_array_iterator(input_str, prev_state) eof = false, } + local sep_char, sep_char2, start_char + if json_array_type == JSON_ARRAY_TYPE.JSONL then -- JSON Lines format + sep_char = "\n" + start_char = nil + elseif json_array_type == JSON_ARRAY_TYPE.FLAT_JSON then -- Flat JSON format + sep_char = "\n" + sep_char2 = "\r" + start_char = nil + elseif json_array_type == JSON_ARRAY_TYPE.GEMINI then -- Gemini [{}, {}] format + sep_char = "," + start_char = "[" + else + error("Unsupported JSON array type: " .. tostring(json_array_type)) + end + + if state.eof then error("Iterator has reached end of input") end @@ -361,11 +401,11 @@ local function json_array_iterator(input_str, prev_state) while state.pos <= len and state.input:sub(state.pos, state.pos):match("%s") do state.pos = state.pos + 1 end - if state.pos > len or state.input:sub(state.pos, state.pos) ~= "[" then - error("Invalid start: expected '['") + if state.pos > len or (start_char and state.input:sub(state.pos, state.pos) ~= start_char) then + error("Invalid start: expected '" .. start_char .. "'") end state.started = true - state.pos = state.pos + 1 + state.pos = state.pos + (start_char and 1 or 0) -- move past the start character end -- Skip whitespace @@ -413,14 +453,18 @@ local function json_array_iterator(input_str, prev_state) delimiter = state.pos - 1 state.eof = true end - elseif char == ',' and brace_count == 0 and bracket_count == 0 then + elseif (char == sep_char or (sep_char2 and char == sep_char2)) + and brace_count == 0 and bracket_count == 0 + then -- Found element delimiter at top level delimiter = state.pos - 1 -- if delimiter is at start of string, skip it in next iteration - if state.pos == 1 then - start_pos = 2 + if state.pos == start_pos then + start_pos = state.pos + 1 end - elseif brace_count == 0 and bracket_count == 0 and state.pos == len then + end + + if not delimiter and brace_count == 0 and bracket_count == 0 and state.pos == len then -- Found element delimiter at end of string delimiter = state.pos end @@ -445,6 +489,80 @@ local function json_array_iterator(input_str, prev_state) end end +_M.JSON_ARRAY_TYPE = JSON_ARRAY_TYPE +_M.json_array_iterator = json_array_iterator + +local function ollama_message_has_tools(message) + return message + and type(message) == "table" + and message['tool_calls'] + and type(message['tool_calls']) == "table" + and #message['tool_calls'] > 0 +end + +-- Check for leap year +local function is_leap(y) + return (y % 4 == 0 and y % 100 ~= 0) or (y % 400 == 0) +end + +local ISO_8601_PATTERN = "^(%d%d%d%d)-(%d%d)-(%d%d)T(%d%d):(%d%d):(%d%d)%.(%d+)Z$" + +--- +-- Converts a JSON-format (ISO 8601) Timestamp into seconds since UNIX EPOCH. +-- Input only supports UTC timezone (suffix 'Z'). +-- +-- @param {string} ISO 8601 UTC input time, example '2025-04-22T13:40:31.926503Z' +-- @return {number} Seconds count since UNIX EPOCH. +function _M.iso_8601_to_epoch(timestamp) + if not string.match(timestamp, ISO_8601_PATTERN) then + return nil, "Invalid ISO 8601 timestamp format" + end + + local year, month, day, hour, min, sec, _ = + string.match(timestamp, ISO_8601_PATTERN) + + -- we can't use os.time as it relys on the system timezone setup and won't always be UTC + + -- Convert to numbers + year, month, day = tonumber(year), tonumber(month), tonumber(day) + hour, min, sec = tonumber(hour), tonumber(min), tonumber(sec) + + if year < 1970 or month < 1 or month > 12 or day < 1 or day > 31 + or hour < 0 or hour > 23 or min < 0 or min > 59 or sec < 0 or sec > 59 then + return nil, "Invalid date/time components" + end + + -- Days in each month (non-leap year) + local days_in_month = {31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31} + + if is_leap(year) then + days_in_month[2] = 29 + end + + if day > days_in_month[month] then + return nil, "Invalid day for the given month" + end + + -- Calculate days since Unix epoch (1970-01-01) + local days = 0 + + -- Add days for complete years + for y = 1970, year - 1 do + days = days + (is_leap(y) and 366 or 365) + end + + -- Add days for complete months in current year + for m = 1, month - 1 do + days = days + days_in_month[m] + end + + -- Add remaining days + days = days + day - 1 + + -- Convert to seconds and add time components + return days * 86400 + hour * 3600 + min * 60 + sec +end + --- -- Splits a HTTPS data chunk or frame into individual -- SSE-format messages, see: @@ -467,10 +585,12 @@ function _M.frame_to_events(frame, content_type) end -- some new LLMs return the JSON object-by-object, - -- because that totally makes sense to parse?! + -- to make it difficult to parse if content_type == _M._CONST.GEMINI_STREAM_CONTENT_TYPE then - for element, new_state in json_array_iterator(frame, kong.ctx.plugin.gemini_state) do - kong.ctx.plugin.gemini_state = new_state + local gemini_state = kong.ctx.plugin.gemini_state + local iter = json_array_iterator(frame, gemini_state, JSON_ARRAY_TYPE.GEMINI) + while true do + local element, gemini_state = iter() if element then local _, err = cjson.decode(element) if err then @@ -478,69 +598,109 @@ function _M.frame_to_events(frame, content_type) end events[#events+1] = { data = element } end - if new_state.eof then -- array end + + -- the order is important here: we check eof before "not element", so that we don't skip this event + if gemini_state.eof then -- array end kong.ctx.plugin.gemini_state = nil events[#events+1] = { data = _M._CONST.SSE_TERMINATOR } return events end + + if not element then + kong.ctx.plugin.gemini_state = gemini_state + return events + end end elseif content_type == _M._CONST.AWS_STREAM_CONTENT_TYPE then - local parser = aws_stream:new(frame) - while true do - local msg = parser:next_message() + local parser = kong.ctx.plugin.aws_stream_parser + if not parser then + kong.ctx.plugin.aws_stream_parser = aws_stream:new(frame) + parser = kong.ctx.plugin.aws_stream_parser + else + -- add the new frame to the existing parser + parser:add(frame) + end + + while parser:has_complete_message() do + -- We need to call `has_complete_message` first, as `next_message` will consume the data + local msg, err = parser:next_message() if not msg then + kong.log.err("failed to parse AWS stream message: ", err) break end events[#events+1] = { data = cjson.encode(msg) } end + -- there may be remained data in the stream, we will parse it with the next frame -- check if it's raw json and just return the split up data frame -- Cohere / Other flat-JSON format parser -- just return the split up data frame - elseif (not kong or (not kong.ctx.plugin.gemini_state and not kong.ctx.plugin.truncated_frame)) and string.sub(str_ltrim(frame), 1, 1) == "{" then - for event in frame:gmatch("[^\r\n]+") do - events[#events + 1] = { - data = event, - } + elseif content_type ~= _M._CONST.SSE_CONTENT_TYPE and + -- if we are parsing a flat JSON stream or the first character is '{' + (kong.ctx.plugin.flat_json_state or string.sub(str_ltrim(frame), 1, 1) == "{") + then + local state = kong.ctx.plugin.flat_json_state + local iter = json_array_iterator(frame, state, JSON_ARRAY_TYPE.FLAT_JSON) + while true do + local element, state = iter() + if element then + local _, err = cjson.decode(element) + if err then + kong.log.err("malformed JSON in flat json stream: ", err, ": ", element) + end + events[#events+1] = { data = element } + end + + if not element then + kong.ctx.plugin.flat_json_state = state + return events + end end -- standard SSE parser else - local event_lines, count = splitn(frame, "\n") - local struct = {} -- { event = nil, id = nil, data = nil } + -- test for previous abnormal start-of-frame (truncation tail) + if kong and kong.ctx.plugin.truncated_frame then + -- this is the tail of a previous incomplete chunk + ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") + frame = fmt("%s%s", kong.ctx.plugin.truncated_frame, frame) + kong.ctx.plugin.truncated_frame = nil + end - for i, dat in ipairs(event_lines) do - if dat == "" then - events[#events + 1] = struct - struct = {} -- { event = nil, id = nil, data = nil } - end + -- check if we were mid-chunk when the frame was truncated + local struct = kong.ctx.plugin.truncated_frame_struct or {} -- { event = nil, id = nil, data = nil } + kong.ctx.plugin.truncated_frame_struct = nil - -- test for truncated chunk on the last line (no trailing \r\n\r\n) - if dat ~= "" and count == i then - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame head") + local start = 1 + local n_events = 1 + while start <= #frame do + -- SSE chunks end with `\n\n` or `\r\n` + local end_of_msg = str_find(frame, "[\r\n]", start, false) + if not end_of_msg or end_of_msg == #frame then if kong then - kong.ctx.plugin.truncated_frame = fmt("%s%s", (kong.ctx.plugin.truncated_frame or ""), dat) + -- we may be mid-chunk when the frame was truncated + -- and already have read in the "id", "event", etc + kong.ctx.plugin.truncated_frame = frame:sub(start) + kong.ctx.plugin.truncated_frame_struct = struct end - break -- stop parsing immediately, server has done something wrong - end - - -- test for abnormal start-of-frame (truncation tail) - if kong and kong.ctx.plugin.truncated_frame then - -- this is the tail of a previous incomplete chunk - ngx.log(ngx.DEBUG, "[ai-proxy] truncated sse frame tail") - dat = fmt("%s%s", kong.ctx.plugin.truncated_frame, dat) - kong.ctx.plugin.truncated_frame = nil + break end - local s1, _ = str_find(dat, ":") -- find where the cut point is - - if s1 and s1 ~= 1 then - local field = str_sub(dat, 1, s1-1) -- returns "data" from data: hello world - local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world + -- find where the cut point is. As str_find cannot specify the end of the search range, + -- the returned s1 may be larger than the end of msg. + local s1 = str_find(frame, ":", start, true) + if s1 and s1 ~= 1 and s1 < end_of_msg then + local field = str_sub(frame, start, s1 - 1) -- returns "data" from data: hello world + local j = s1 + 1 + while j <= end_of_msg - 1 and str_byte(frame, j) == SPACE do + -- consume spaces. + j = j + 1 + end + local value = str_sub(frame, j, end_of_msg - 1) -- returns "hello world" from data: hello world -- for now not checking if the value is already been set if field == "event" then struct.event = value @@ -548,8 +708,19 @@ function _M.frame_to_events(frame, content_type) elseif field == "data" then struct.data = value end -- if end -- if + + start = end_of_msg + 1 + -- When start > #frame, str_byte returns nil, so we don't need to check out of bound here. + if str_byte(frame, start) == NEWLINE then + -- End of the SSE shunk. This is faster than calling str_find '\n' again. + events[n_events] = struct + n_events = n_events + 1 + struct = {} + start = start + 1 + end + end - end + end -- else return events end @@ -568,6 +739,28 @@ function _M.to_ollama(request_table, model) end + -- If tools are an actual JSON object, + -- flatten to a raw string. + -- Ollama is inconsistent in its response formats. + if request_table.messages then + for _, message in ipairs(request_table.messages) do + if message["tool_calls"] and type(message["tool_calls"]) == "table" then + for _, tool in ipairs(message["tool_calls"]) do + if tool['function'] + and tool['function']['arguments'] + and type(tool['function']['arguments']) == "string" then + + tool['function']['arguments'] = cjson.decode(tool['function']['arguments']) + end + end + end + end + end + + + -- handle tools + input.tools = request_table.tools + -- common parameters input.stream = request_table.stream or false -- for future capability input.model = model.name or request_table.name @@ -597,7 +790,7 @@ function _M.from_ollama(response_string, model_info, route_type) return nil, "failed to decode ollama response" end - output, _, analytics = handle_stream_event(response_table, model_info, route_type) + output, _, analytics = handle_ollama_stream_event(response_table, model_info, route_type) elseif route_type == "stream/llm/v1/completions" then local response_table, err = cjson.decode(response_string.data) @@ -605,7 +798,7 @@ function _M.from_ollama(response_string, model_info, route_type) return nil, "failed to decode ollama response" end - output, _, analytics = handle_stream_event(response_table, model_info, route_type) + output, _, analytics = handle_ollama_stream_event(response_table, model_info, route_type) else local response_table, err = cjson.decode(response_string) @@ -624,18 +817,40 @@ function _M.from_ollama(response_string, model_info, route_type) -- common fields output.model = response_table.model - output.created = response_table.created_at + output.created, err = _M.iso_8601_to_epoch(response_table.created_at) + if err then + ngx.log(ngx.WARN, "failed to convert created_at to epoch: ", err, ", fallback to 1970-01-01T00:00:00Z") + output.created = 0 + end -- analytics output.usage = { completion_tokens = response_table.eval_count or 0, prompt_tokens = response_table.prompt_eval_count or 0, - total_tokens = (response_table.eval_count or 0) + + total_tokens = (response_table.eval_count or 0) + (response_table.prompt_eval_count or 0), } if route_type == "llm/v1/chat" then output.object = "chat.completion" + + -- handle tools conversion + if ollama_message_has_tools(response_table.message) then + + -- If tools are an actual JSON object, + -- flatten to a raw string. + -- Ollama is inconsistent in its response formats. + for _, tool in ipairs(response_table.message['tool_calls']) do + if tool['function'] + and tool['function']['arguments'] + and type(tool['function']['arguments']) == "table" then + + tool['function']['arguments'] = cjson.encode(tool['function']['arguments']) + end + end + + end + output.choices = { { finish_reason = response_table.finish_reason or stop_reason, @@ -700,25 +915,40 @@ function _M.merge_model_options(kong_request, conf_m) new_conf_m[k] = v else -- string values - local tmpl_start, tmpl_end = str_find(v or "", '%$%((.-)%)') - if tmpl_start then - local tmpl = str_sub(v, tmpl_start+2, tmpl_end-1) -- strip surrounding $( and ) - local splitted = split(tmpl, '.') - if #splitted ~= 2 then - return nil, "cannot parse expression for field '" .. v .. "'" + local result = v + local has_error = false + + -- Use gsub to replace all template variables in one pass + result = result:gsub("%$%((.-)%)", function(tmpl) + if has_error then return "" end + + local splitted, count = splitn(tmpl, '.', 3) + if count ~= 2 then + has_error = true + err = "cannot parse expression for field '" .. v .. "'" + return "" end - local evaluated, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) - if err then - return nil, err + + local evaluated, eval_err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) + if eval_err then + has_error = true + err = eval_err + return "" end if not evaluated then - return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided" + has_error = true + err = splitted[1] .. " key " .. splitted[2] .. " was not provided" + return "" end - -- replace place holder with evaluated - new_conf_m[k] = str_sub(v, 1, tmpl_start - 1) .. evaluated .. str_sub(v, tmpl_end + 1) - else -- not a tmplate, just copy - new_conf_m[k] = v + + return evaluated + end) + + if has_error then + return nil, err end + + new_conf_m[k] = result end end @@ -993,7 +1223,7 @@ local function count_prompt(content, tokens_factor) return nil, "Invalid request format" end end - else + else return nil, "Invalid request format" end return count, nil @@ -1010,9 +1240,9 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor) if query_body.choices then -- Calculate the cost based on the content type for _, choice in ipairs(query_body.choices) do - if choice.message and choice.message.content then + if choice.message and choice.message.content then query_cost = query_cost + (count_words(choice.message.content) * tokens_factor) - elseif choice.text then + elseif choice.text then query_cost = query_cost + (count_words(choice.text) * tokens_factor) end end @@ -1049,6 +1279,49 @@ function _M.override_upstream_url(parsed_url, conf, model) end end +function _M.convert_structured_output_tool(response) + -- bounds check + if not response + or not response.choices + or not response.choices[1] + or not response.choices[1].message + or not response.choices[1].message.tool_calls + then + return nil + end + + local tool_calls = response.choices[1].message.tool_calls + + for i, tool_call in ipairs(tool_calls) do + if tool_call['function'] + and tool_call['function'].name == _M._CONST.STRUCTURED_OUTPUT_TOOL_NAME + then + -- delete the tool and move it to the message params + table_remove(response.choices[1].message.tool_calls, i) + response.choices[1].message.content = tool_call['function'].arguments or '[]' + response.choices[1].finish_reason = "stop" + + if #response.choices[1].message.tool_calls == 0 then + response.choices[1].message.tool_calls = nil + end + + return response + end + end + + return nil +end + +function _M.is_tool_result_message(message) + return message + and message.role + and message.role == "tool" + and message.tool_call_id + and type(message.tool_call_id) == "string" + and message.content + and type(message.content) == "string" +end + -- for unit tests if _G.TEST then _M._count_words = count_words diff --git a/kong/llm/plugin/base.lua b/kong/llm/plugin/base.lua index f97d7484193..9205d102447 100644 --- a/kong/llm/plugin/base.lua +++ b/kong/llm/plugin/base.lua @@ -1,5 +1,7 @@ local deflate_gzip = require("kong.tools.gzip").deflate_gzip local ai_plugin_ctx = require("kong.llm.plugin.ctx") +local kong_global = require "kong.global" + local get_global_ctx, _ = ai_plugin_ctx.get_global_accessors("_base") @@ -72,13 +74,16 @@ function MetaPlugin:configure(sub_plugin, configs) run_stage(STAGES.SETUP, sub_plugin, configs) end -function MetaPlugin:access(sub_plugin, conf) +function MetaPlugin:access(sub_plugin, conf, kong_plugin_t) ngx.ctx.ai_namespaced_ctx = ngx.ctx.ai_namespaced_ctx or {} ngx.ctx.ai_executed_filters = ngx.ctx.ai_executed_filters or {} if sub_plugin.enable_balancer_retry then kong.service.set_target_retry_callback(function() ngx.ctx.ai_executed_filters = {} + kong_global.set_named_ctx(kong, "plugin", kong_plugin_t, ngx.ctx) + kong_global.set_namespaced_log(kong, sub_plugin.name, ngx.ctx) + ngx.ctx.plugin_id = conf.__plugin_id MetaPlugin:retry(sub_plugin, conf) @@ -110,6 +115,7 @@ function MetaPlugin:body_filter(sub_plugin, conf) -- check if a response is already sent in access phase by any filter local sent, source = get_global_ctx("response_body_sent") if sent then + run_stage(STAGES.RES_PRE_PROCESSING, sub_plugin, conf) kong.log.debug("response already sent from source: ", source, " skipping body_filter") return end @@ -238,7 +244,7 @@ function _M:as_kong_plugin() if self.filters[STAGES.REQ_INTROSPECTION] or self.filters[STAGES.REQ_TRANSFORMATION] then Plugin.access = function(_, conf) - return MetaPlugin:access(self, conf) + return MetaPlugin:access(self, conf, Plugin) end end @@ -268,4 +274,4 @@ function _M:as_kong_plugin() return Plugin end -return _M +return _M \ No newline at end of file diff --git a/kong/llm/plugin/ctx.lua b/kong/llm/plugin/ctx.lua index 99176d46341..d6ac7b068bd 100644 --- a/kong/llm/plugin/ctx.lua +++ b/kong/llm/plugin/ctx.lua @@ -10,6 +10,7 @@ local schemas = { response_body_sent = "boolean", llm_format_adapter = "table", preserve_mode = "boolean", + structured_output_mode = "boolean", -- set by request phases to modify certain response transformations }, } @@ -145,7 +146,7 @@ end local EMPTY_REQUEST_T = _M.immutable_table({}) -function _M.set_request_body_table_inuse(t, source) +function _M.set_request_body_table_inuse(t, source, should_set_body) assert(source, "source is missing") -- merge overlay keys into the key itself @@ -161,6 +162,11 @@ function _M.set_request_body_table_inuse(t, source) _M.set_namespaced_ctx("_global", "request_body_table", t) ngx.ctx.ai_request_body_table_source = source + + -- optionally set the request body to Kong service + if should_set_body then + kong.service.request.set_body(t, "application/json") + end end function _M.get_request_body_table_inuse() diff --git a/kong/llm/plugin/shared-filters/normalize-request.lua b/kong/llm/plugin/shared-filters/normalize-request.lua index d45dc9a99dd..eb2e12f037c 100644 --- a/kong/llm/plugin/shared-filters/normalize-request.lua +++ b/kong/llm/plugin/shared-filters/normalize-request.lua @@ -19,10 +19,7 @@ local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_ local get_global_ctx, set_global_ctx = ai_plugin_ctx.get_global_accessors(_M.NAME) -local _KEYBASTION = setmetatable({}, { - __mode = "k", - __index = ai_shared.cloud_identity_function, -}) +local _KEYBASTION = ai_shared.get_key_bastion() local function bail(code, msg) if code == 400 and msg then @@ -124,8 +121,10 @@ local function validate_and_transform(conf) local multipart = ai_plugin_ctx.get_namespaced_ctx("parse-request", "multipart_request") -- check the incoming format is the same as the configured LLM format local compatible, err = llm.is_compatible(request_table, route_type) - if not multipart and not compatible then - return bail(400, err) + if conf.llm_format == "openai" then + if not multipart and not compatible then + return bail(400, err) + end end -- check if the user has asked for a stream, and/or if @@ -213,7 +212,7 @@ local function validate_and_transform(conf) return bail(500, "LLM request failed before proxying") end - -- now re-configure the request for this operation type + -- now configure the request for this operation type local ok, err = ai_driver.configure_request(conf, identity_interface and identity_interface.interface) if not ok then diff --git a/kong/llm/plugin/shared-filters/normalize-sse-chunk.lua b/kong/llm/plugin/shared-filters/normalize-sse-chunk.lua index 212fd777a44..cfa0c472c81 100644 --- a/kong/llm/plugin/shared-filters/normalize-sse-chunk.lua +++ b/kong/llm/plugin/shared-filters/normalize-sse-chunk.lua @@ -79,8 +79,7 @@ local function handle_streaming_frame(conf, chunk, finished) end end - local finish_reason - + local finish_reason = nil for _, event in ipairs(events) do -- TODO: currently only subset of driver follow the body, err, metadata pattern @@ -88,38 +87,41 @@ local function handle_streaming_frame(conf, chunk, finished) local model_t = ai_plugin_ctx.get_request_model_table_inuse() local formatted, _, metadata = ai_driver.from_format(event, model_t, "stream/" .. conf.route_type) - if formatted then - frame_buffer:put("data: ") - frame_buffer:put(formatted or "") - frame_buffer:put((formatted ~= ai_shared._CONST.SSE_TERMINATOR) and "\n\n" or "") - end - if formatted and formatted ~= ai_shared._CONST.SSE_TERMINATOR then -- only stream relevant frames back to the user -- append the "choice" to the buffer, for logging later. this actually works! local event_t, err = cjson.decode(formatted) if not err then + local token_t + if event_t.choices and #event_t.choices > 0 then finish_reason = event_t.choices[1].finish_reason end - local token_t = get_token_text(event_t) + token_t = get_token_text(event_t) - -- either enabled in ai-proxy plugin, or required by other plugin - if body_buffer then + if body_buffer and token_t then body_buffer:put(token_t) end end end + if formatted then + if formatted == ai_shared._CONST.SSE_TERMINATOR and not get_global_ctx("sample_event") then + frame_buffer:put("data: ") + frame_buffer:put(formatted) + frame_buffer:put("\n\n") + end + end + if conf.logging and conf.logging.log_statistics and metadata then -- gemini metadata specifically, works differently if conf.model.provider == "gemini" then - ai_plugin_o11y.metrics_set("llm_prompt_tokens_count", metadata.prompt_tokens or 0) - ai_plugin_o11y.metrics_set("llm_completion_tokens_count", metadata.completion_tokens or 0) + ai_plugin_o11y.metrics_set("input_tokens_count", metadata.prompt_tokens or 0) + ai_plugin_o11y.metrics_set("output_tokens_count", metadata.completion_tokens or 0) else - ai_plugin_o11y.metrics_add("llm_prompt_tokens_count", metadata.prompt_tokens or 0) - ai_plugin_o11y.metrics_add("llm_completion_tokens_count", metadata.completion_tokens or 0) + ai_plugin_o11y.metrics_add("input_tokens_count", metadata.prompt_tokens or 0) + ai_plugin_o11y.metrics_add("output_tokens_count", metadata.completion_tokens or 0) end end end @@ -140,8 +142,8 @@ local function handle_streaming_frame(conf, chunk, finished) if finished then local response = body_buffer and body_buffer:get() - local prompt_tokens_count = ai_plugin_o11y.metrics_get("llm_prompt_tokens_count") - local completion_tokens_count = ai_plugin_o11y.metrics_get("llm_completion_tokens_count") + local prompt_tokens_count = ai_plugin_o11y.metrics_get("input_tokens_count") + local completion_tokens_count = ai_plugin_o11y.metrics_get("output_tokens_count") if conf.logging and conf.logging.log_statistics then -- no metadata populated in the event streams, do our estimation @@ -151,7 +153,7 @@ local function handle_streaming_frame(conf, chunk, finished) -- -- essentially, every 4 characters is a token, with minimum of 1*4 per event completion_tokens_count = math.ceil(#strip(response) / 4) - ai_plugin_o11y.metrics_set("llm_completion_tokens_count", completion_tokens_count) + ai_plugin_o11y.metrics_set("output_tokens_count", completion_tokens_count) end end @@ -182,7 +184,7 @@ local function handle_streaming_frame(conf, chunk, finished) usage = { prompt_tokens = prompt_tokens_count, completion_tokens = completion_tokens_count, - total_tokens = ai_plugin_o11y.metrics_get("llm_total_tokens_count"), + total_tokens = ai_plugin_o11y.metrics_get("total_tokens_count"), } } @@ -220,13 +222,9 @@ function _M:run(conf) conf = ai_plugin_ctx.get_namespaced_ctx("ai-proxy-advanced-balance", "selected_target") or conf end - -- TODO: check if ai-response-transformer let response.source become not service - if kong.response.get_source() == "service" then - - handle_streaming_frame(conf, ngx.arg[1], ngx.arg[2]) - end + handle_streaming_frame(conf, ngx.arg[1], ngx.arg[2]) return true end -return _M \ No newline at end of file +return _M diff --git a/kong/llm/plugin/shared-filters/parse-json-response.lua b/kong/llm/plugin/shared-filters/parse-json-response.lua index 473c616de8c..aa7c4279476 100644 --- a/kong/llm/plugin/shared-filters/parse-json-response.lua +++ b/kong/llm/plugin/shared-filters/parse-json-response.lua @@ -13,9 +13,16 @@ local get_global_ctx, set_global_ctx = ai_plugin_ctx.get_global_accessors(_M.NAM function _M:run(_) - ai_plugin_o11y.record_request_end() + if get_global_ctx("response_body") or + get_global_ctx("stream_mode") or + kong.response.get_source() ~= "service" + then + return true + end - if get_global_ctx("response_body") or get_global_ctx("stream_mode") or kong.response.get_source() ~= "service" then + local content_type = kong.service.response.get_header("Content-Type") or "application/json" + -- gemini vertex ai return response header content-type = "text/html" in json case + if content_type:sub(1, 16) ~= "application/json" and content_type:sub(1, 9) ~= "text/html" then return true end diff --git a/kong/llm/plugin/shared-filters/parse-request.lua b/kong/llm/plugin/shared-filters/parse-request.lua index b0eca5aa22b..0bd857faec4 100644 --- a/kong/llm/plugin/shared-filters/parse-request.lua +++ b/kong/llm/plugin/shared-filters/parse-request.lua @@ -20,7 +20,7 @@ local get_global_ctx, set_global_ctx = ai_plugin_ctx.get_global_accessors(_M.NAM function _M:run(conf) - -- Thie might be called again in retry, simply skip it as we already parsed the request + -- This might be called again in retry, simply skip it as we already parsed the request if ngx.get_phase() == "balancer" then return true end @@ -52,7 +52,6 @@ function _M:run(conf) set_ctx("multipart_request", true) end - local adapter local llm_format = conf.llm_format if llm_format and llm_format ~= "openai" then diff --git a/kong/llm/plugin/shared-filters/parse-sse-chunk.lua b/kong/llm/plugin/shared-filters/parse-sse-chunk.lua index abec7581e07..37947130154 100644 --- a/kong/llm/plugin/shared-filters/parse-sse-chunk.lua +++ b/kong/llm/plugin/shared-filters/parse-sse-chunk.lua @@ -34,6 +34,8 @@ local function handle_streaming_frame(conf, chunk, finished) local events = ai_shared.frame_to_events(chunk, normalized_content_type) if not events then + -- unrecognized frame, need to reset the current events + set_ctx("current_events", {}) return end @@ -49,6 +51,9 @@ local function handle_streaming_frame(conf, chunk, finished) kong.log.debug("using existing body buffer created by: ", source) -- TODO: implement the ability to decode the frame based on content type + else + -- empty frame, need to reset the current events + set_ctx("current_events", {}) end end @@ -70,4 +75,4 @@ function _M:run(conf) return true end -return _M \ No newline at end of file +return _M diff --git a/kong/llm/plugin/shared-filters/save-request-body.lua b/kong/llm/plugin/shared-filters/save-request-body.lua new file mode 100644 index 00000000000..86bceb8862e --- /dev/null +++ b/kong/llm/plugin/shared-filters/save-request-body.lua @@ -0,0 +1,45 @@ +local ai_plugin_ctx = require("kong.llm.plugin.ctx") + +local _M = { + NAME = "save-request-body", + STAGE = "REQ_INTROSPECTION", + DESCRIPTION = "save the raw request body if needed", +} + +local FILTER_OUTPUT_SCHEMA = { + raw_request_body = "string", +} + +local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA) + + +function _M:run(conf) + -- This might be called again in retry, simply skip it as we already parsed the request + if ngx.get_phase() == "balancer" then + return true + end + + local log_request_body = false + if conf.logging then + log_request_body = not not conf.logging.log_payloads + elseif conf.targets ~= nil then + -- For ai-proxy-advanced, we store the raw request body if any target wants to log the request payload + for _, target in ipairs(conf.targets) do + if target.logging and target.logging.log_payloads then + log_request_body = true + break + end + end + end + + if log_request_body then + -- This is the raw request body which is sent by the client. The request_body_table key is similar to this, + -- but it is in openai format which is converted from other LLM vendors' format in parse_xxx_request. + -- Note we save the unmodified request body (not even json decode then encode). + set_ctx("raw_request_body", kong.request.get_raw_body(conf.max_request_body_size)) + end + + return true +end + +return _M diff --git a/kong/llm/plugin/shared-filters/serialize-analytics.lua b/kong/llm/plugin/shared-filters/serialize-analytics.lua index 390d528450f..0ac5df7d4aa 100644 --- a/kong/llm/plugin/shared-filters/serialize-analytics.lua +++ b/kong/llm/plugin/shared-filters/serialize-analytics.lua @@ -17,6 +17,8 @@ function _M:run(conf) return true end + ai_plugin_o11y.record_request_end() + local provider_name, request_model do local model_t = ai_plugin_ctx.get_request_model_table_inuse() @@ -82,8 +84,8 @@ function _M:run(conf) -- payloads if conf.logging and conf.logging.log_payloads then - -- can't use kong.service.get_raw_body because it also fall backs to get_body_file which isn't available in log phase - kong.log.set_serialize_value(string.format("ai.%s.payload.request", ai_plugin_o11y.NAMESPACE), ngx.req.get_body_data()) + local request_body = ai_plugin_ctx.get_namespaced_ctx("save-request-body", "raw_request_body") + kong.log.set_serialize_value(string.format("ai.%s.payload.request", ai_plugin_o11y.NAMESPACE), request_body) kong.log.set_serialize_value(string.format("ai.%s.payload.response", ai_plugin_o11y.NAMESPACE), get_global_ctx("response_body")) end diff --git a/kong/pdk/request.lua b/kong/pdk/request.lua index 5c5f6e63e00..827c3fdf982 100644 --- a/kong/pdk/request.lua +++ b/kong/pdk/request.lua @@ -682,6 +682,7 @@ local function new(self) local before_content = phase_checker.new(PHASES.rewrite, PHASES.access, + PHASES.balancer, PHASES.response, PHASES.error, PHASES.admin_api) @@ -700,7 +701,7 @@ local function new(self) -- and has performance implications. -- -- @function kong.request.get_raw_body - -- @phases rewrite, access, response, admin_api + -- @phases rewrite, access, balancer, response, admin_api -- @max_allowed_file_size[opt] number the max allowed file size to be read from, -- 0 means unlimited, but the size of this body will still be limited -- by Nginx's client_max_body_size. @@ -806,7 +807,7 @@ local function new(self) -- what MIME type the body was parsed as. -- -- @function kong.request.get_body - -- @phases rewrite, access, response, admin_api + -- @phases rewrite, access, balancer, response, admin_api -- @tparam[opt] string mimetype The MIME type. -- @tparam[opt] number max_args Sets a limit on the maximum number of parsed -- @tparam[opt] number max_allowed_file_size the max allowed file size to be read from diff --git a/kong/pdk/service/request.lua b/kong/pdk/service/request.lua index b1c5b4ba26a..8945d7a15aa 100644 --- a/kong/pdk/service/request.lua +++ b/kong/pdk/service/request.lua @@ -585,7 +585,7 @@ local function new(self) :put('Content-Disposition: form-data; name="') :put(k) :put('"\r\n\r\n') - :put(args[k]) + :put(tostring(args[k])) :put("\r\n") end out:put("--") diff --git a/kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua b/kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua index c93376452c0..2c35a1ce0dc 100644 --- a/kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua +++ b/kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua @@ -78,10 +78,8 @@ function _M:run(conf) -- Re-assign it to trigger GC of the old one and save memory. request_body_table = execute(cycle_aware_deep_copy(request_body_table), conf) - kong.service.request.set_body(request_body_table, "application/json") - set_ctx("decorated", true) - ai_plugin_ctx.set_request_body_table_inuse(request_body_table, _M.NAME) + ai_plugin_ctx.set_request_body_table_inuse(request_body_table, _M.NAME, true) return true end diff --git a/kong/plugins/ai-prompt-decorator/schema.lua b/kong/plugins/ai-prompt-decorator/schema.lua index a0a83bc71c0..d1556ba9379 100644 --- a/kong/plugins/ai-prompt-decorator/schema.lua +++ b/kong/plugins/ai-prompt-decorator/schema.lua @@ -5,7 +5,7 @@ local prompt_record = { required = false, fields = { { role = { type = "string", required = true, one_of = { "system", "assistant", "user" }, default = "system" }}, - { content = { type = "string", required = true, len_min = 1, len_max = 500 } }, + { content = { type = "string", required = true, len_min = 1, len_max = 100000 } }, } } diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 4a7eaff055d..1a8ba787892 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -6,7 +6,7 @@ local PRIORITY = 770 local AIPlugin = ai_plugin_base.define(NAME, PRIORITY) local SHARED_FILTERS = { - "parse-request", "normalize-request", "enable-buffering", + "parse-request", "save-request-body", "normalize-request", "enable-buffering", "normalize-response-header", "parse-sse-chunk", "normalize-sse-chunk", "parse-json-response", "normalize-json-response", "serialize-analytics", diff --git a/kong/plugins/ai-request-transformer/filters/transform-request.lua b/kong/plugins/ai-request-transformer/filters/transform-request.lua index 3270132e809..cbf6bb7a528 100644 --- a/kong/plugins/ai-request-transformer/filters/transform-request.lua +++ b/kong/plugins/ai-request-transformer/filters/transform-request.lua @@ -21,10 +21,7 @@ local FILTER_OUTPUT_SCHEMA = { local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA) -local _KEYBASTION = setmetatable({}, { - __mode = "k", - __index = ai_shared.cloud_identity_function, -}) +local _KEYBASTION = ai_shared.get_key_bastion() local function bad_request(msg) diff --git a/kong/plugins/ai-response-transformer/filters/transform-response.lua b/kong/plugins/ai-response-transformer/filters/transform-response.lua index 10b714304ec..3f9c27d40d4 100644 --- a/kong/plugins/ai-response-transformer/filters/transform-response.lua +++ b/kong/plugins/ai-response-transformer/filters/transform-response.lua @@ -23,10 +23,7 @@ local FILTER_OUTPUT_SCHEMA = { local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA) local _, set_global_ctx = ai_plugin_ctx.get_global_accessors(_M.NAME) -local _KEYBASTION = setmetatable({}, { - __mode = "k", - __index = ai_shared.cloud_identity_function, -}) +local _KEYBASTION = ai_shared.get_key_bastion() local function bad_request(msg) diff --git a/kong/tools/aws_stream.lua b/kong/tools/aws_stream.lua index ebefc2c2656..23f868e4efe 100644 --- a/kong/tools/aws_stream.lua +++ b/kong/tools/aws_stream.lua @@ -36,16 +36,15 @@ local _HEADER_EXTRACTORS = { -- @param is_hex boolean specify if the chunk bytes are already decoded to hex -- @usage -- local stream_parser = stream:new("00000120af0310f.......", true) --- local next, err = stream_parser:next_message() +-- if stream_parser:has_complete_message() then +-- local next, err = stream_parser:next_message() +-- -- do something with the next message +-- end function Stream:new(chunk, is_hex) local self = {} -- override 'self' to be the new object/class setmetatable(self, Stream) - - if #chunk < ((is_hex and 32) or 16) then - return nil, "cannot parse a chunk less than 16 bytes long" - end - - self.read_count = 0 + + self.read_count = 0 self.chunk = buf.new() self.chunk:put((is_hex and chunk) or to_hex(chunk)) @@ -79,6 +78,9 @@ function Stream:next_bytes(count) end local bytes = self.chunk:get(count * 2) + if #bytes < count * 2 then + return nil, "not enough bytes in buffer when trying to read " .. count .. " bytes, only " .. #bytes / 2 .. " bytes available" + end self.read_count = (count) + self.read_count return bytes @@ -106,6 +108,115 @@ function Stream:next_int(size) return tonumber(int, 16), int end +--- Extract a single header from the stream +-- @return table|nil header containing key and value, or nil on error +-- @return number bytes consumed for this header +-- @return string error message if failed +function Stream:extract_header() + -- the next 8-bit int is the "header key length" + local header_key_len, _, err = self:next_int(8) + if err then + return nil, 0, "failed to read header key length: " .. err + end + + -- Validate header key length + if header_key_len < 1 or header_key_len > 255 then + return nil, 0, "invalid header key length: " .. tostring(header_key_len) + end + + local header_key = self:next_utf_8(header_key_len) + if not header_key then + return nil, 0, "failed to read header key" + end + + local bytes_consumed = 1 + header_key_len -- key length byte + key bytes + + -- next 8-bits is the header type, which is an enum + local header_type, _, err = self:next_int(8) + if err then + return nil, bytes_consumed, "failed to read header type: " .. err + end + + bytes_consumed = bytes_consumed + 1 -- header type byte + + -- Validate header type is in valid range (0-9 according to AWS spec) + if header_type < 0 or header_type > 9 then + return nil, bytes_consumed, "invalid header type: " .. tostring(header_type) + end + + -- depending on the header type, depends on how long the header should max out at + local extractor = _HEADER_EXTRACTORS[header_type] + if not extractor then + return nil, bytes_consumed, "unsupported header type: " .. tostring(header_type) + end + + local header_value, header_value_len = extractor(self) + if not header_value then + return nil, bytes_consumed, "failed to extract header value for type: " .. tostring(header_type) + end + + bytes_consumed = bytes_consumed + header_value_len + + return { + key = header_key, + value = header_value, + type = header_type + }, bytes_consumed, nil +end + +--- returns the length of the chunk in bytes +-- @return number length of the chunk in bytes +function Stream:bytes() + if not self.chunk then + return nil, "function cannot be called on its own - initialise a chunk reader with :new(chunk)" + end + + -- because the chunk is hex-encoded, we divide by 2 to get the actual byte count + return #self.chunk / 2 +end + +local function hex_char_to_int(c) + -- The caller should ensure that `c` is a valid hex character + if c < 58 then + c = c - 48 -- '0' to '9' + else + c = c - 87 -- 'a' to 'f' + end + return tonumber(c) +end + +function Stream:has_complete_message() + if not self.chunk then + return nil, "function cannot be called on its own - initialise a chunk reader with :new(chunk)" + end + + local n = self:bytes() + -- check if we have at least the 4 bytes for the message length + if n < 4 then + return false + end + + local ptr, _ = self.chunk:ref() + local msg_len = 0 + for i = 0, 3 do + msg_len = msg_len * 256 + hex_char_to_int(ptr[i * 2]) * 16 + hex_char_to_int(ptr[i * 2 + 1]) + end + return n >= msg_len +end + +function Stream:add(chunk, is_hex) + if not self.chunk then + return nil, "function cannot be called on its own - initialise a chunk reader with :new(chunk)" + end + + if type(chunk) ~= "string" then + return nil, "data must be a string" + end + + -- add the data to the chunk + self.chunk:put((is_hex and chunk) or to_hex(chunk)) +end + --- returns the next message in the chunk, as a table. --- can be used as an iterator. -- @return table formatted next message from the given constructor chunk @@ -114,8 +225,8 @@ function Stream:next_message() return nil, "function cannot be called on its own - initialise a chunk reader with :new(chunk)" end - if #self.chunk < 1 then - return false + if not self:has_complete_message() then + return nil, "not enough bytes in buffer for a complete message" end -- get the message length and pull that many bytes @@ -125,13 +236,13 @@ function Stream:next_message() -- whole message at correct offset local msg_len, _, err = self:next_int(32) if err then - return err + return nil, err end -- get the headers length local headers_len, _, err = self:next_int(32) if err then - return err + return nil, err end -- get the preamble checksum diff --git a/spec/01-unit/35-aws_stream_spec.lua b/spec/01-unit/35-aws_stream_spec.lua new file mode 100644 index 00000000000..bae31c4743e --- /dev/null +++ b/spec/01-unit/35-aws_stream_spec.lua @@ -0,0 +1,77 @@ +local aws_stream = require("kong.tools.aws_stream") + + +describe("aws stream", function() + it("reject incomplete message", function() + local frame = "\0\0\0\44" .. -- total length + "\0\0\0\0" .. -- headers length (0) + "\0\0\0\0" .. -- crc + ("1234567"):rep(4) -- payload, 4 bytes are missing + local stream, err = aws_stream:new(frame) + assert.is_nil(err) + assert.equal(40, stream:bytes()) + local msg, err = stream:next_message() + assert.is_nil(msg) + assert.equal(err, "not enough bytes in buffer for a complete message") + + local frame = "\0\0\0\40" .. -- total length + "\0\0\0\0" .. -- headers length (0) + "\0\0\0\0" .. -- crc + ("1234567"):rep(4) .. -- payload + "\0\1" -- incomplete length for the next message + local stream, err = aws_stream:new(frame) + assert.is_nil(err) + stream:next_message() + local msg, err = stream:next_message() + assert.is_nil(msg) + assert.equal(err, "not enough bytes in buffer for a complete message") + end) + + it("reject out of bound", function() + -- test for low-level api + local frame = "\0\0\0\40" .. -- total length + "\0\0\0\0" .. -- headers length (0) + "\0\0\0\0" .. -- crc + ("1234567"):rep(4) -- payload + local stream, err = aws_stream:new(frame) + assert.is_nil(err) + local bytes, err = stream:next_bytes(44) -- request more bytes than available + assert.is_nil(bytes) + assert.equal(err, "not enough bytes in buffer when trying to read 44 bytes, only 40 bytes available") + + local stream = aws_stream:new(frame) + local bytes = stream:next_bytes(40) -- request exact number of bytes + assert.is_not_nil(bytes) + bytes, err = stream:next_bytes(1) -- request one more byte + assert.is_nil(bytes) + assert.equal(err, "not enough bytes in buffer when trying to read 1 bytes, only 0 bytes available") + end) + + it("check completion", function() + local frame = "\0\0\0\44" .. -- total length + "\0\0\0\0" .. -- headers length (0) + "\0\0\0\0" .. -- crc + ("1234567"):rep(4) -- payload, 4 bytes are missing + local stream = aws_stream:new(frame) + assert.is_false(stream:has_complete_message()) + stream:add("1234") -- add the missing bytes + assert.is_true(stream:has_complete_message()) + local msg = stream:next_message() + assert.same(("1234567"):rep(4), msg.body) + + local frame = "\0\0\0\40" .. -- total length + "\0\0\0\0" .. -- headers length (0) + "\0\0\0\0" .. -- crc + ("1234567"):rep(4) .. -- payload + "\0\0\0\40" -- incomplete length for the next message + local stream = aws_stream:new(frame) + assert.is_true(stream:has_complete_message()) + stream:next_message() + assert.is_false(stream:has_complete_message()) + local remain = "\0\0\0\0" .. -- headers length (0) + "\0\0\0\0" .. -- crc + ("1234567"):rep(4) -- payload + stream:add(remain) -- add the missing bytes + assert.is_true(stream:has_complete_message()) + end) +end) diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 04e69adb046..0a2a6e65d84 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -5,6 +5,7 @@ local cjson = require("cjson.safe") local fmt = string.format local llm = require("kong.llm") local ai_shared = require("kong.llm.drivers.shared") +local pdk_log = require "kong.pdk.log" local SAMPLE_LLM_V2_CHAT_MULTIMODAL_IMAGE_URL = { messages = { @@ -174,6 +175,117 @@ local SAMPLE_OPENAI_TOOLS_REQUEST = { }, } +local SAMPLE_OPENAI_TOOLS_REQUEST_WITH_REPLY = { + messages = { + { + role = "user", + content = "What is the current temperature in London? Get it, and then convert it to Fahrenheit.\n\nAlso, side note, what country is London in?" + }, + { + role = "assistant", + content = nil, + tool_calls = { + { + id = "call_WQY6Dj7ncvDFjE5EuTeGzWGc", + type = "function", + ['function'] = { + name = "get_weather", + arguments = "{\"location\": \"London\"}" + } + }, + { + id = "call_KJeuUux7YsaWo7zVUlnklzae", + type = "function", + ['function'] = { + name = "where_is_city", + arguments = "{\"city_name\": \"London\", \"continent\": \"Europe\"}" + } + } + } + }, + { + role = "tool", + tool_call_id = "call_WQY6Dj7ncvDFjE5EuTeGzWGc", + content = "35" + }, + { + role = "tool", + tool_call_id = "call_KJeuUux7YsaWo7zVUlnklzae", + content = "{\"country\": \"England\", \"continent\": \"Europe\"}" + }, + { + role = "assistant", + content = "The current temperature in London is 35 degrees Celcius, which is 95 degrees Fahrenheit. London is in England, which is part of Europe." + } + }, + tools = { + { + type = "function", + ['function'] = { + name = "where_is_city", + description = "Returns what country a city is in.", + parameters = { + type = "object", + properties = { + city_name = { + type = "string", + description = "City e.g. Bogotá" + }, + continent = { + type = "string", + description = "Continent e.g. South America" + } + }, + required = { + "city_name" + }, + additionalProperties = false + } + } + }, + { + type = "function", + ['function'] = { + name = "get_weather", + description = "Get current temperature for a given location in Celcius.", + parameters = { + type = "object", + properties = { + location = { + type = "string", + description = "City e.g. Bogotá" + } + }, + required = { + "location" + }, + additionalProperties = false + } + } + }, + { + type = "function", + ['function'] = { + name = "convert_celcius_to_fahrenheit", + description = "Convert Celcius to Fahrenheit.", + parameters = { + type = "object", + properties = { + degrees_celcius = { + type = "integer", + description = "Temperature in DEGREES CELCIUS" + } + }, + required = { + "degrees_celcius" + }, + additionalProperties = false + } + } + } + } +} + local SAMPLE_GEMINI_TOOLS_RESPONSE = { candidates = { { content = { @@ -191,6 +303,43 @@ local SAMPLE_GEMINI_TOOLS_RESPONSE = { } }, } +local SAMPLE_GEMINI_TOOLS_RESPONSE_WITH_CHATTER = { + responseId = "chatcmpl-12345", + modelVersion = "gemini-2.5-pro-2025_08_09", + createTime = "2025-08-04T23:01:58.396462Z", + candidates = { { + content = { + role = "model", + parts = { { + functionCall = { + name = "sql_execute", + args = { + product_name = "NewPhone" + } + } + }, + { + text = "And now I will call the function to determine your location." + }, + { + functionCall = { + name = "from_place", + args = { + city = "London", + country = "England" + } + } + }, + { + text = "It has been called." + } } + }, + finishReason = "STOP", + } }, +} + + + local SAMPLE_BEDROCK_TOOLS_RESPONSE = { metrics = { latencyMs = 3781 @@ -410,60 +559,912 @@ local FORMATS = { config = { name = "bedrock", provider = "bedrock", - options = { - max_tokens = 8192, - temperature = 0.8, - top_k = 1, - top_p = 0.6, - }, - }, - }, - }, -} + options = { + max_tokens = 8192, + temperature = 0.8, + top_k = 1, + top_p = 0.6, + }, + }, + }, + }, +} + +local STREAMS = { + openai = { + ["llm/v1/chat"] = { + name = "gpt-4", + provider = "openai", + }, + ["llm/v1/completions"] = { + name = "gpt-3.5-turbo-instruct", + provider = "openai", + }, + }, + cohere = { + ["llm/v1/chat"] = { + name = "command", + provider = "cohere", + }, + ["llm/v1/completions"] = { + name = "command-light", + provider = "cohere", + }, + }, +} + +local expected_stream_choices = { + ["llm/v1/chat"] = { + [1] = { + delta = { + content = "the answer", + }, + finish_reason = ngx.null, + index = 0, + logprobs = ngx.null, + }, + }, + ["llm/v1/completions"] = { + [1] = { + text = "the answer", + finish_reason = ngx.null, + index = 0, + logprobs = ngx.null, + }, + }, +} + +describe(PLUGIN_NAME .. ": (unit)", function() + setup(function() + package.loaded["kong.llm.drivers.shared"] = nil + _G.TEST = true + ai_shared = require("kong.llm.drivers.shared") + end) + + before_each(function() + _G.kong = {} + kong.log = pdk_log.new(kong) + -- provide clean context for each test + kong.ctx = { + plugin = {}, + } + end) + + teardown(function() + _G.TEST = nil + end) + + it("resolves referenceable plugin configuration from request context", function() + local fake_request = { + ["get_header"] = function(header_name) + local headers = { + ["from_header_1"] = "header_value_here_1", + ["from_header_2"] = "header_value_here_2", + } + return headers[header_name] + end, + + ["get_uri_captures"] = function() + return { + ["named"] = { + ["uri_cap_1"] = "cap_value_here_1", + ["uri_cap_2"] = "cap_value_here_2", + }, + } + end, + + ["get_query_arg"] = function(query_arg_name) + local query_args = { + ["arg_1"] = "arg_value_here_1", + ["arg_2"] = "arg_value_here_2", + } + return query_args[query_arg_name] + end, + } + + local fake_config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + }, + model = { + name = "$(headers.from_header_1)", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + azure_instance = "$(uri_captures.uri_cap_1)", + azure_deployment_id = "$(headers.from_header_1)", + azure_api_version = "$(query_params.arg_1)", + upstream_url = "https://$(uri_captures.uri_cap_1).example.com", + bedrock = { + aws_region = "$(uri_captures.uri_cap_1)", + } + }, + }, + } + + local result, err = ai_shared.merge_model_options(fake_request, fake_config) + assert.is_falsy(err) + assert.same(result.model.name, "header_value_here_1") + assert.same(result.model.options, { + azure_api_version = 'arg_value_here_1', + azure_deployment_id = 'header_value_here_1', + azure_instance = 'cap_value_here_1', + max_tokens = 256, + temperature = 1, + upstream_url = "https://cap_value_here_1.example.com", + bedrock = { + aws_region = "cap_value_here_1", + }, + }) + end) + + it("returns appropriate error when referenceable plugin configuration is missing from request context", function() + local fake_request = { + ["get_header"] = function(header_name) + local headers = { + ["from_header_1"] = "header_value_here_1", + ["from_header_2"] = "header_value_here_2", + } + return headers[header_name] + end, + + ["get_uri_captures"] = function() + return { + ["named"] = { + ["uri_cap_1"] = "cap_value_here_1", + ["uri_cap_2"] = "cap_value_here_2", + }, + } + end, + + ["get_query_arg"] = function(query_arg_name) + local query_args = { + ["arg_1"] = "arg_value_here_1", + ["arg_2"] = "arg_value_here_2", + } + return query_args[query_arg_name] + end, + } + + local fake_config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + azure_instance = "$(uri_captures.uri_cap_3)", + azure_deployment_id = "$(headers.from_header_1)", + azure_api_version = "$(query_params.arg_1)", + }, + }, + } + + local _, err = ai_shared.merge_model_options(fake_request, fake_config) + assert.same("uri_captures key uri_cap_3 was not provided", err) + + local fake_config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + azure_instance = "$(uri_captures.uri_cap_1)", + azure_deployment_id = "$(headers.from_header_1)", + azure_api_version = "$(query_params.arg_1)", + bedrock = { + aws_region = "$(uri_captures.uri_cap_3)", + } + }, + }, + } + + local _, err = ai_shared.merge_model_options(fake_request, fake_config) + assert.same("uri_captures key uri_cap_3 was not provided", err) + + local fake_config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + azure_instance = "$(uri_captures_uri_cap_1)", + }, + }, + } + + local _, err = ai_shared.merge_model_options(fake_request, fake_config) + assert.same("cannot parse expression for field '$(uri_captures_uri_cap_1)'", err) + end) + + -- generic tests + it("throws correct error when format is not supported", function() + local driver = require("kong.llm.drivers.mistral") -- one-shot, random example of provider with only prompt support + + local model_config = { + route_type = "llm/v1/chatnopenotsupported", + name = "mistral-tiny", + provider = "mistral", + options = { + max_tokens = 512, + temperature = 0.5, + mistral_format = "ollama", + }, + } + + local request_json = pl_file.read("spec/fixtures/ai-proxy/unit/requests/llm-v1-chat.json") + local request_table, err = cjson.decode(request_json) + assert.is_falsy(err) + + -- send it + local actual_request_table, content_type, err = driver.to_format(request_table, model_config, model_config.route_type) + assert.is_nil(actual_request_table) + assert.is_nil(content_type) + assert.equal(err, "no transformer available to format mistral://llm/v1/chatnopenotsupported/ollama") + end) + + + it("produces a correct default config merge", function() + local formatted, err = ai_shared.merge_config_defaults( + SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS, + { + max_tokens = 1024, + top_p = 0.5, + }, + "llm/v1/chat" + ) + + formatted.messages = nil -- not needed for config merge + + assert.is_nil(err) + assert.same({ + max_tokens = 1024, + temperature = 0.1, + top_p = 0.5, + some_extra_param = "string_val", + another_extra_param = 0.5, + }, formatted) + end) + + describe("count_words", function() + local c = ai_shared._count_words + + it("normal prompts", function() + assert.same(10, c(string.rep("apple ", 10))) + end) + + it("multi-modal prompts", function() + assert.same(10, c({ + { + type = "text", + text = string.rep("apple ", 10), + }, + })) + + assert.same(20, c({ + { + type = "text", + text = string.rep("apple ", 10), + }, + { + type = "text", + text = string.rep("banana ", 10), + }, + })) + + assert.same(10, c({ + { + type = "not_text", + text = string.rep("apple ", 10), + }, + { + type = "text", + text = string.rep("banana ", 10), + }, + { + type = "text", + -- somehow malformed + }, + })) + end) + end) + + describe("gemini multimodal", function() + local gemini_driver + + setup(function() + _G._TEST = true + package.loaded["kong.llm.drivers.gemini"] = nil + gemini_driver = require("kong.llm.drivers.gemini") + end) + + teardown(function() + _G._TEST = nil + end) + + it("transforms a text type prompt to gemini GOOD", function() + local gemini_prompt, err = gemini_driver._openai_part_to_gemini_part( + { + ["type"] = "text", + ["text"] = "What is in this picture?", + }) + + assert.not_nil(gemini_prompt) + assert.is_nil(err) + + assert.same(gemini_prompt, + { + ["text"] = "What is in this picture?", + }) + end) + + it("transforms a text type prompt to gemini BAD MISSING TEXT FIELD", function() + local gemini_prompt, err = gemini_driver._openai_part_to_gemini_part( + { + ["type"] = "text", + ["bad_text_field"] = "What is in this picture?", + }) + + assert.is_nil(gemini_prompt) + assert.not_nil(err) + + assert.same("message part type is 'text' but is missing .text block", err) + end) + + it("transforms an image_url type prompt when data is a URL to gemini GOOD", function() + local gemini_prompt, err = gemini_driver._openai_part_to_gemini_part( + { + ["type"] = "image_url", + ["image_url"] = { + ["url"] = "https://example.local/image.jpg", + }, + }) + + assert.not_nil(gemini_prompt) + assert.is_nil(err) + + assert.same(gemini_prompt, + { + ["fileData"] = { + ["fileUri"] = "https://example.local/image.jpg", + ["mimeType"] = "image/generic", + }, + }) + end) + + it("transforms an image_url type prompt when data is a URL to gemini BAD MISSING IMAGE FIELD", function() + local gemini_prompt, err = gemini_driver._openai_part_to_gemini_part( + { + ["type"] = "image_url", + ["image_url"] = "https://example.local/image.jpg", + }) + + assert.is_nil(gemini_prompt) + assert.not_nil(err) + + assert.same("message part type is 'image_url' but is missing .image_url.url block", err) + end) + + it("fails to transform a non-mapped multimodal entity type", function() + local gemini_prompt, err = gemini_driver._openai_part_to_gemini_part( + { + ["type"] = "doesnt_exist", + ["doesnt_exist"] = "https://example.local/video.mp4", + }) + + assert.is_nil(gemini_prompt) + assert.not_nil(err) + + assert.same("cannot transform part of type 'doesnt_exist' to Gemini format", err) + end) + + it("transforms 'describe this image' via URL from openai to gemini", function() + local gemini_prompt, _, err = gemini_driver._to_gemini_chat_openai(SAMPLE_LLM_V2_CHAT_MULTIMODAL_IMAGE_URL) + + assert.is_nil(err) + assert.not_nil(gemini_prompt) + + gemini_prompt.generationConfig = nil -- not needed for comparison + + assert.same({ + ["contents"] = { + { + ["role"] = "user", + ["parts"] = { + { + ["text"] = "What is in this picture?", + }, + { + ["fileData"] = { + ["fileUri"] = "https://example.local/image.jpg", + ["mimeType"] = "image/generic", + }, + } + }, + }, + { + ["role"] = "model", + ["parts"] = { + { + ["text"] = "A picture of a cat.", + }, + }, + }, + { + ["role"] = "user", + ["parts"] = { + { + ["text"] = "Now draw it wearing a party-hat.", + }, + }, + }, + } + }, gemini_prompt) + end) + + it("transforms 'describe this image' via base64 from openai to gemini", function() + local gemini_prompt, _, err = gemini_driver._to_gemini_chat_openai(SAMPLE_LLM_V2_CHAT_MULTIMODAL_IMAGE_B64) + + assert.is_nil(err) + assert.not_nil(gemini_prompt) + + gemini_prompt.generationConfig = nil -- not needed for comparison + + assert.same({ + ["contents"] = { + { + ["role"] = "user", + ["parts"] = { + { + ["text"] = "What is in this picture?", + }, + { + ["inlineData"] = { + ["data"] = "Y2F0X3BuZ19oZXJlX2xvbAo=", + ["mimeType"] = "image/png", + }, + } + }, + }, + { + ["role"] = "model", + ["parts"] = { + { + ["text"] = "A picture of a cat.", + }, + }, + }, + { + ["role"] = "user", + ["parts"] = { + { + ["text"] = "Now draw it wearing a party-hat.", + }, + }, + }, + } + }, gemini_prompt) + end) + + end) + + + describe("gemini tools", function() + local gemini_driver + + setup(function() + _G._TEST = true + package.loaded["kong.llm.drivers.gemini"] = nil + gemini_driver = require("kong.llm.drivers.gemini") + end) + + teardown(function() + _G._TEST = nil + end) + + it("transforms openai tools to gemini tool declarations GOOD", function() + local gemini_tools = gemini_driver._to_tools(SAMPLE_OPENAI_TOOLS_REQUEST.tools) + + assert.not_nil(gemini_tools) + assert.same(gemini_tools, { + { + function_declarations = { + { + description = "Check a product is in stock.", + name = "check_stock", + parameters = { + properties = { + product_name = { + type = "string" + } + }, + required = { + "product_name" + }, + type = "object" + } + } + } + } + }) + end) + + it("transforms openai tools to gemini tools NO_TOOLS", function() + local gemini_tools = gemini_driver._to_tools(SAMPLE_LLM_V1_CHAT) + + assert.is_nil(gemini_tools) + end) + + it("transforms openai tools to gemini tools NIL", function() + local gemini_tools = gemini_driver._to_tools(nil) + + assert.is_nil(gemini_tools) + end) + + it("transforms openai tool_calls to gemini functionCalls GOOD", function() + local gemini_tools = gemini_driver._extract_function_calls(SAMPLE_OPENAI_TOOLS_REQUEST_WITH_REPLY.messages[2]) + + assert.not_nil(gemini_tools) + assert.same(gemini_tools, { + { + functionCall = { + name = "get_weather", + args = { + location = "London" + } + } + }, + { + functionCall = { + name = "where_is_city", + args = { + city_name = "London", + continent = "Europe" + } + } + } + }) + end) + + it("transforms openai tool_calls to gemini functionCalls NO TOOL CALLS", function() + local gemini_tools = gemini_driver._extract_function_calls(SAMPLE_OPENAI_TOOLS_REQUEST_WITH_REPLY.messages[1]) + + assert.is_nil(gemini_tools) + end) + + it("transforms openai tool_calls to gemini functionCalls NIL", function() + local gemini_tools = gemini_driver._extract_function_calls(nil) + + assert.is_nil(gemini_tools) + end) + + it("transforms gemini tools to openai tools GOOD", function() + local openai_tools = gemini_driver._extract_response_tool_calls(SAMPLE_GEMINI_TOOLS_RESPONSE.candidates[1]) + + assert.not_nil(openai_tools) + + for _, v in ipairs(openai_tools) do + assert.is_not_nil(v.id) + v.id = nil -- remove random id for comparison + + assert.is_not_nil(v['function'].arguments) + v['function'].arguments = cjson.decode(v['function'].arguments) -- decode arguments to stop flaky tests + end + + assert.same(openai_tools, { + { + ['type'] = "function", + ['function'] = { + ['name'] = "sql_execute", + ['arguments'] = cjson.decode("{\"product_name\":\"NewPhone\"}") + } + } + }) + + --- + + openai_tools = gemini_driver._extract_response_tool_calls(SAMPLE_GEMINI_TOOLS_RESPONSE_WITH_CHATTER.candidates[1]) + + assert.not_nil(openai_tools) + + for _, v in ipairs(openai_tools) do + assert.is_not_nil(v.id) + v.id = nil -- remove random id for comparison + + assert.is_not_nil(v['function'].arguments) + v['function'].arguments = cjson.decode(v['function'].arguments) -- decode arguments to stop flaky tests + end + + assert.same(openai_tools, { + { + ['type'] = "function", + ['function'] = { + ['name'] = "sql_execute", + ['arguments'] = cjson.decode("{\"product_name\":\"NewPhone\"}") + } + }, + { + ['type'] = 'function', + ['function'] = { + ['name'] = 'from_place', + ['arguments'] = cjson.decode('{"country":"England","city":"London"}') + } + } + }) + end) + + it("transforms openai tool_call results to gemini functionResponses GOOD", function() + local gemini_toolresult = gemini_driver._openai_toolresult_to_gemini_toolresult( + SAMPLE_OPENAI_TOOLS_REQUEST_WITH_REPLY.messages[2], + SAMPLE_OPENAI_TOOLS_REQUEST_WITH_REPLY.messages[3]) + + assert.not_nil(gemini_toolresult) + assert.same(gemini_toolresult, { + ['functionResponse'] = { + ['name'] = 'get_weather', + ['response'] = { + ['result'] = '35' + } + } + }) + + gemini_toolresult = gemini_driver._openai_toolresult_to_gemini_toolresult( + SAMPLE_OPENAI_TOOLS_REQUEST_WITH_REPLY.messages[2], + SAMPLE_OPENAI_TOOLS_REQUEST_WITH_REPLY.messages[4]) + + assert.not_nil(gemini_toolresult) + assert.same(gemini_toolresult, { + ['functionResponse'] = { + ['name'] = 'where_is_city', + ['response'] = { + ['country'] = 'England', + ['continent'] = 'Europe' + } + } + }) + end) + + it("transforms openai tools to gemini tools from whole response multiple tool use GOOD", function() + local gemini_response, _, err = gemini_driver._to_gemini_chat_openai(SAMPLE_OPENAI_TOOLS_REQUEST_WITH_REPLY, {}, "llm/v1/chat") + + assert.is_nil(err) + assert.not_nil(gemini_response) + assert.not_nil(gemini_response.contents) + + assert.same(gemini_response.contents, { + { + ['parts'] = { + { + text = "What is the current temperature in London? Get it, and then convert it to Fahrenheit.\n\nAlso, side note, what country is London in?" + } + }, + ['role'] = 'user' + }, + { + ['parts'] = { + { + ['functionCall'] = { + ['name'] = "get_weather", + ['args'] = { + ['location'] = "London" + } + } + }, + { + ['functionCall'] = { + ['name'] = "where_is_city", + ['args'] = { + ['city_name'] = "London", + ['continent'] = "Europe" + } + } + } + }, + ['role'] = 'model' + }, + { + ['parts'] = { + { + ['functionResponse'] = { + ['name'] = "get_weather", + ['response'] = { + ['result'] = "35" + } + } + }, + { + ['functionResponse'] = { + ['name'] = "where_is_city", + ['response'] = { + ['country'] = "England", + ['continent'] = "Europe" + } + } + } + }, + ['role'] = 'user' + }, + { + ['parts'] = { + { + text = "The current temperature in London is 35 degrees Celcius, which is 95 degrees Fahrenheit. London is in England, which is part of Europe." + } + }, + ['role'] = 'model' + } + }) + end) + + it("transforms gemini tools to openai tools from whole response GOOD", function() + local openai_response = gemini_driver._from_gemini_chat_openai(SAMPLE_GEMINI_TOOLS_RESPONSE, {}, "llm/v1/chat") + + assert.not_nil(openai_response) + + openai_response = cjson.decode(openai_response) + + for _, v in ipairs(openai_response.choices[1].message.tool_calls) do + assert.is_not_nil(v.id) + v.id = nil -- remove random id for comparison + + assert.is_not_nil(v['function'].arguments) + v['function'].arguments = cjson.decode(v['function'].arguments) -- decode arguments to stop flaky tests + end + + assert.same(openai_response.choices[1].message.tool_calls, { + { + ['type'] = "function", + ['function'] = { + ['name'] = "sql_execute", + ['arguments'] = cjson.decode("{\"product_name\":\"NewPhone\"}") + } + } + }) + + --- + + openai_response = gemini_driver._from_gemini_chat_openai(SAMPLE_GEMINI_TOOLS_RESPONSE_WITH_CHATTER, {}, "llm/v1/chat") + + assert.not_nil(openai_response) + + openai_response = cjson.decode(openai_response) + + for _, v in ipairs(openai_response.choices[1].message.tool_calls) do + assert.is_not_nil(v.id) + v.id = nil -- remove random id for comparison + + assert.is_not_nil(v['function'].arguments) + v['function'].arguments = cjson.decode(v['function'].arguments) -- decode arguments to stop flaky tests + end + + assert.same(openai_response.choices[1].message.tool_calls, { + { + ['type'] = "function", + ['function'] = { + ['name'] = "sql_execute", + ['arguments'] = cjson.decode("{\"product_name\":\"NewPhone\"}") + } + }, + { + ['type'] = 'function', + ['function'] = { + ['name'] = 'from_place', + ['arguments'] = cjson.decode('{"country":"England","city":"London"}') + } + } + }) + + assert.same(openai_response.choices[1].message.content, "And now I will call the function to determine your location.\\nIt has been called.") + + assert.same(openai_response.choices[1].finish_reason, "tool_calls") + assert.is_not_nil(openai_response.id) + assert.same(openai_response.id, "chatcmpl-12345") + assert.same(openai_response.model, "gemini-2.5-pro-2025_08_09") + assert.same(openai_response.created, 1754348518.0) + end) + end) + + describe("bedrock tools", function() + local bedrock_driver + + setup(function() + _G._TEST = true + package.loaded["kong.llm.drivers.bedrock"] = nil + bedrock_driver = require("kong.llm.drivers.bedrock") + end) + + teardown(function() + _G._TEST = nil + end) + + it("transforms openai tools to bedrock tools GOOD", function() + local bedrock_tools = bedrock_driver._to_tools(SAMPLE_OPENAI_TOOLS_REQUEST.tools) + + assert.not_nil(bedrock_tools) + assert.same(bedrock_tools, { + { + toolSpec = { + description = "Check a product is in stock.", + inputSchema = { + json = { + properties = { + product_name = { + type = "string" + } + }, + required = { + "product_name" + }, + type = "object" + } + }, + name = "check_stock" + } + } + }) + end) + + it("transforms openai tools to bedrock tools NO_TOOLS", function() + local bedrock_tools = bedrock_driver._to_tools(SAMPLE_LLM_V1_CHAT) + + assert.is_nil(bedrock_tools) + end) + + it("transforms openai tools to bedrock tools NIL", function() + local bedrock_tools = bedrock_driver._to_tools(nil) + + assert.is_nil(bedrock_tools) + end) + + it("transforms bedrock tools to openai tools GOOD", function() + local openai_tools = bedrock_driver._from_tool_call_response(SAMPLE_BEDROCK_TOOLS_RESPONSE.output.message.content) + + assert.not_nil(openai_tools) + + assert.same(openai_tools[1]['function'], { + name = "sumArea", + arguments = "{\"areas\":[121,212,313]}" + }) + end) + + it("transforms guardrails into bedrock generation config", function() + local model_info = { + route_type = "llm/v1/chat", + name = "some-model", + provider = "bedrock", + } + local bedrock_guardrails = bedrock_driver._to_bedrock_chat_openai(SAMPLE_LLM_V1_CHAT_WITH_GUARDRAILS, model_info, "llm/v1/chat") -local STREAMS = { - openai = { - ["llm/v1/chat"] = { - name = "gpt-4", - provider = "openai", - }, - ["llm/v1/completions"] = { - name = "gpt-3.5-turbo-instruct", - provider = "openai", - }, - }, - cohere = { - ["llm/v1/chat"] = { - name = "command", - provider = "cohere", - }, - ["llm/v1/completions"] = { - name = "command-light", - provider = "cohere", - }, - }, -} + assert.not_nil(bedrock_guardrails) + + assert.same(bedrock_guardrails.guardrailConfig, { + ['guardrailIdentifier'] = 'yu5xwvfp4sud', + ['guardrailVersion'] = '1', + ['trace'] = 'enabled', + }) + end) + end) +end) -local expected_stream_choices = { - ["llm/v1/chat"] = { - [1] = { - delta = { - content = "the answer", - }, - finish_reason = ngx.null, - index = 0, - logprobs = ngx.null, - }, - }, - ["llm/v1/completions"] = { - [1] = { - text = "the answer", - finish_reason = ngx.null, - index = 0, - logprobs = ngx.null, - }, - }, -} describe(PLUGIN_NAME .. ": (unit)", function() setup(function() @@ -472,6 +1473,15 @@ describe(PLUGIN_NAME .. ": (unit)", function() ai_shared = require("kong.llm.drivers.shared") end) + before_each(function() + _G.kong = {} + kong.log = pdk_log.new(kong) + -- provide clean context for each test + kong.ctx = { + plugin = {}, + } + end) + teardown(function() _G.TEST = nil end) @@ -511,7 +1521,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() header_value = "azure-key", }, model = { - name = "$(headers.from_header_1)", + name = "gpt-3.5-turbo", provider = "azure", options = { max_tokens = 256, @@ -519,31 +1529,22 @@ describe(PLUGIN_NAME .. ": (unit)", function() azure_instance = "$(uri_captures.uri_cap_1)", azure_deployment_id = "$(headers.from_header_1)", azure_api_version = "$(query_params.arg_1)", - upstream_url = "https://$(uri_captures.uri_cap_1).example.com", - bedrock = { - aws_region = "$(uri_captures.uri_cap_1)", - } }, }, } local result, err = ai_shared.merge_model_options(fake_request, fake_config) assert.is_falsy(err) - assert.same(result.model.name, "header_value_here_1") assert.same(result.model.options, { - azure_api_version = 'arg_value_here_1', - azure_deployment_id = 'header_value_here_1', - azure_instance = 'cap_value_here_1', - max_tokens = 256, - temperature = 1, - upstream_url = "https://cap_value_here_1.example.com", - bedrock = { - aws_region = "cap_value_here_1", - }, + ['azure_api_version'] = 'arg_value_here_1', + ['azure_deployment_id'] = 'header_value_here_1', + ['azure_instance'] = 'cap_value_here_1', + ['max_tokens'] = 256, + ['temperature'] = 1, }) end) - it("returns appropriate error when referenceable plugin configuration is missing from request context", function() + it("resolves referenceable model name from request context", function() local fake_request = { ["get_header"] = function(header_name) local headers = { @@ -578,72 +1579,21 @@ describe(PLUGIN_NAME .. ": (unit)", function() header_value = "azure-key", }, model = { - name = "gpt-3.5-turbo", - provider = "azure", - options = { - max_tokens = 256, - temperature = 1.0, - azure_instance = "$(uri_captures.uri_cap_3)", - azure_deployment_id = "$(headers.from_header_1)", - azure_api_version = "$(query_params.arg_1)", - }, - }, - } - - local _, err = ai_shared.merge_model_options(fake_request, fake_config) - assert.same("uri_captures key uri_cap_3 was not provided", err) - - local fake_config = { - route_type = "llm/v1/chat", - auth = { - header_name = "api-key", - header_value = "azure-key", - }, - model = { - name = "gpt-3.5-turbo", - provider = "azure", - options = { - max_tokens = 256, - temperature = 1.0, - azure_instance = "$(uri_captures.uri_cap_1)", - azure_deployment_id = "$(headers.from_header_1)", - azure_api_version = "$(query_params.arg_1)", - bedrock = { - aws_region = "$(uri_captures.uri_cap_3)", - } - }, - }, - } - - local _, err = ai_shared.merge_model_options(fake_request, fake_config) - assert.same("uri_captures key uri_cap_3 was not provided", err) - - local fake_config = { - route_type = "llm/v1/chat", - auth = { - header_name = "api-key", - header_value = "azure-key", - }, - model = { - name = "gpt-3.5-turbo", + name = "$(uri_captures.uri_cap_2)", provider = "azure", options = { max_tokens = 256, temperature = 1.0, - azure_instance = "$(uri_captures_uri_cap_1)", + azure_instance = "string-1", + azure_deployment_id = "string-2", + azure_api_version = "string-3", }, }, } - local _, err = ai_shared.merge_model_options(fake_request, fake_config) - assert.same("cannot parse expression for field '$(uri_captures_uri_cap_1)'", err) - end) - - it("llm/v1/chat message is compatible with llm/v1/chat route", function() - local compatible, err = llm.is_compatible(SAMPLE_LLM_V1_CHAT, "llm/v1/chat") - - assert.is_truthy(compatible) - assert.is_nil(err) + local result, err = ai_shared.merge_model_options(fake_request, fake_config) + assert.is_falsy(err) + assert.same("cap_value_here_2", result.model.name) end) it("llm/v1/chat message is not compatible with llm/v1/completions route", function() @@ -669,7 +1619,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() for i, j in pairs(FORMATS) do - describe(i .. " format tests", function() + describe("#" .. i .. " format tests", function() for k, l in pairs(j) do @@ -761,8 +1711,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.is_nil(err) -- compare the tables - assert.same(expected_response_table.choices[1].message, actual_response_table.choices[1].message) - assert.same(actual_response_table.model, expected_response_table.model) + assert.same(expected_response_table, actual_response_table) end) end) end @@ -958,7 +1907,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("transforms complete-json type", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin")) - local events = ai_shared._frame_to_events(input, "text/event-stream") -- not "truncated json mode" like Gemini + local events = ai_shared._frame_to_events(input, "application/stream+json") -- not "truncated json mode" like Gemini local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json")) local expected_events = cjson.decode(expected) @@ -990,6 +1939,23 @@ describe(PLUGIN_NAME .. ": (unit)", function() end end) + it("transforms flat json type separated with \\r", function() + local input = pl_file.read("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/input.json") + -- Cohere's response is \n delimited JSON events + -- https://docs.cohere.com/v1/reference/chat-stream + -- We simulate a flat json input separated with `\r` by replacing \n with \r + input = input:gsub("\n", "\r") + local events = ai_shared._frame_to_events(input, "application/stream+json") + + local expected = pl_file.read("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/expected-output.json") + local expected_events = cjson.decode(expected) + + assert.equal(#events, #expected_events) + for i, _ in ipairs(expected_events) do + assert.same(cjson.decode(events[i].data), cjson.decode(expected_events[i].data)) + end + end) + end) describe("count_words", function() @@ -1198,88 +2164,23 @@ describe(PLUGIN_NAME .. ": (unit)", function() { ["text"] = "A picture of a cat.", }, - }, - }, - { - ["role"] = "user", - ["parts"] = { - { - ["text"] = "Now draw it wearing a party-hat.", - }, - }, - }, - } - }, gemini_prompt) - end) - - end) - - - describe("gemini tools", function() - local gemini_driver - - setup(function() - _G._TEST = true - package.loaded["kong.llm.drivers.gemini"] = nil - gemini_driver = require("kong.llm.drivers.gemini") - end) - - teardown(function() - _G._TEST = nil - end) - - it("transforms openai tools to gemini tools GOOD", function() - local gemini_tools = gemini_driver._to_tools(SAMPLE_OPENAI_TOOLS_REQUEST.tools) - - assert.not_nil(gemini_tools) - assert.same(gemini_tools, { - { - function_declarations = { - { - description = "Check a product is in stock.", - name = "check_stock", - parameters = { - properties = { - product_name = { - type = "string" - } - }, - required = { - "product_name" - }, - type = "object" - } - } - } + }, + }, + { + ["role"] = "user", + ["parts"] = { + { + ["text"] = "Now draw it wearing a party-hat.", + }, + }, + }, } - }) - end) - - it("transforms openai tools to gemini tools NO_TOOLS", function() - local gemini_tools = gemini_driver._to_tools(SAMPLE_LLM_V1_CHAT) - - assert.is_nil(gemini_tools) - end) - - it("transforms openai tools to gemini tools NIL", function() - local gemini_tools = gemini_driver._to_tools(nil) - - assert.is_nil(gemini_tools) + }, gemini_prompt) end) - it("transforms gemini tools to openai tools GOOD", function() - local openai_tools = gemini_driver._from_gemini_chat_openai(SAMPLE_GEMINI_TOOLS_RESPONSE, {}, "llm/v1/chat") - - assert.not_nil(openai_tools) - - openai_tools = cjson.decode(openai_tools) - assert.same(openai_tools.choices[1].message.tool_calls[1]['function'], { - name = "sql_execute", - arguments = "{\"product_name\":\"NewPhone\"}" - }) - end) end) + describe("bedrock tools", function() local bedrock_driver @@ -1870,6 +2771,27 @@ describe(PLUGIN_NAME .. ": (unit)", function() local expected_events = cjson.decode(expected) assert.same(expected_events, events) + + local len = #input + -- Fuzz with possible truncations. We choose to truncate into three parts, so we can test + -- a case that the frame is truncated into more than two parts, to avoid unexpected cleanup + -- of the truncation state. + for i = 2, len - 1 do + for j = i + 1, len do + local events = {} + local delimiters = {} + delimiters[1] = {1, i - 1} + delimiters[2] = {i, j - 1} + delimiters[3] = {j, len} + for k = 1, #delimiters do + local output = ai_shared._frame_to_events(input:sub(delimiters[k][1], delimiters[k][2]), "text/event-stream") + for _, event in ipairs(output or {}) do + table.insert(events, event) + end + end + assert.same(expected_events, events, "failed when the frame is truncated in " .. cjson.encode(delimiters)) + end + end end) it("transforms application/vnd.amazon.eventstream (AWS) type", function() @@ -1884,6 +2806,27 @@ describe(PLUGIN_NAME .. ": (unit)", function() -- tables are random ordered, so we need to compare each serialized event assert.same(cjson.decode(events[i].data), cjson.decode(expected_events[i].data)) end + + local len = #input + -- fuzz with random truncations + for i = 1, len / 2, 10 do + local events = {} + + for j = 0, 2 do + local stop = i * j + i + if j == 2 then + -- the last truncated frame + stop = len + end + local output = ai_shared._frame_to_events(input:sub(i * j + 1, stop), "application/vnd.amazon.eventstream") + for _, event in ipairs(output or {}) do + table.insert(events, event) + end + end + for i, _ in ipairs(expected_events) do + assert.same(cjson.decode(events[i].data), cjson.decode(expected_events[i].data), "failed when the frame is truncated at " .. i) + end + end end) end) @@ -2260,16 +3203,25 @@ end) describe("json_array_iterator", function() local json_array_iterator + local JSON_ARRAY_TYPE lazy_setup(function() _G.TEST = true package.loaded["kong.llm.drivers.shared"] = nil - json_array_iterator = require("kong.llm.drivers.shared")._json_array_iterator + json_array_iterator = require("kong.llm.drivers.shared").json_array_iterator + JSON_ARRAY_TYPE = require("kong.llm.drivers.shared").JSON_ARRAY_TYPE end) -- Helper function to collect all elements from iterator - local function collect_elements(input) + local function collect_elements(input, jsonl) local elements = {} - local iter = json_array_iterator(input) + local json_array_type + if jsonl then + json_array_type = JSON_ARRAY_TYPE.JSONL + else + json_array_type = JSON_ARRAY_TYPE.GEMINI + end + + local iter = json_array_iterator(input, nil, json_array_type) local next_element = iter() while next_element do table.insert(elements, next_element) @@ -2341,13 +3293,13 @@ describe("json_array_iterator", function() local iter -- First chunk (split within string) - iter = json_array_iterator('["hel', state) + iter = json_array_iterator('["hel', state, JSON_ARRAY_TYPE.GEMINI) local element, new_state = iter() assert.is_nil(element) -- Should return nil as string is incomplete state = new_state -- Second chunk (complete string) - iter = json_array_iterator('lo"]', state) + iter = json_array_iterator('lo"]', state, JSON_ARRAY_TYPE.GEMINI) element = iter() assert.are.same('"hello"', element) end) @@ -2362,12 +3314,12 @@ describe("json_array_iterator", function() local iter -- Split during escape sequence - iter = json_array_iterator('["he\\', state) + iter = json_array_iterator('["he\\', state, JSON_ARRAY_TYPE.GEMINI) local element, new_state = iter() assert.is_nil(element) state = new_state - iter = json_array_iterator('\\nllo"]', state) + iter = json_array_iterator('\\nllo"]', state, JSON_ARRAY_TYPE.GEMINI) element = iter() assert.are.same('"he\\\\nllo"', element) end) @@ -2382,12 +3334,12 @@ describe("json_array_iterator", function() local iter -- Split between object definition - iter = json_array_iterator('[{"name": "Jo', state) + iter = json_array_iterator('[{"name": "Jo', state, JSON_ARRAY_TYPE.GEMINI) local element, new_state = iter() assert.is_nil(element) state = new_state - iter = json_array_iterator('hn"}, {"age": 30}]', state) + iter = json_array_iterator('hn"}, {"age": 30}]', state, JSON_ARRAY_TYPE.GEMINI) element = iter() assert.are.same('{"name": "John"}', element) @@ -2405,12 +3357,12 @@ describe("json_array_iterator", function() local iter -- Split between nested array - iter = json_array_iterator('[[1, 2', state) + iter = json_array_iterator('[[1, 2', state, JSON_ARRAY_TYPE.GEMINI) local element, new_state = iter() assert.is_nil(element) state = new_state - iter = json_array_iterator('], [3, 4]]', state) + iter = json_array_iterator('], [3, 4]]', state, JSON_ARRAY_TYPE.GEMINI) element = iter() assert.are.same('[1, 2]', element) @@ -2428,12 +3380,12 @@ describe("json_array_iterator", function() local iter -- Split at comma - iter = json_array_iterator('[1,', state) + iter = json_array_iterator('[1,', state, JSON_ARRAY_TYPE.GEMINI) local element, new_state = iter() assert.are.same('1', element) state = new_state - iter = json_array_iterator(' 2]', state) + iter = json_array_iterator(' 2]', state, JSON_ARRAY_TYPE.GEMINI) element = iter() assert.are.same('2', element) end) @@ -2448,13 +3400,13 @@ describe("json_array_iterator", function() local iter -- Split at comma - iter = json_array_iterator('[{"message":"hello world"}, {"message":"goodbye,', state) + iter = json_array_iterator('[{"message":"hello world"}, {"message":"goodbye,', state, JSON_ARRAY_TYPE.GEMINI) local element, _ = iter() assert.are.same('{"message":"hello world"}',element) local element, _ = iter() assert.is_nil(element) - iter = json_array_iterator(' world"}]', state) + iter = json_array_iterator(' world"}]', state, JSON_ARRAY_TYPE.GEMINI) element = iter() assert.are.same('{"message":"goodbye, world"}', element) end) @@ -2469,17 +3421,17 @@ describe("json_array_iterator", function() local iter -- Complex nested structure split - iter = json_array_iterator('[{"users": [{"id": 1', state) + iter = json_array_iterator('[{"users": [{"id": 1', state, JSON_ARRAY_TYPE.GEMINI) local element, new_state = iter() assert.is_nil(element) state = new_state - iter = json_array_iterator(', "name": "John"}]}, {"status": "', state) + iter = json_array_iterator(', "name": "John"}]}, {"status": "', state, JSON_ARRAY_TYPE.GEMINI) local element, new_state = iter() assert.are.same('{"users": [{"id": 1, "name": "John"}]}', element) state = new_state - iter = json_array_iterator('active"}]', state) + iter = json_array_iterator('active"}]', state, JSON_ARRAY_TYPE.GEMINI) element = iter() assert.are.same('{"status": "active"}', element) end) @@ -2487,7 +3439,7 @@ describe("json_array_iterator", function() it("should error on invalid start", function() assert.has_error(function() - json_array_iterator('{1, 2, 3}')() + json_array_iterator('{1, 2, 3}', nil, JSON_ARRAY_TYPE.GEMINI)() end, "Invalid start: expected '['") end) @@ -2503,4 +3455,450 @@ describe("json_array_iterator", function() elements[2] ) end) + + it("#jsonl should handle complex nested jsonl structures", function() + local input = '{"users": [{"id": 1, "name": "John"}, {"id": 2, "name": "Jane"}]}\n{"status": "active"}' + local elements = collect_elements(input, true) + assert.are.same( + '{\"users\": [{\"id\": 1, \"name\": \"John\"}, {\"id\": 2, \"name\": \"Jane\"}]}', + elements[1] + ) + assert.are.same( + '{\"status\": \"active\"}', + elements[2] + ) + end) + + describe("#jsonl incremental parsing", function() + it("should handle split between object braces", function() + local state = { + started = false, + pos = 1, + input = '', + eof = false, + } + local iter + + -- Split between object definition + iter = json_array_iterator('{"name": "Jo', state, JSON_ARRAY_TYPE.JSONL) + local element, new_state = iter() + assert.is_nil(element) + state = new_state + + iter = json_array_iterator('hn"}\n{"age": 30}', state, JSON_ARRAY_TYPE.JSONL) + element = iter() + assert.are.same('{"name": "John"}', element) + + element = iter() + assert.are.same('{"age": 30}', element) + end) + + it("should handle split between array brackets", function() + local state = { + started = false, + pos = 1, + input = '', + eof = false, + } + local iter + + -- Split between nested array + iter = json_array_iterator('[1, 2', state, JSON_ARRAY_TYPE.JSONL) + local element, new_state = iter() + assert.is_nil(element) + state = new_state + + iter = json_array_iterator(']\n[3, 4]', state, JSON_ARRAY_TYPE.JSONL) + element = iter() + assert.are.same('[1, 2]', element) + + element = iter() + assert.are.same('[3, 4]', element) + end) + + it("should not split between literal \n", function() + local state = { + started = false, + pos = 1, + input = '', + eof = false, + } + local iter + + -- Split at comma + iter = json_array_iterator('{"message":"hello world"}\n{"message":"goodbye\\n', state, JSON_ARRAY_TYPE.JSONL) + local element, _ = iter() + assert.are.same('{"message":"hello world"}',element) + local element, _ = iter() + assert.is_nil(element) + + iter = json_array_iterator(' world"}', state, JSON_ARRAY_TYPE.JSONL) + element = iter() + assert.are.same('{"message":"goodbye\\n world"}', element) + end) + + it("should handle split within complex nested structure", function() + local state = { + started = false, + pos = 1, + input = '', + eof = false, + } + local iter + + -- Complex nested structure split + iter = json_array_iterator('{"users": [{"id": 1', state, JSON_ARRAY_TYPE.JSONL) + local element, new_state = iter() + assert.is_nil(element) + state = new_state + + iter = json_array_iterator(', "name": "John"}]}\n{"status": "', state, JSON_ARRAY_TYPE.JSONL) + local element, new_state = iter() + assert.are.same('{"users": [{"id": 1, "name": "John"}]}', element) + state = new_state + + iter = json_array_iterator('active"}', state, JSON_ARRAY_TYPE.JSONL) + element = iter() + assert.are.same('{"status": "active"}', element) + end) + end) +end) + +describe("iso_8601_to_epoch", function() + local shared = require "kong.llm.drivers.shared" + + describe("Basic functionality", function() + it("should convert Unix epoch start correctly", function() + local result = shared.iso_8601_to_epoch("1970-01-01T00:00:00.000Z") + assert.are.equal(0, result) + end) + + it("should convert a simple timestamp correctly", function() + -- 2000-01-01T00:00:00.000Z should be 946684800 + local result = shared.iso_8601_to_epoch("2000-01-01T00:00:00.000Z") + assert.are.equal(946684800, result) + end) + + it("should handle time components correctly", function() + -- 1970-01-01T01:00:00.000Z should be 3600 (1 hour) + local result = shared.iso_8601_to_epoch("1970-01-01T01:00:00.000Z") + assert.are.equal(3600, result) + end) + + it("should handle full time specification", function() + -- 1970-01-01T12:34:56.000Z should be 12*3600 + 34*60 + 56 = 45296 + local result = shared.iso_8601_to_epoch("1970-01-01T12:34:56.000Z") + assert.are.equal(45296, result) + end) + end) + + describe("Leap year handling", function() + it("should handle leap year correctly", function() + -- 2000 is a leap year, Feb 29th should exist + -- 2000-02-29T00:00:00.000Z should be 951782400 + local result = shared.iso_8601_to_epoch("2000-02-29T00:00:00.000Z") + assert.are.equal(951782400, result) + end) + + it("should handle century year that is not leap year", function() + -- 2100 is not a leap year (divisible by 100 but not 400) + -- 2100-02-01T00:00:00.000Z should account for Feb having only 28 days + local result = shared.iso_8601_to_epoch("2100-02-01T00:00:00.000Z") + local non_leap_year = shared.iso_8601_to_epoch("2101-02-01T00:00:00.000Z") + -- The difference should be exactly one year + local diff = non_leap_year - result + assert.are.equal(365 * 86400, diff) -- 366 days in seconds + end) + + it("should handle year 2000 leap year edge case", function() + -- Year 2000 is divisible by 400, so it IS a leap year + -- Test March 1st in a leap year vs non-leap year + local leap_year = shared.iso_8601_to_epoch("2000-02-01T00:00:00.000Z") + local non_leap_year = shared.iso_8601_to_epoch("2001-02-01T00:00:00.000Z") + + -- The difference should be exactly one year plus one day (leap day) + local diff = non_leap_year - leap_year + assert.are.equal(366 * 86400, diff) -- 366 days in seconds + end) + end) + + describe("Month boundaries", function() + it("should handle month transitions correctly", function() + local jan31 = shared.iso_8601_to_epoch("2000-01-31T00:00:00.000Z") + local feb01 = shared.iso_8601_to_epoch("2000-02-01T00:00:00.000Z") + + -- Should be exactly 1 day difference + assert.are.equal(86400, feb01 - jan31) + end) + + it("should handle December to January transition", function() + local dec31 = shared.iso_8601_to_epoch("1999-12-31T00:00:00.000Z") + local jan01 = shared.iso_8601_to_epoch("2000-01-01T00:00:00.000Z") + + -- Should be exactly 1 day difference + assert.are.equal(86400, jan01 - dec31) + end) + end) + + describe("Known timestamps", function() + it("should convert Y2K timestamp correctly", function() + -- Y2K: 2000-01-01T00:00:00.000Z = 946684800 + local result = shared.iso_8601_to_epoch("2000-01-01T00:00:00.000Z") + assert.are.equal(946684800, result) + end) + + it("should convert a recent timestamp correctly", function() + -- 2024-01-01T00:00:00.000Z = 1704067200 + local result = shared.iso_8601_to_epoch("2024-01-01T00:00:00.000Z") + assert.are.equal(1704067200, result) + end) + + it("should handle timestamps with various time components", function() + -- 2020-06-15T14:30:45.000Z + local result = shared.iso_8601_to_epoch("2020-06-15T14:30:45.000Z") + -- Expected: 1592231445 + assert.are.equal(1592231445, result) + end) + end) + + describe("Edge cases", function() + it("should handle single digit months and days", function() + local result = shared.iso_8601_to_epoch("2020-01-01T01:01:01.000Z") + assert.is_true(result > 0) + end) + + it("should handle end of month correctly", function() + -- Test various month endings + local results = {} + results[1] = shared.iso_8601_to_epoch("2020-01-31T23:59:59.000Z") + results[2] = shared.iso_8601_to_epoch("2020-02-29T23:59:59.000Z") -- Leap year + results[3] = shared.iso_8601_to_epoch("2020-04-30T23:59:59.000Z") + + -- Each should be valid positive numbers + for i, result in ipairs(results) do + assert.is_true(result > 0, "Result " .. i .. " should be positive") + end + end) + end) + + describe("Input validation behavior", function() + it("should handle malformed input gracefully", function() + -- This will return nil values from string.match, causing tonumber to return nil + -- The function should handle this or we should add validation + assert.error_matches(function() + assert(shared.iso_8601_to_epoch("invalid-timestamp")) + end, "Invalid ISO 8601 timestamp format") + end) + + it("should handle missing Z suffix", function() + -- Test what happens with timestamp without Z + assert.error_matches(function() + assert(shared.iso_8601_to_epoch("2020-01-01T00:00:00.000")) + end, "Invalid ISO 8601 timestamp format") + end) + + it("should account for invalid mday", function() + -- Invalid date like 2020-02-30 should not be accepted + assert.error_matches(function() + assert(shared.iso_8601_to_epoch("2020-02-30T00:00:00.000Z")) + end, "Invalid day for the given month") + end) + + it("should handle invalid month, year and day", function() + -- Invalid month like 2020-13-01 should not be accepted + assert.error_matches(function() + assert(shared.iso_8601_to_epoch("2020-13-01T00:00:00.000Z")) + end, "Invalid date/time components") + + -- Invalid year like 2020-01-32 should not be accepted + assert.error_matches(function() + assert(shared.iso_8601_to_epoch("2020-01-32T00:00:00.000Z")) + end, "Invalid date/time components") + -- Invalid year like 2020-00-01 should not be accepted + assert.error_matches(function() + assert(shared.iso_8601_to_epoch("2020-00-01T00:00:00.000Z")) + end, "Invalid date/time components") + + -- Invalid year like 2020-01-00 should not be accepted + assert.error_matches(function() + assert(shared.iso_8601_to_epoch("2020-01-00T00:00:00.000Z")) + end, "Invalid date/time components") + + end) + end) +end) + +describe("upstream_url capture groups", function() + local mock_request + + lazy_setup(function() + end) + + before_each(function() + -- Mock Kong request object for testing capture groups + mock_request = { + get_uri_captures = function() + return { + named = { + api = "api", + chat = "chat", + completions = "completions" + }, + unnamed = { + [0] = "/api/chat", + [1] = "api", + [2] = "chat" + } + } + end, + get_header = function(key) + if key == "x-test-header" then + return "test-value" + end + return nil + end, + get_query_arg = function(key) + if key == "test_param" then + return "param-value" + end + return nil + end + } + end) + + describe("merge_model_options function", function() + local shared = require "kong.llm.drivers.shared" + + it("resolves capture group templates in upstream_url", function() + local conf_m = { + upstream_url = "http://127.0.0.1:11434/$(uri_captures.api)/$(uri_captures.chat)" + } + + local result, err = shared.merge_model_options(mock_request, conf_m) + + assert.is_nil(err) + assert.not_nil(result) + assert.equal("http://127.0.0.1:11434/api/chat", result.upstream_url) + end) + + it("resolves multiple capture groups in same string", function() + local conf_m = { + upstream_url = "http://127.0.0.1:11434/$(uri_captures.api)/$(uri_captures.chat)", + custom_path = "/$(uri_captures.api)-$(uri_captures.chat)-endpoint" + } + + local result, err = shared.merge_model_options(mock_request, conf_m) + + assert.is_nil(err) + assert.not_nil(result) + assert.equal("http://127.0.0.1:11434/api/chat", result.upstream_url) + assert.equal("/api-chat-endpoint", result.custom_path) + end) + + it("resolves capture groups in nested tables", function() + local conf_m = { + model = { + options = { + upstream_url = "http://127.0.0.1:11434/$(uri_captures.api)/$(uri_captures.chat)", + custom_endpoint = "/$(uri_captures.api)/v1" + } + } + } + + local result, err = shared.merge_model_options(mock_request, conf_m) + + assert.is_nil(err) + assert.not_nil(result) + assert.equal("http://127.0.0.1:11434/api/chat", result.model.options.upstream_url) + assert.equal("/api/v1", result.model.options.custom_endpoint) + end) + + end) + + describe("real route scenario tests", function() + local shared = require "kong.llm.drivers.shared" + + it("simulates llama2-chat route with capture groups", function() + -- Simulate the actual route configuration from the user's example + local targets_config = { + { + model = { + options = { + upstream_url = "http://127.0.0.1:11434/$(uri_captures.api)/$(uri_captures.chat)", + llama2_format = "ollama" + }, + provider = "llama2" + } + } + } + + -- Process the first target's model options + local result, err = shared.merge_model_options(mock_request, targets_config[1].model.options) + + assert.is_nil(err) + assert.not_nil(result) + assert.equal("http://127.0.0.1:11434/api/chat", result.upstream_url) + assert.equal("ollama", result.llama2_format) + end) + + it("handles route path ~/(?[a-z]+)/(?[a-z]+)$ correctly", function() + -- Mock a request that would match the route pattern ~/(?[a-z]+)/(?[a-z]+)$ + -- For request path "/api/chat" + local request_with_captures = { + get_uri_captures = function() + return { + named = { + api = "api", + chat = "chat" + }, + unnamed = { + [0] = "/api/chat", + [1] = "api", + [2] = "chat" + } + } + end + } + + local conf_m = { + upstream_url = "http://127.0.0.1:11434/$(uri_captures.api)/$(uri_captures.chat)" + } + + local result, err = shared.merge_model_options(request_with_captures, conf_m) + + assert.is_nil(err) + assert.not_nil(result) + assert.equal("http://127.0.0.1:11434/api/chat", result.upstream_url) + end) + + it("handles different capture group values dynamically", function() + -- Test with different capture values that could match the regex pattern + local request_with_v1_completions = { + get_uri_captures = function() + return { + named = { + api = "v1", + chat = "completions" + }, + unnamed = { + [0] = "/v1/completions", + [1] = "v1", + [2] = "completions" + } + } + end + } + + local conf_m = { + upstream_url = "http://127.0.0.1:11434/$(uri_captures.api)/$(uri_captures.chat)" + } + + local result, err = shared.merge_model_options(request_with_v1_completions, conf_m) + + assert.is_nil(err) + assert.not_nil(result) + assert.equal("http://127.0.0.1:11434/v1/completions", result.upstream_url) + end) + end) + end) \ No newline at end of file diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index 99d4026769b..8860d7cfde1 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -81,6 +81,7 @@ for _, strategy in helpers.all_strategies() do listen ]]..MOCK_PORT..[[; default_type 'application/json'; + client_body_buffer_size 64k; # ensure we can test with larger payloads location = "/llm/v1/chat/good" { @@ -114,6 +115,17 @@ for _, strategy in helpers.all_strategies() do } } + location = "/llm/v1/chat/good-with-payloads-preserved" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.req.read_body() + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/good.json")) + } + } + location = "/llm/v1/chat/bad_upstream_response" { content_by_lua_block { local pl_file = require "pl.file" @@ -334,6 +346,52 @@ for _, strategy in helpers.all_strategies() do } -- + -- 200 chat good with one option + local chat_good_with_no_upstream_port = assert(bp.routes:insert { + service = empty_service, + protocols = { "http", "https" }, + strip_path = true, + paths = { "/openai/llm/v1/chat/good_with_no_upstream_port" }, + snis = { "example.test" }, + }) + bp.plugins:insert { + name = PLUGIN_NAME, + id = "6e7c40f6-ce96-48e4-a366-d109c169e44a", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + allow_override = true, + }, + model = { + name = "gpt-3.5-turbo", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host .."/llm/v1/chat/good", + input_cost = 10.0, + output_cost = 10.0, + }, + }, + }, + } + + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + path = FILE_LOG_PATH_NO_LOGS, + }, + } + + -- 200 chat good with statistics disabled local chat_good_no_stats = assert(bp.routes:insert { service = empty_service, @@ -416,6 +474,48 @@ for _, strategy in helpers.all_strategies() do } -- + -- 200 chat good with all logging enabled, perserved without transformation + local chat_good_log_payloads_preserved = assert(bp.routes:insert { + service = empty_service, + protocols = { "http", "https" }, + strip_path = true, + paths = { "/llm/v1/chat/good-with-payloads-preserved" }, + snis = { "example.test" }, + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_log_payloads_preserved.id }, + config = { + route_type = "preserve", + max_request_body_size = 64 * 1024, -- 64KB, ensure we can test with larger payloads + logging = { + log_payloads = true, + log_statistics = true, + }, + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_log_payloads_preserved.id }, + config = { + path = FILE_LOG_PATH_WITH_PAYLOADS, + }, + } + -- + -- 200 chat bad upstream response with one option local chat_bad_upstream = assert(bp.routes:insert { service = empty_service, @@ -952,8 +1052,8 @@ for _, strategy in helpers.all_strategies() do local _, message = next(log_message.ai) -- test request bodies - assert.matches('"content":"What is 1 + 1?"', message.payload.request, nil, true) - assert.matches('"role":"user"', message.payload.request, nil, true) + assert.matches('"content": "What is 1 + 1?"', message.payload.request, nil, true) + assert.matches('"role": "user"', message.payload.request, nil, true) -- test response bodies assert.matches('"content": "The sum of 1 + 1 is 2.",', message.payload.response, nil, true) @@ -1022,6 +1122,24 @@ for _, strategy in helpers.all_strategies() do assert.res_status(200 , r) end) + it("authorized request with client header auth and no self upstream port", function() + local r = client:post("/openai/llm/v1/chat/good_with_no_upstream_port", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + ["Authorization"] = "Bearer openai-key", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(502 , r) + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_NO_LOGS) + assert.same("127.0.0.1", log_message.client_ip) + local tries = log_message.tries + assert.is_table(tries) + assert.equal(tries[1].port, 80) + end) + it("authorized request with client right header auth with no allow_override", function() local r = client:get("/openai/llm/v1/chat/good-no-allow-override", { headers = { @@ -1047,7 +1165,6 @@ for _, strategy in helpers.all_strategies() do assert.res_status(200 , r) end) - end) describe("openai llm/v1/chat", function() @@ -1636,14 +1753,33 @@ for _, strategy in helpers.all_strategies() do local _, message = next(log_message.ai) -- test request bodies - assert.matches('"text":"What\'s in this image?"', message.payload.request, nil, true) - assert.matches('"role":"user"', message.payload.request, nil, true) + assert.matches('"text": "What\'s in this image?"', message.payload.request, nil, true) + assert.matches('"role": "user"', message.payload.request, nil, true) -- test response bodies assert.matches('"content": "The sum of 1 + 1 is 2.",', message.payload.response, nil, true) assert.matches('"role": "assistant"', message.payload.response, nil, true) assert.matches('"id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2"', message.payload.response, nil, true) end) + + it("logs huge payloads", function() + local request_body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_multi_modal.json") + local obj = cjson.decode(request_body) + local padding = string.rep("x", 32 * 1024) -- make it larger than 32k + obj.messages[1].content = obj.messages[1].content .. " " .. padding + request_body = cjson.encode(obj) + local r = client:post("/llm/v1/chat/good-with-payloads-preserved", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = request_body, + }) + assert.res_status(200, r) + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) + local _, message = next(log_message.ai) + assert.matches(padding, message.payload.request, nil, true) + end) end) describe("one-shot request", function() diff --git a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua index 71174601fc8..75e282b6149 100644 --- a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua @@ -3,8 +3,36 @@ local cjson = require "cjson" local pl_file = require "pl.file" local deepcompare = require("pl.tablex").deepcompare +local strip = require("kong.tools.string").strip + +local ANTHROPIC_MOCK = pl_file.read("spec/fixtures/ai-proxy/mock_servers/anthropic.lua.txt") local PLUGIN_NAME = "ai-proxy" +local FILE_LOG_PATH_NO_LOGS = os.tmpname() + + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end + for _, strategy in helpers.all_strategies() do describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() local client @@ -20,191 +48,7 @@ for _, strategy in helpers.all_strategies() do http_mock = {}, } - fixtures.http_mock.anthropic = [[ - server { - server_name anthropic; - listen ]]..MOCK_PORT..[[; - - default_type 'application/json'; - - - location = "/llm/v1/chat/good" { - content_by_lua_block { - local pl_file = require "pl.file" - local json = require("cjson.safe") - - local token = ngx.req.get_headers()["x-api-key"] - if token == "anthropic-key" then - ngx.req.read_body() - local body, err = ngx.req.get_body_data() - body, err = json.decode(body) - - if err or (not body.messages) then - ngx.status = 400 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) - else - ngx.status = 200 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json")) - end - else - ngx.status = 401 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json")) - end - } - } - - location = "/llm/v1/chat/bad_upstream_response" { - content_by_lua_block { - local pl_file = require "pl.file" - local json = require("cjson.safe") - - local token = ngx.req.get_headers()["x-api-key"] - if token == "anthropic-key" then - ngx.req.read_body() - local body, err = ngx.req.get_body_data() - body, err = json.decode(body) - - if err or (not body.messages) then - ngx.status = 400 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) - else - ngx.status = 200 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_upstream_response.json")) - end - else - ngx.status = 401 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json")) - end - } - } - - location = "/llm/v1/chat/no_usage_upstream_response" { - content_by_lua_block { - local pl_file = require "pl.file" - local json = require("cjson.safe") - - local token = ngx.req.get_headers()["x-api-key"] - if token == "anthropic-key" then - ngx.req.read_body() - local body, err = ngx.req.get_body_data() - body, err = json.decode(body) - - if err or (not body.messages) then - ngx.status = 400 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) - else - ngx.status = 200 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/no_usage_response.json")) - end - else - ngx.status = 401 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json")) - end - } - } - - location = "/llm/v1/chat/malformed_usage_upstream_response" { - content_by_lua_block { - local pl_file = require "pl.file" - local json = require("cjson.safe") - - local token = ngx.req.get_headers()["x-api-key"] - if token == "anthropic-key" then - ngx.req.read_body() - local body, err = ngx.req.get_body_data() - body, err = json.decode(body) - - if err or (not body.messages) then - ngx.status = 400 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) - else - ngx.status = 200 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/malformed_usage_response.json")) - end - else - ngx.status = 401 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json")) - end - } - } - - location = "/llm/v1/chat/bad_request" { - content_by_lua_block { - local pl_file = require "pl.file" - - ngx.status = 400 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) - } - } - - location = "/llm/v1/chat/internal_server_error" { - content_by_lua_block { - local pl_file = require "pl.file" - - ngx.status = 500 - ngx.header["content-type"] = "text/html" - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/internal_server_error.html")) - } - } - - location = "/llm/v1/chat/tool_choice" { - content_by_lua_block { - local pl_file = require "pl.file" - local json = require("cjson.safe") - - ngx.req.read_body() - local function assert_ok(ok, err) - if not ok then - ngx.status = 500 - ngx.say(err) - ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR) - end - return ok - end - local body = assert_ok(ngx.req.get_body_data()) - body = assert_ok(json.decode(body)) - local tool_choice = body.tool_choice - ngx.header["tool-choice"] = json.encode(tool_choice) - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json")) - } - } - - location = "/llm/v1/completions/good" { - content_by_lua_block { - local pl_file = require "pl.file" - local json = require("cjson.safe") - - local token = ngx.req.get_headers()["x-api-key"] - if token == "anthropic-key" then - ngx.req.read_body() - local body, err = ngx.req.get_body_data() - body, err = json.decode(body) - - if err or (not body.prompt) then - ngx.status = 400 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/bad_request.json")) - else - ngx.status = 200 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/good.json")) - end - else - ngx.status = 401 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/unauthorized.json")) - end - } - } - - location = "/llm/v1/completions/bad_request" { - content_by_lua_block { - local pl_file = require "pl.file" - - ngx.status = 400 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/bad_request.json")) - } - } - - } - ]] + fixtures.http_mock.anthropic = string.format(ANTHROPIC_MOCK, MOCK_PORT) local empty_service = assert(bp.services:insert { name = "empty_service", @@ -243,6 +87,43 @@ for _, strategy in helpers.all_strategies() do }, } + local chat_good_with_no_upstream_port = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/anthropic/llm/v1/chat/good_with_no_upstream_port" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_with_no_upstream_port.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "x-api-key", + header_value = "anthropic-key", + allow_override = true, + }, + model = { + name = "claude-2.1", + provider = "anthropic", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host.."/llm/v1/chat/good", + anthropic_version = "2023-06-01", + }, + }, + }, + } + + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + path = FILE_LOG_PATH_NO_LOGS, + }, + } + local chat_good_no_allow_override = assert(bp.routes:insert { service = empty_service, protocols = { "http" }, @@ -634,6 +515,24 @@ for _, strategy in helpers.all_strategies() do }, json.choices[1].message) end) + it("good request with no upstream port", function() + local r = client:get("/anthropic/llm/v1/chat/good_with_no_upstream_port", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(502 , r) + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_NO_LOGS) + assert.same("127.0.0.1", log_message.client_ip) + local tries = log_message.tries + assert.is_table(tries) + assert.equal(tries[1].port, 80) + end) + + it("good request with client right header auth", function() local r = client:get("/anthropic/llm/v1/chat/good", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua index 2efe5fa0e9a..fd59cffaade 100644 --- a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua @@ -1,8 +1,34 @@ local helpers = require "spec.helpers" local cjson = require "cjson" local pl_file = require "pl.file" +local strip = require("kong.tools.string").strip + local PLUGIN_NAME = "ai-proxy" +local FILE_LOG_PATH_NO_LOGS = os.tmpname() + + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end for _, strategy in helpers.all_strategies() do describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() local client @@ -168,6 +194,43 @@ for _, strategy in helpers.all_strategies() do }, }, } + + local chat_good_with_no_upstream_port = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/cohere/llm/v1/chat/good_with_no_upstream_port" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_with_no_upstream_port.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer cohere-key", + allow_override = true, + }, + model = { + name = "command", + provider = "cohere", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host.. "/llm/v1/chat/good", + }, + }, + }, + } + + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + path = FILE_LOG_PATH_NO_LOGS, + }, + } + local chat_good_no_allow_override = assert(bp.routes:insert { service = empty_service, protocols = { "http" }, @@ -455,6 +518,23 @@ for _, strategy in helpers.all_strategies() do }, json.choices[1].message) end) + it("good request with no upstream port", function() + local r = client:get("/cohere/llm/v1/chat/good_with_no_upstream_port", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good.json"), + }) + + assert.res_status(502 , r) + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_NO_LOGS) + assert.same("127.0.0.1", log_message.client_ip) + local tries = log_message.tries + assert.is_table(tries) + assert.equal(tries[1].port, 80) + end) + it("good request with right client auth", function() local r = client:get("/cohere/llm/v1/chat/good", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua index 110baeeda50..cd57ec62b3e 100644 --- a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua @@ -2,6 +2,34 @@ local helpers = require "spec.helpers" local cjson = require "cjson" local pl_file = require "pl.file" +local strip = require("kong.tools.string").strip + +local FILE_LOG_PATH_NO_LOGS = os.tmpname() + + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end + + local PLUGIN_NAME = "ai-proxy" for _, strategy in helpers.all_strategies() do @@ -233,6 +261,43 @@ for _, strategy in helpers.all_strategies() do }, } + local chat_good_with_no_upstream_port = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/azure/llm/v1/chat/good_with_no_upstream_port" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_with_no_upstream_port.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + allow_override = true, + }, + model = { + name = "gpt-3.5-turbo", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host .."/llm/v1/chat/good", + azure_instance = "001-kong-t", + azure_deployment_id = "gpt-3.5-custom", + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + path = FILE_LOG_PATH_NO_LOGS, + }, + } + local chat_good_no_allow_override = assert(bp.routes:insert { service = empty_service, protocols = { "http" }, @@ -627,6 +692,25 @@ for _, strategy in helpers.all_strategies() do }, json.choices[1].message) end) + it("good request", function() + local r = client:get("/azure/llm/v1/chat/good_with_no_upstream_port", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 502 + + assert.res_status(502 , r) + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_NO_LOGS) + assert.same("127.0.0.1", log_message.client_ip) + local tries = log_message.tries + assert.is_table(tries) + assert.equal(tries[1].port, 80) + end) + it("good request with client right auth", function() local r = client:get("/azure/llm/v1/chat/good", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua index 458505dc9d5..eaf155e01ea 100644 --- a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua @@ -1,9 +1,36 @@ local helpers = require "spec.helpers" local cjson = require "cjson" local pl_file = require "pl.file" +local strip = require("kong.tools.string").strip + local PLUGIN_NAME = "ai-proxy" +local FILE_LOG_PATH_NO_LOGS = os.tmpname() + + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end + for _, strategy in helpers.all_strategies() do describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() local client @@ -115,6 +142,45 @@ for _, strategy in helpers.all_strategies() do }, } + local chat_good_with_no_upstream_port = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/mistral/llm/v1/chat/good_with_no_upstream_port" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_with_no_upstream_port.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer mistral-key", + allow_override = true, + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.1-instruct", + provider = "mistral", + options = { + max_tokens = 256, + temperature = 1.0, + mistral_format = "openai", + upstream_url = "http://"..helpers.mock_upstream_host .."/v1/chat/completions", + }, + }, + }, + } + + + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + path = FILE_LOG_PATH_NO_LOGS, + }, + } + + local chat_good_no_allow_override = assert(bp.routes:insert { service = empty_service, protocols = { "http" }, @@ -380,6 +446,25 @@ for _, strategy in helpers.all_strategies() do }, json.choices[1].message) end) + it("good request with no upstream port", function() + local r = client:get("/mistral/llm/v1/chat/good_with_no_upstream_port", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + + -- validate that the request succeeded, response status 502 + assert.res_status(502 , r) + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_NO_LOGS) + assert.same("127.0.0.1", log_message.client_ip) + local tries = log_message.tries + assert.is_table(tries) + assert.equal(tries[1].port, 80) + end) + + it("good request with client right auth", function() local r = client:get("/mistral/llm/v1/chat/good", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua index 4b49979159b..31a7186e63e 100644 --- a/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/07-llama2_integration_spec.lua @@ -2,8 +2,36 @@ local helpers = require "spec.helpers" local cjson = require "cjson" local pl_file = require "pl.file" +local strip = require("kong.tools.string").strip + + local PLUGIN_NAME = "ai-proxy" +local FILE_LOG_PATH_NO_LOGS = os.tmpname() + + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end + for _, strategy in helpers.all_strategies() do describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() local client @@ -113,6 +141,41 @@ for _, strategy in helpers.all_strategies() do }, }, } + + local chat_good_with_no_upstream_port = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/raw/llm/v1/chat/completions_no_upstream_port" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good_with_no_upstream_port.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer llama2-key", + }, + model = { + name = "llama-2-7b-chat-hf", + provider = "llama2", + options = { + max_tokens = 256, + temperature = 1.0, + llama2_format = "raw", + upstream_url = "http://"..helpers.mock_upstream_host .."/raw/llm/v1/chat", + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + path = FILE_LOG_PATH_NO_LOGS, + }, + } -- -- 200 completions good with one option @@ -216,6 +279,23 @@ for _, strategy in helpers.all_strategies() do assert.equals(json.choices[1].message.content, "Is a well known font.") end) + it("runs good request in chat format with no upstream port", function() + local r = client:get("/raw/llm/v1/chat/completions_no_upstream_port", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/llama2/raw/requests/good-chat.json"), + }) + + assert.res_status(502, r) + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_NO_LOGS) + assert.same("127.0.0.1", log_message.client_ip) + local tries = log_message.tries + assert.is_table(tries) + assert.equal(tries[1].port, 80) + end) + it("runs good request in completions format", function() local r = client:get("/raw/llm/v1/completions", { headers = { diff --git a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua index bdac132d114..d2b378e10d0 100644 --- a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua @@ -19,10 +19,10 @@ local _EXPECTED_CHAT_STATS = { }, usage = { prompt_tokens = 18, - completion_tokens = 13, -- this was from estimation - total_tokens = 31, + completion_tokens = 7, -- this was from estimation + total_tokens = 25, time_per_token = 1, - cost = 0.00031, + cost = 0.00025, }, } @@ -130,6 +130,8 @@ for _, strategy in helpers.all_strategies() do for i, EVENT in ipairs(_EVENT_CHUNKS) do ngx.print(fmt("%s\n\n", EVENT)) + ngx.sleep(0.01) -- simulate delay for latency assertion + ngx.flush(true) end end else @@ -233,10 +235,10 @@ for _, strategy in helpers.all_strategies() do -- GOOD RESPONSE ngx.status = 200 - ngx.header["Content-Type"] = "text/event-stream" + ngx.header["Content-Type"] = "application/stream+json" for i, EVENT in ipairs(_EVENT_CHUNKS) do - ngx.print(fmt("%s\n\n", EVENT)) + ngx.print(fmt("%s\n", EVENT)) end end else @@ -313,6 +315,37 @@ for _, strategy in helpers.all_strategies() do } } + location = "/gemini/llm/v1/chat/good" { + content_by_lua_block { + local _EVENT_CHUNKS = { + [1] = '[', + [2] = '{"candidates": [{"content": {"parts": [{"text": "Gemini"}], "role": "model"}, "finishReason": "NULL", "index": 0}]},', + [3] = '{"candidates": [{"content": {"parts": [{"text": " is"}], "role": "model"}, "finishReason": "NULL", "index": 0}]},', + [4] = '{"candidates": [{"content": {"parts": [{"text": " a"}], "role": "model"}, "finishReason": "NULL", "index": 0}]},', + [5] = '{"candidates": [{"content": {"parts": [{"text": " powerful"}], "role": "model"}, "finishReason": "NULL", "index": 0}]},', + [6] = '{"candidates": [{"content": {"parts": [{"text": " model"}], "role": "model"}, "finishReason": "NULL", "index": 0}]},', + [7] = '{"candidates": [{"content": {"parts": [{"text": "."}], "role": "model"}, "finishReason": "STOP", "index": 0}]}', + [8] = ']', + } + + local fmt = string.format + local pl_file = require "pl.file" + local json = require("cjson.safe") + + ngx.req.read_body() + + -- GOOD RESPONSE + ngx.status = 200 + ngx.header["Content-Type"] = "application/json" + + for i, EVENT in ipairs(_EVENT_CHUNKS) do + ngx.print(EVENT) + ngx.flush(true) + ngx.sleep(0.01) -- simulate delay for latency assertion + end + } + } + location = "/openai/llm/v1/chat/bad" { content_by_lua_block { local fmt = string.format @@ -397,6 +430,67 @@ for _, strategy in helpers.all_strategies() do end } } + + location = "/bedrock/llm/v1/chat/streaming/good" { + content_by_lua_block { + local pl_file = require "pl.file" + local _EVENT_CHUNKS = {} + + for i=1,4 do + local encoded = pl_file.read("spec/fixtures/ai-proxy/bedrock/chunks-normal/chunk-" .. i .. ".bin") + local decoded = ngx.decode_base64(encoded) + _EVENT_CHUNKS[i] = decoded + end + + local fmt = string.format + local json = require("cjson.safe") + + ngx.req.read_body() + + -- GOOD RESPONSE + ngx.status = 200 + ngx.header["Content-Type"] = "application/vnd.amazon.eventstream" + + for i, EVENT in ipairs(_EVENT_CHUNKS) do + ngx.print(fmt("%s", EVENT)) + end + } + } + + location = "/bedrock/llm/v1/chat/streaming/incomplete" { + content_by_lua_block { + local pl_file = require "pl.file" + local all_chunks = "" + + for i=1,4 do + local encoded = pl_file.read("spec/fixtures/ai-proxy/bedrock/chunks-normal/chunk-" .. i .. ".bin") + local decoded = ngx.decode_base64(encoded) + all_chunks = all_chunks .. decoded + end + + local json = require("cjson.safe") + + ngx.req.read_body() + + -- GOOD RESPONSE + ngx.status = 200 + ngx.header["Content-Type"] = "application/vnd.amazon.eventstream" + + local len = #all_chunks + local i = math.floor(len / 3) + for j = 0, 2 do + local stop = i * j + i + if j == 2 then + -- the last truncated frame + stop = len + end + local to_send = all_chunks:sub(i * j + 1, stop) + ngx.print(to_send) + ngx.flush(true) + ngx.sleep(0.1) -- simulate delay + end + } + } } ]] @@ -423,6 +517,10 @@ for _, strategy in helpers.all_strategies() do header_name = "Authorization", header_value = "Bearer openai-key", }, + logging = { + log_payloads = true, + log_statistics = true, + }, model = { name = "gpt-3.5-turbo", provider = "openai", @@ -440,7 +538,7 @@ for _, strategy in helpers.all_strategies() do name = "file-log", route = { id = openai_chat_good.id }, config = { - path = "/dev/stdout", + path = FILE_LOG_PATH_WITH_PAYLOADS, }, } -- @@ -565,6 +663,49 @@ for _, strategy in helpers.all_strategies() do } -- + -- 200 chat gemini + local gemini_chat_good = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/gemini/llm/v1/chat/good" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = gemini_chat_good.id }, + config = { + route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, + model = { + name = "gemini-1.5-flash", + provider = "gemini", + options = { + max_tokens = 512, + temperature = 0.6, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/gemini/llm/v1/chat/good", + input_cost = 20.0, + output_cost = 20.0, + }, + }, + auth = { + header_name = "x-goog-api-key", + header_value = "123", + allow_override = false, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = gemini_chat_good.id }, + config = { + path = "/dev/stdout", + }, + } + -- + -- 400 chat openai local openai_chat_bad = assert(bp.routes:insert { service = empty_service, @@ -777,6 +918,88 @@ for _, strategy in helpers.all_strategies() do } -- + -- 200 chat bedrock streaming + local bedrock_chat_streaming_good = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/bedrock/llm/v1/chat/streaming/good" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = bedrock_chat_streaming_good.id }, + config = { + route_type = "llm/v1/chat", + logging = { + log_payloads = true, + log_statistics = true, + }, + model = { + name = "anthropic.claude-3-sonnet-20240229-v1:0", + provider = "bedrock", + options = { + max_tokens = 512, + temperature = 0.7, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/bedrock/llm/v1/chat/streaming/good", + input_cost = 15.0, + output_cost = 15.0, + }, + }, + auth = { + allow_override = false, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = bedrock_chat_streaming_good.id }, + config = { + path = FILE_LOG_PATH_WITH_PAYLOADS, + }, + } + -- + + -- 200 chat bedrock streaming with incomplete events + local bedrock_chat_streaming_incomplete = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/bedrock/llm/v1/chat/streaming/incomplete" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = bedrock_chat_streaming_incomplete.id }, + config = { + route_type = "llm/v1/chat", + logging = { + log_payloads = true, + log_statistics = true, + }, + model = { + name = "anthropic.claude-3-sonnet-20240229-v1:0", + provider = "bedrock", + options = { + max_tokens = 512, + temperature = 0.7, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/bedrock/llm/v1/chat/streaming/incomplete", + input_cost = 15.0, + output_cost = 15.0, + }, + }, + auth = { + allow_override = false, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = bedrock_chat_streaming_incomplete.id }, + config = { + path = FILE_LOG_PATH_WITH_PAYLOADS, + }, + } + -- + helpers.setenv("AWS_REGION", "us-east-1") -- start kong @@ -848,6 +1071,8 @@ for _, strategy in helpers.all_strategies() do end if buffer then + -- ensure the last chunk ends properly + assert.equal("\n\n", buffer:sub(-2)) -- we need to rip each message from this chunk for s in buffer:gmatch("[^\r\n]+") do local s_copy = s @@ -871,6 +1096,11 @@ for _, strategy in helpers.all_strategies() do assert.equal(buf:tostring(), "The answer to 1 + 1 is 2.") -- to verifiy not enable `kong.service.request.enable_buffering()` assert.logfile().has.no.line("/kong_buffered_http", true, 10) + + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) + local actual_stats = log_message.ai.proxy + local actual_llm_latency = actual_stats.meta.llm_latency + assert.is_true(actual_llm_latency > 60) -- 6 events, each with 10ms latency end) it("good stream request openai with partial split chunks", function() @@ -958,7 +1188,8 @@ for _, strategy in helpers.all_strategies() do assert.is_true(actual_llm_latency >= 0) assert.same(tonumber(string.format("%.3f", actual_time_per_token)), tonumber(string.format("%.3f", time_per_token))) assert.match_re(actual_request_log, [[.*content.*What is 1 \+ 1.*]]) - assert.match_re(actual_response_log, [[.*content.*The answer.*]]) + -- include the quotes to match the whole string + assert.match_re(actual_response_log, [[.*"The answer to 1 \+ 1 is 2\.".*]]) -- to verifiy not enable `kong.service.request.enable_buffering()` assert.logfile().has.no.line("/kong_buffered_http", true, 10) end) @@ -1091,6 +1322,74 @@ for _, strategy in helpers.all_strategies() do assert.logfile().has.no.line("/kong_buffered_http", true, 10) end) + it("good stream request gemini", function() + local httpc = http.new() + + local ok, err, _ = httpc:connect({ + scheme = "http", + host = helpers.mock_upstream_host, + port = helpers.get_proxy_port(), + }) + if not ok then + assert.is_nil(err) + end + + -- Then send using `request`, supplying a path and `Host` header instead of a + -- full URI. + local res, err = httpc:request({ + path = "/gemini/llm/v1/chat/good", + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json"), + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + }) + if not res then + assert.is_nil(err) + end + + local reader = res.body_reader + local buffer_size = 35536 + local events = {} + local buf = require("string.buffer").new() + + -- extract event + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + assert.is_falsy(err and err ~= "closed") + end + + if buffer then + -- we need to rip each message from this chunk + for s in buffer:gmatch("[^\r\n]+") do + local s_copy = s + s_copy = string.sub(s_copy,7) + s_copy = cjson.decode(s_copy) + + if s_copy then + assert.equal("gemini-1.5-flash", s_copy.model) + end + + buf:put(s_copy + and s_copy.choices + and s_copy.choices + and s_copy.choices[1] + and s_copy.choices[1].delta + and s_copy.choices[1].delta.content + or "") + table.insert(events, s) + end + end + until not buffer + + assert.equal(7, #events) + assert.equal("Gemini is a powerful model.", buf:tostring()) + -- to verifiy not enable `kong.service.request.enable_buffering()` + assert.logfile().has.no.line("/kong_buffered_http", true, 10) + end) + it("bad request is returned to the client not-streamed", function() local httpc = http.new() @@ -1191,6 +1490,7 @@ for _, strategy in helpers.all_strategies() do s = cjson.decode(s_copy) if s and s.choices then + assert.equal("gemini-1.5-flash", s.model) func_name = s.choices[1].delta.tool_calls[1]['function'].name func_args = s.choices[1].delta.tool_calls[1]['function'].arguments end @@ -1391,6 +1691,153 @@ for _, strategy in helpers.all_strategies() do assert.logfile().has.no.line("/kong_buffered_http", true, 10) end) + it("good stream request bedrock streaming", function() + local httpc = http.new() + + local ok, err, _ = httpc:connect({ + scheme = "http", + host = helpers.mock_upstream_host, + port = helpers.get_proxy_port(), + }) + if not ok then + assert.is_nil(err) + end + + -- Then send using `request`, supplying a path and `Host` header instead of a + -- full URI. + local res, err = httpc:request({ + path = "/bedrock/llm/v1/chat/streaming/good", + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json"), + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + }) + if not res then + assert.is_nil(err) + end + + assert.equal(200, res.status) + + local reader = res.body_reader + local buffer_size = 35536 + local events = {} + local buf = require("string.buffer").new() + + -- extract event + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + assert.is_falsy(err and err ~= "closed") + end + + if buffer then + -- ensure the last chunk ends properly + assert.equal("\n\n", buffer:sub(-2)) + -- we need to rip each message from this chunk + for s in buffer:gmatch("[^\r\n]+") do + local s_copy = s + s_copy = string.sub(s_copy,7) + s_copy = cjson.decode(s_copy) + + buf:put(s_copy + and s_copy.choices + and s_copy.choices[1] + and s_copy.choices[1].delta + and s_copy.choices[1].delta.content + or "") + + table.insert(events, s) + end + end + until not buffer + + assert.equal(#events, 17) -- 16 data events + [DONE] + assert.equal(buf:tostring(), "1 + 1 = 2\n\nThis is one of the most basic arithmetic equations. It represents the addition of two units, resulting in a sum of two.") + + -- test analytics on this item + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + + -- to verify not enable `kong.service.request.enable_buffering()` + assert.logfile().has.no.line("/kong_buffered_http", true, 10) + end) + + it("good stream request bedrock streaming with incomplete events", function() + local httpc = http.new() + + local ok, err, _ = httpc:connect({ + scheme = "http", + host = helpers.mock_upstream_host, + port = helpers.get_proxy_port(), + }) + if not ok then + assert.is_nil(err) + end + + -- Then send using `request`, supplying a path and `Host` header instead of a + -- full URI. + local res, err = httpc:request({ + path = "/bedrock/llm/v1/chat/streaming/good", + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good-stream.json"), + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + }) + if not res then + assert.is_nil(err) + end + + assert.equal(200, res.status) + + local reader = res.body_reader + local buffer_size = 35536 + local events = {} + local buf = require("string.buffer").new() + + -- extract event + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + assert.is_falsy(err and err ~= "closed") + end + + if buffer then + -- ensure the last chunk ends properly + assert.equal("\n\n", buffer:sub(-2)) + -- we need to rip each message from this chunk + for s in buffer:gmatch("[^\r\n]+") do + local s_copy = s + s_copy = string.sub(s_copy,7) + s_copy = cjson.decode(s_copy) + + buf:put(s_copy + and s_copy.choices + and s_copy.choices[1] + and s_copy.choices[1].delta + and s_copy.choices[1].delta.content + or "") + + table.insert(events, s) + end + end + until not buffer + + assert.equal(#events, 17) -- 16 data events + [DONE] + assert.equal(buf:tostring(), "1 + 1 = 2\n\nThis is one of the most basic arithmetic equations. It represents the addition of two units, resulting in a sum of two.") + + -- test analytics on this item + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + end) + end) end) diff --git a/spec/03-plugins/38-ai-proxy/10-huggingface_integration_spec.lua b/spec/03-plugins/38-ai-proxy/10-huggingface_integration_spec.lua index 16ca0c548cf..5e3aef9063b 100644 --- a/spec/03-plugins/38-ai-proxy/10-huggingface_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/10-huggingface_integration_spec.lua @@ -2,8 +2,37 @@ local helpers = require("spec.helpers") local cjson = require("cjson") local pl_file = require("pl.file") +local strip = require("kong.tools.string").strip + + local PLUGIN_NAME = "ai-proxy" +local FILE_LOG_PATH_NO_LOGS = os.tmpname() + + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end + + for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() @@ -27,7 +56,7 @@ for _, strategy in helpers.all_strategies() do default_type 'application/json'; - location = "/v1/chat/completions" { + location = "/llm/v1/chat/good" { content_by_lua_block { local pl_file = require "pl.file" local json = require("cjson.safe") @@ -129,11 +158,51 @@ for _, strategy in helpers.all_strategies() do use_cache = false, wait_for_model = true, }, - upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/good", + }, + }, + }, + }) + + local chat_good_with_no_upstream_port = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/huggingface/llm/v1/chat/good_with_no_upstream_port" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + route = { id = chat_good_with_no_upstream_port.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer huggingface-key", + }, + model = { + name = "mistralai/Mistral-7B-Instruct-v0.2", + provider = "huggingface", + options = { + max_tokens = 256, + temperature = 1.0, + huggingface = { + use_cache = false, + wait_for_model = true, + }, + upstream_url = "http://" .. helpers.mock_upstream_host, }, }, }, }) + + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + path = FILE_LOG_PATH_NO_LOGS, + }, + } + local completions_good = assert(bp.routes:insert({ service = empty_service, protocols = { "http" }, @@ -252,7 +321,7 @@ for _, strategy in helpers.all_strategies() do use_cache = false, wait_for_model = false, }, - upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT.."/model-loading", + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT.."/model-loading/v1/chat/completions", }, }, }, @@ -283,7 +352,7 @@ for _, strategy in helpers.all_strategies() do use_cache = false, wait_for_model = false, }, - upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT.."/model-timeout", + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT.."/model-timeout/v1/chat/completions", }, }, }, @@ -334,14 +403,29 @@ for _, strategy in helpers.all_strategies() do assert.equals(json.object, "chat.completion") assert.is_table(json.choices) - --print("json: ", inspect(json)) assert.is_string(json.choices[1].message.content) assert.same( " The sum of 1 + 1 is 2. This is a basic arithmetic operation and the answer is always the same: adding one to one results in having two in total.", json.choices[1].message.content ) end) + it("good request with no upstream port", function() + local r = client:get("/huggingface/llm/v1/chat/good_with_no_upstream_port", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/huggingface/llm-v1-chat/requests/good.json"), + }) + assert.res_status(502 , r) + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_NO_LOGS) + assert.same("127.0.0.1", log_message.client_ip) + local tries = log_message.tries + assert.is_table(tries) + assert.equal(tries[1].port, 80) + end) end) + describe("huggingface llm/v1/completions", function() it("good request", function() local r = client:get("/huggingface/llm/v1/completions/good", { diff --git a/spec/03-plugins/38-ai-proxy/11-gemini_integration_spec.lua b/spec/03-plugins/38-ai-proxy/11-gemini_integration_spec.lua index 93273515bc0..fe3ce547419 100644 --- a/spec/03-plugins/38-ai-proxy/11-gemini_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/11-gemini_integration_spec.lua @@ -34,12 +34,12 @@ local function wait_for_json_log_entry(FILE_LOG_PATH) return json end -local _EXPECTED_CHAT_STATS = { +local _EXPECTED_CHAT_STATS_GEMINI = { meta = { plugin_id = '17434c15-2c7c-4c2f-b87a-58880533a3c1', provider_name = 'gemini', - request_model = 'gemini-1.5-pro', - response_model = 'gemini-1.5-pro', + request_model = 'gemini-1.5-flash', + response_model = 'gemini-1.5-flash-002', llm_latency = 1, }, usage = { @@ -52,54 +52,46 @@ local _EXPECTED_CHAT_STATS = { } for _, strategy in helpers.all_strategies() do + local gemini_driver + if strategy ~= "cassandra" then describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() local client - local MOCK_PORT + local MOCK_PORTS = { + _GEMINI = 0, + _ANTHROPIC = 0, + } lazy_setup(function() - MOCK_PORT = helpers.get_available_port() - + _G._TEST = true + package.loaded["kong.llm.drivers.gemini"] = nil + gemini_driver = require("kong.llm.drivers.gemini") + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) - + -- set up gemini mock fixtures local fixtures = { http_mock = {}, } - fixtures.http_mock.gemini = [[ - server { - server_name gemini; - listen ]] .. MOCK_PORT .. [[; - - default_type 'application/json'; - - location = "/v1/chat/completions" { - content_by_lua_block { - local pl_file = require "pl.file" - local json = require("cjson.safe") - - local token = ngx.req.get_headers()["authorization"] - if token == "Bearer gemini-key" then - ngx.req.read_body() - local body, err = ngx.req.get_body_data() - body, err = json.decode(body) - - ngx.status = 200 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/good.json")) - end - } - } - } - ]] + local GEMINI_MOCK = pl_file.read("spec/fixtures/ai-proxy/mock_servers/gemini.lua.txt") + MOCK_PORTS._GEMINI = helpers.get_available_port() + fixtures.http_mock.gemini = string.format(GEMINI_MOCK, MOCK_PORTS._GEMINI) + + local ANTHROPIC_MOCK = pl_file.read("spec/fixtures/ai-proxy/mock_servers/anthropic.lua.txt") + MOCK_PORTS._ANTHROPIC = helpers.get_available_port() + fixtures.http_mock.anthropic = string.format(ANTHROPIC_MOCK, MOCK_PORTS._ANTHROPIC) local empty_service = assert(bp.services:insert({ name = "empty_service", host = "localhost", --helpers.mock_upstream_host, - port = 8080, --MOCK_PORT, + port = 8080, --MOCK_PORTS._GEMINI, path = "/", })) + ---- + -- GEMINI MODELS + ---- -- 200 chat good with one option local chat_good = assert(bp.routes:insert({ service = empty_service, @@ -122,12 +114,12 @@ for _, strategy in helpers.all_strategies() do log_statistics = true, }, model = { - name = "gemini-1.5-pro", + name = "gemini-1.5-flash", provider = "gemini", options = { max_tokens = 256, temperature = 1.0, - upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/v1/chat/completions", + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORTS._GEMINI .. "/v1/chat/completions", input_cost = 15.0, output_cost = 15.0, }, @@ -142,6 +134,47 @@ for _, strategy in helpers.all_strategies() do }, } + local chat_good_with_no_upstream_port = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/gemini/llm/v1/chat/good_with_no_upstream_port" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + id = "17434c15-2c7c-4c2f-b87a-58880533a3ca", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer gemini-key", + }, + logging = { + log_payloads = true, + log_statistics = true, + }, + model = { + name = "gemini-1.5-flash", + provider = "gemini", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://" .. helpers.mock_upstream_host .. "/v1/chat/completions", + input_cost = 15.0, + output_cost = 15.0, + }, + }, + }, + }) + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_with_no_upstream_port.id }, + config = { + path = FILE_LOG_PATH_WITH_PAYLOADS, + }, + } + -- 200 chat good with variable local chat_good_with_var = assert(bp.routes:insert({ service = empty_service, @@ -166,12 +199,173 @@ for _, strategy in helpers.all_strategies() do name = "$(uri_captures.model)", provider = "gemini", options = { - upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/v1/chat/completions", + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORTS._GEMINI .. "/v1/chat/completions/$(uri_captures.model)", }, }, }, }) + -- 200 chat good with query param auth using ai-proxy-advanced and ai-response-transformer + local chat_query_auth = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/gemini/llm/v1/chat/query-auth" }, + })) + bp.plugins:insert({ + name = "ai-proxy-advanced", + id = "27544c15-3c8c-5c3f-c98a-69990644a4d2", + route = { id = chat_query_auth.id }, + config = { + targets = { + { + route_type = "llm/v1/chat", + auth = { + param_name = "key", + param_value = "gemini-query-key", + param_location = "query", + }, + logging = { + log_payloads = true, + log_statistics = true, + }, + model = { + name = "gemini-1.5-flash", + provider = "gemini", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORTS._GEMINI .. "/v1/chat/completions/query-auth", + input_cost = 15.0, + output_cost = 15.0, + }, + }, + }, + }, + }, + }) + bp.plugins:insert({ + name = "ai-response-transformer", + id = "37655d26-4d9d-6d4f-d09b-70001755b5e3", + route = { id = chat_query_auth.id }, + config = { + prompt = "Mask all emails and phone numbers in my JSON message with '*'. Return me ONLY the resulting JSON.", + parse_llm_response_json_instructions = false, + llm = { + route_type = "llm/v1/chat", + auth = { + param_name = "key", + param_value = "gemini-query-key", + param_location = "query", + }, + logging = { + log_payloads = true, + log_statistics = true, + }, + model = { + provider = "gemini", + name = "gemini-1.5-flash", + options = { + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORTS._GEMINI .. "/v1/chat/completions/query-auth", + input_cost = 15.0, + output_cost = 15.0, + }, + }, + }, + }, + }) + bp.plugins:insert { + name = "file-log", + route = { id = chat_query_auth.id }, + config = { + path = FILE_LOG_PATH_WITH_PAYLOADS, + }, + } + + -- 400 chat fails Model Armor "Floor". + -- NOT related to the "ai-gcp-model-armor" plugin. + local chat_fail_model_armor = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/gemini/llm/v1/chat/fail-model-armor" }, + })) + bp.plugins:insert({ + name = "ai-proxy-advanced", + id = "27544c15-3c8c-5c3f-c98a-69990644aaaa", + route = { id = chat_fail_model_armor.id }, + config = { + targets = { + { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer gemini-key", + }, + logging = { + log_payloads = false, + log_statistics = false, + }, + model = { + name = "gemini-2.5-flash", + provider = "gemini", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORTS._GEMINI .. "/v1/chat/completions/fail-model-armor", + input_cost = 15.0, + output_cost = 15.0, + }, + }, + }, + }, + }, + }) + + ---- + -- ANTHROPIC MODELS + ---- + local chat_good_anthropic = assert(bp.routes:insert({ + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/anthropic/llm/v1/chat/good" }, + })) + bp.plugins:insert({ + name = PLUGIN_NAME, + id = "17434c15-2c7c-4c2f-b87a-58880533a3c2", + route = { id = chat_good_anthropic.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "x-api-key", + header_value = "anthropic-key", + }, + logging = { + log_payloads = true, + log_statistics = true, + }, + model = { + name = "claude-2.1", + provider = "gemini", + options = { + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORTS._ANTHROPIC .. "/llm/v1/chat/good", + input_cost = 15.0, + output_cost = 15.0, + }, + }, + }, + }) + bp.plugins:insert { + name = "file-log", + route = { id = chat_good_anthropic.id }, + config = { + path = FILE_LOG_PATH_WITH_PAYLOADS, + }, + } + + -- TODO: mock gcp client to test vertex mode + -- start kong assert(helpers.start_kong({ -- set the strategy @@ -201,68 +395,360 @@ for _, strategy in helpers.all_strategies() do end end) - describe("gemini llm/v1/chat", function() - it("good request", function() - local r = client:get("/gemini/llm/v1/chat/good", { - headers = { - ["content-type"] = "application/json", - ["accept"] = "application/json", + describe("gemini models", function() + describe("gemini (gemini) llm/v1/chat", function() + it("good request", function() + local r = client:get("/gemini/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.model, "gemini-1.5-flash-002") + assert.equals(json.object, "chat.completion") + assert.equals(json.choices[1].finish_reason, "stop") + + assert.is_table(json.choices) + assert.is_string(json.choices[1].message.content) + assert.same("Everything is okay.", json.choices[1].message.content) + + -- test stats from file-log + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + + local actual_stats = log_message.ai.proxy + + local actual_llm_latency = actual_stats.meta.llm_latency + local actual_time_per_token = actual_stats.usage.time_per_token + local time_per_token = actual_llm_latency / actual_stats.usage.completion_tokens + + local actual_request_log = actual_stats.payload.request + local actual_response_log = actual_stats.payload.response + actual_stats.payload = nil + + actual_stats.meta.llm_latency = 1 + actual_stats.usage.time_per_token = 1 + + assert.same(_EXPECTED_CHAT_STATS_GEMINI, actual_stats) + assert.is_true(actual_llm_latency >= 0) + assert.same(tonumber(string.format("%.3f", actual_time_per_token)), tonumber(string.format("%.3f", time_per_token))) + assert.match_re(actual_request_log, [[.*content.*What is 1 \+ 1.*]]) + assert.match_re(actual_response_log, [[.*content.*Everything is okay.*]]) + end) + it("good request with no upstream port", function() + local r = client:get("/gemini/llm/v1/chat/good_with_no_upstream_port", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + -- validate that the request succeeded, response status 200 + assert.res_status(502, r) + + -- test stats from file-log + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) + assert.same("127.0.0.1", log_message.client_ip) + local tries = log_message.tries + assert.is_table(tries) + assert.equal(tries[1].port, 80) + end) + + it("good request with model name from variable", function() + local r = client:get("/gemini/llm/v1/chat/good/gemini-2.0-flash", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200, r) + local json = cjson.decode(body) + assert.equals("gemini-2.0-flash-079", json.model) + end) + + it("bad request fails gcp model armor floor settings", function() + local r = client:get("/gemini/llm/v1/chat/fail-model-armor", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + -- the body doesn't matter - the mock server always returns the error we want + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + -- validate that the request succeeded, response status 400 + local body = assert.res_status(400, r) + local json = cjson.decode(body) + + assert.same(json, { + error = true, + message = "Blocked by Model Armor Floor Setting: The prompt violated Responsible AI Safety settings (Harassment), Prompt Injection and Jailbreak filters.", + reason = "MODEL_ARMOR" + }) + end) + end) + + describe("gemini (gemini) llm/v1/chat with query param auth", function() + it("good request with query parameter authentication", function() + local r = client:get("/gemini/llm/v1/chat/query-auth", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200, r) + assert.same("Everything is okay.", body) + end) + end) + end) + + describe("anthropic models", function() + describe("gemini (anthropic) llm/v1/chat", function() + it("good request #tt", function() + local r = client:get("/anthropic/llm/v1/chat/good", { + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), + }) + -- validate that the request succeeded, response status 200 + local body = assert.res_status(200, r) + local json = cjson.decode(body) + + -- check this is in the 'kong' response format + assert.equals(json.model, "claude-2.1") + assert.equals(json.object, "chat.completion") + assert.equals(json.choices[1].finish_reason, "stop") + + assert.is_table(json.choices) + assert.is_string(json.choices[1].message.content) + assert.same("The sum of 1 + 1 is 2.", json.choices[1].message.content) + + -- test stats from file-log + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + + local actual_stats = log_message.ai.proxy + + local actual_llm_latency = actual_stats.meta.llm_latency + local actual_time_per_token = actual_stats.usage.time_per_token + local time_per_token = actual_llm_latency / actual_stats.usage.completion_tokens + + local actual_request_log = actual_stats.payload.request + local actual_response_log = actual_stats.payload.response + actual_stats.payload = nil + + actual_stats.meta.llm_latency = 1 + actual_stats.usage.time_per_token = 1 + + assert.is_true(actual_llm_latency >= 0) + assert.same(tonumber(string.format("%.3f", actual_time_per_token)), tonumber(string.format("%.3f", time_per_token))) + assert.match_re(actual_request_log, [[.*content".*What is 1 \+ 1.*]]) + assert.match_re(actual_response_log, [[.*content.*The sum of 1 \+ 1 is 2.*]]) + end) + end) + end) + + describe("#utilities", function() + + it("should parse gemini model names into coordinates", function() + -- gemini no stream + local model_name = "gemini-1.5-flash" + local coordinates = gemini_driver.get_model_coordinates(model_name, false) + + assert.same({ + publisher = "google", + operation = "generateContent", + }, coordinates) + + -- gemini stream + model_name = "gemini-1.5-flash" + coordinates = gemini_driver.get_model_coordinates(model_name, true) + assert.same({ + publisher = "google", + operation = "streamGenerateContent", + }, coordinates) + + -- claude no stream + model_name = "claude-3.5-sonnet-20240229" + coordinates = gemini_driver.get_model_coordinates(model_name, false) + assert.same({ + publisher = "anthropic", + operation = "rawPredict", + }, coordinates) + + -- claude stream + model_name = "claude-3.5-sonnet-20240229" + coordinates = gemini_driver.get_model_coordinates(model_name, true) + assert.same({ + publisher = "anthropic", + operation = "streamRawPredict", + }, coordinates) + + -- ai21/jamba + model_name = "jamba-1.0" + coordinates = gemini_driver.get_model_coordinates(model_name, false) + assert.same({ + publisher = "ai21", + operation = "rawPredict", + }, coordinates) + + -- mistral + model_name = "mistral-large-2407" + coordinates = gemini_driver.get_model_coordinates(model_name, false) + assert.same({ + publisher = "mistral", + operation = "rawPredict", + }, coordinates) + + -- non-text model + model_name = "text-embedding-004" + coordinates = gemini_driver.get_model_coordinates(model_name, false) + assert.same({ + publisher = "google", + operation = "generateContent", -- doesn't matter, not used + }, coordinates) + + model_name = "imagen-4.0-generate-preview-06-06" + coordinates = gemini_driver.get_model_coordinates(model_name, false) + assert.same({ + publisher = "google", + operation = "generateContent", -- doesn't matter, not used + }, coordinates) + + end) + + it("should provide correct gemini (vertex) URL pattern", function() + -- err + local _, err = gemini_driver._get_gemini_vertex_url({ + provider = "gemini", + name = "gemini-1.5-flash", + }, "llm/v1/chat", false) + + assert.equals("model.options.gemini.* options must be set for vertex mode", err) + + local gemini_options = { + gemini = { + api_endpoint = "gemini.local", + project_id = "test-project", + location_id = "us-central1", + }, + } + + -- gemini no stream + local url = gemini_driver._get_gemini_vertex_url({ + provider = "gemini", + name = "gemini-1.5-flash", + options = gemini_options, + }, "llm/v1/chat", false) + + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash:generateContent", url) + + -- gemini stream + url = gemini_driver._get_gemini_vertex_url({ + provider = "gemini", + name = "gemini-1.5-flash", + options = gemini_options, + }, "llm/v1/chat", true) + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-flash:streamGenerateContent", url) + + -- claude no stream + url = gemini_driver._get_gemini_vertex_url({ + provider = "anthropic", + name = "claude-3.5-sonnet-20240229", + options = gemini_options, + }, "llm/v1/chat", false) + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/anthropic/models/claude-3.5-sonnet-20240229:rawPredict", url) + + -- claude stream + url = gemini_driver._get_gemini_vertex_url({ + provider = "anthropic", + name = "claude-3.5-sonnet-20240229", + options = { + gemini = { + api_endpoint = "gemini.local", + project_id = "test-project", + location_id = "us-central1", + }, }, - body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), - }) - -- validate that the request succeeded, response status 200 - local body = assert.res_status(200, r) - local json = cjson.decode(body) - - -- check this is in the 'kong' response format - assert.equals(json.model, "gemini-1.5-pro") - assert.equals(json.object, "chat.completion") - assert.equals(json.choices[1].finish_reason, "stop") - - assert.is_table(json.choices) - assert.is_string(json.choices[1].message.content) - assert.same("Everything is okay.", json.choices[1].message.content) - - -- test stats from file-log - local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) - assert.same("127.0.0.1", log_message.client_ip) - assert.is_number(log_message.request.size) - assert.is_number(log_message.response.size) - - local actual_stats = log_message.ai.proxy - - local actual_llm_latency = actual_stats.meta.llm_latency - local actual_time_per_token = actual_stats.usage.time_per_token - local time_per_token = actual_llm_latency / actual_stats.usage.completion_tokens - - local actual_request_log = actual_stats.payload.request - local actual_response_log = actual_stats.payload.response - actual_stats.payload = nil - - actual_stats.meta.llm_latency = 1 - actual_stats.usage.time_per_token = 1 - - assert.same(_EXPECTED_CHAT_STATS, actual_stats) - assert.is_true(actual_llm_latency >= 0) - assert.same(tonumber(string.format("%.3f", actual_time_per_token)), tonumber(string.format("%.3f", time_per_token))) - assert.match_re(actual_request_log, [[.*contents.*What is 1 \+ 1.*]]) - assert.match_re(actual_response_log, [[.*content.*Everything is okay.*]]) + }, "llm/v1/chat", true) + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/anthropic/models/claude-3.5-sonnet-20240229:streamRawPredict", url) + + -- ai21/jamba + url = gemini_driver._get_gemini_vertex_url({ + provider = "ai21", + name = "jamba-1.0", + options = gemini_options, + }, "llm/v1/chat", false) + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/ai21/models/jamba-1.0:rawPredict", url) + + -- mistral + url = gemini_driver._get_gemini_vertex_url({ + provider = "mistral", + name = "mistral-large-2407", + options = gemini_options, + }, "llm/v1/chat", false) + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/mistral/models/mistral-large-2407:rawPredict", url) + + -- non-text model + url = gemini_driver._get_gemini_vertex_url({ + provider = "google", + name = "text-embedding-004", + options = gemini_options, + }, "llm/v1/embeddings", false) + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/google/models/text-embedding-004:predict", url) + + url = gemini_driver._get_gemini_vertex_url({ + provider = "google", + name = "imagen-4.0-generate-preview-06-06", + options = gemini_options, + }, "image/v1/images/generations", false) + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-preview-06-06:generateContent", url) + + url = gemini_driver._get_gemini_vertex_url({ + provider = "google", + name = "imagen-4.0-generate-preview-06-06", + options = gemini_options, + }, "image/v1/images/edits", false) + assert.equals("https://gemini.local/v1/projects/test-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-preview-06-06:generateContent", url) + end) - it("good request with model name from variable", function() - local r = client:get("/gemini/llm/v1/chat/good/gemni-2.0-flash", { - headers = { - ["content-type"] = "application/json", - ["accept"] = "application/json", + it("should detect vertex mode automatically", function() + local model = { + name = "gemini-1.5-flash", + options = { + gemini = { + api_endpoint = "gemini.local", + project_id = "test-project", + location_id = "us-central1", + }, }, - body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), - }) - -- validate that the request succeeded, response status 200 - local body = assert.res_status(200, r) - local json = cjson.decode(body) - assert.equals("gemni-2.0-flash", json.model) + } + + assert.is_true(gemini_driver._is_vertex_mode(model)) + + model = { + name = "gemini-1.5-flash", + } + + assert.is_falsy(gemini_driver._is_vertex_mode(model)) end) end) end) end -end +end \ No newline at end of file diff --git a/spec/03-plugins/38-ai-proxy/12-native_unit_spec.lua b/spec/03-plugins/38-ai-proxy/12-native_unit_spec.lua index 5f62593035a..d04543c5c60 100644 --- a/spec/03-plugins/38-ai-proxy/12-native_unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/12-native_unit_spec.lua @@ -52,6 +52,116 @@ local _NATIVE_ADAPTERS = { }, }, }, + ["GOOD_FULL_RERANK"] = { + INPUT_FILE = "spec/fixtures/ai-proxy/native/bedrock/request/rerank.json", + CLEANUP_BEFORE_COMPARE = function(expected_response, real_response) + expected_response.model = "cohere.command-r-v1:0" + return expected_response, real_response + end, + MOCK_KONG = { + request = { + get_path = function() + return "/rerank" + end, + get_uri_captures = function() + return { + named = { + ["model"] = "cohere.command-r-v1:0", + ["operation"] = "converse", + }, + } + end, + }, + }, + }, + ["GOOD_RETRIE_AND_GENERATE_STREAM"] = { + INPUT_FILE = "spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerateStream.json", + CLEANUP_BEFORE_COMPARE = function(expected_response, real_response) + expected_response.model = "cohere.command-r-v1:0" + return expected_response, real_response + end, + MOCK_KONG = { + request = { + get_path = function() + return "/retrieveAndGenerateStream" + end, + get_uri_captures = function() + return { + named = { + ["model"] = "cohere.command-r-v1:0", + ["operation"] = "retrieve-and-generate-stream", + }, + } + end, + }, + }, + }, + ["GOOD_RETRIE_AND_GENERATE"] = { + INPUT_FILE = "spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerate.json", + CLEANUP_BEFORE_COMPARE = function(expected_response, real_response) + expected_response.model = "cohere.command-r-v1:0" + return expected_response, real_response + end, + MOCK_KONG = { + request = { + get_path = function() + return "/retrieveAndGenerate" + end, + get_uri_captures = function() + return { + named = { + ["model"] = "cohere.command-r-v1:0", + ["operation"] = "retrieve-and-generate", + }, + } + end, + }, + }, + }, + ["GOOD_CONVERSE"] = { + INPUT_FILE = "spec/fixtures/ai-proxy/native/bedrock/request/converse.json", + CLEANUP_BEFORE_COMPARE = function(expected_response, real_response) + expected_response.model = "cohere.command-r-v1:0" + return expected_response, real_response + end, + MOCK_KONG = { + request = { + get_path = function() + return "/model/cohere.command-r-v1%3A0/converse" + end, + get_uri_captures = function() + return { + named = { + ["model"] = "cohere.command-r-v1:0", + ["operation"] = "converse", + }, + } + end, + }, + }, + }, + ["GOOD_CONVERSE_STREAM"] = { + INPUT_FILE = "spec/fixtures/ai-proxy/native/bedrock/request/converse-stream.json", + CLEANUP_BEFORE_COMPARE = function(expected_response, real_response) + expected_response.model = "cohere.command-r-v1:0" + return expected_response, real_response + end, + MOCK_KONG = { + request = { + get_path = function() + return "/model/cohere.command-r-v1%3A0/converse-stream" + end, + get_uri_captures = function() + return { + named = { + ["model"] = "cohere.command-r-v1:0", + ["operation"] = "converse-stream", + }, + } + end, + }, + }, + }, }, }, gemini = { @@ -199,6 +309,98 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.same(target_response, response) end) + it(fmt("adapters.%s good full with rerank", adapter_name), function() + if adapter_name == "bedrock" then + local target_response = cjson.decode(pl_file.read("spec/fixtures/ai-proxy/native/bedrock/request/rerank.json")) + + local test_manifest = adapter_manifest.TESTS.GOOD_FULL_RERANK + + package.loaded[adapter_manifest.CLASS] = nil + _G.TEST = true + local adapter = require(adapter_manifest.CLASS) + + adapter = adapter:new() + + local request = cjson.decode(pl_file.read(test_manifest.INPUT_FILE)) + local response = adapter:to_kong_req(request, test_manifest.MOCK_KONG) + assert.same(adapter.forward_path, "/rerank") + assert.same(response.messages, target_response.queries) + end + end) + it(fmt("adapters.%s good retrieve and generate stream", adapter_name), function() + if adapter_name == "bedrock" then + local target_response = cjson.decode(pl_file.read("spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerateStream.json")) + + local test_manifest = adapter_manifest.TESTS.GOOD_RETRIE_AND_GENERATE_STREAM + + package.loaded[adapter_manifest.CLASS] = nil + _G.TEST = true + local adapter = require(adapter_manifest.CLASS) + + adapter = adapter:new() + + local request = cjson.decode(pl_file.read(test_manifest.INPUT_FILE)) + local response = adapter:to_kong_req(request, test_manifest.MOCK_KONG) + assert.same(adapter.forward_path, "/retrieveAndGenerateStream") + assert.same(response.prompt, target_response.input.text) + assert.same(response.stream, true) + end + end) + it(fmt("adapters.%s good retrieve and generate", adapter_name), function() + if adapter_name == "bedrock" then + local target_response = cjson.decode(pl_file.read("spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerate.json")) + + local test_manifest = adapter_manifest.TESTS.GOOD_RETRIE_AND_GENERATE + + package.loaded[adapter_manifest.CLASS] = nil + _G.TEST = true + local adapter = require(adapter_manifest.CLASS) + + adapter = adapter:new() + + local request = cjson.decode(pl_file.read(test_manifest.INPUT_FILE)) + local response = adapter:to_kong_req(request, test_manifest.MOCK_KONG) + assert.same(adapter.forward_path, "/retrieveAndGenerate") + assert.same(response.prompt, target_response.input.text) + assert.same(response.stream, false) + end + end) + + it(fmt("adapters.%s good converse", adapter_name), function() + if adapter_name == "bedrock" then + + local test_manifest = adapter_manifest.TESTS.GOOD_CONVERSE + + package.loaded[adapter_manifest.CLASS] = nil + _G.TEST = true + local adapter = require(adapter_manifest.CLASS) + + adapter = adapter:new() + + local request = cjson.decode(pl_file.read(test_manifest.INPUT_FILE)) + local response = adapter:to_kong_req(request, test_manifest.MOCK_KONG) + assert.same(adapter.forward_path, "/model/%s/converse") + assert.same(response.stream, false) + end + end) + + it(fmt("adapters.%s good converse stream", adapter_name), function() + if adapter_name == "bedrock" then + + local test_manifest = adapter_manifest.TESTS.GOOD_CONVERSE_STREAM + + package.loaded[adapter_manifest.CLASS] = nil + _G.TEST = true + local adapter = require(adapter_manifest.CLASS) + + adapter = adapter:new() + + local request = cjson.decode(pl_file.read(test_manifest.INPUT_FILE)) + local response = adapter:to_kong_req(request, test_manifest.MOCK_KONG) + assert.same(adapter.forward_path, "/model/%s/converse-stream") + assert.same(response.stream, true) + end + end) end end) diff --git a/spec/03-plugins/38-ai-proxy/fuzz/03-gemini_spec.lua b/spec/03-plugins/38-ai-proxy/fuzz/03-gemini_spec.lua new file mode 100644 index 00000000000..6a61b8f3949 --- /dev/null +++ b/spec/03-plugins/38-ai-proxy/fuzz/03-gemini_spec.lua @@ -0,0 +1,16 @@ +local cjson = require "cjson.safe" +local stream_response_fuzzer = require "spec.03-plugins.38-ai-proxy.fuzz.stream_response" +local ai_shared = require("kong.llm.drivers.shared") + +stream_response_fuzzer.setup(getfenv()) +-- Used in Gemini +local assert_fn = function(expected, actual, msg) + -- tables are random ordered, so we need to compare each serialized event + assert.same(cjson.decode(expected.data), cjson.decode(actual.data), msg) +end +stream_response_fuzzer.run_case("gemini", + "spec/fixtures/ai-proxy/unit/streaming-chunk-formats/gemini/input.json", + "spec/fixtures/ai-proxy/unit/streaming-chunk-formats/gemini/expected-output.json", + ai_shared._CONST.GEMINI_STREAM_CONTENT_TYPE, + assert_fn +) diff --git a/spec/03-plugins/38-ai-proxy/fuzz/04-cohere_spec.lua b/spec/03-plugins/38-ai-proxy/fuzz/04-cohere_spec.lua new file mode 100644 index 00000000000..4031c9d4d02 --- /dev/null +++ b/spec/03-plugins/38-ai-proxy/fuzz/04-cohere_spec.lua @@ -0,0 +1,15 @@ +local cjson = require "cjson.safe" +local stream_response_fuzzer = require "spec.03-plugins.38-ai-proxy.fuzz.stream_response" + +stream_response_fuzzer.setup(getfenv()) +-- Used in Gemini +local assert_fn = function(expected, actual, msg) + -- tables are random ordered, so we need to compare each serialized event + assert.same(cjson.decode(expected.data), cjson.decode(actual.data), msg) +end +stream_response_fuzzer.run_case("cohere", + "spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/input.json", + "spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/expected-output.json", + "application/stream+json", + assert_fn +) diff --git a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua index 6f2a981cf86..73dab199131 100644 --- a/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/02-integration_spec.lua @@ -87,7 +87,7 @@ local _EXPECTED_CHAT_STATS_GEMINI = { plugin_id = '71083e79-4921-4f9f-97a4-ee7810b6cd8b', provider_name = 'gemini', request_model = 'UNSPECIFIED', - response_model = 'gemini-1.5-flash', + response_model = 'gemini-1.5-flash-002', llm_latency = 1 }, usage = { diff --git a/spec/03-plugins/41-ai-prompt-decorator/01-unit_spec.lua b/spec/03-plugins/41-ai-prompt-decorator/01-unit_spec.lua index 95e078a6178..f9662391348 100644 --- a/spec/03-plugins/41-ai-prompt-decorator/01-unit_spec.lua +++ b/spec/03-plugins/41-ai-prompt-decorator/01-unit_spec.lua @@ -42,6 +42,17 @@ local general_chat_request = { }, } +local full_chat_request = { + model = "gpt-4o", + temperature = 0.6, + messages = { + [1] = { + role = "user", + content = "What is 1+1?" + }, + }, +} + local injector_conf_prepend = { prompts = { prepend = { @@ -173,6 +184,25 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.same(expected_request_copy, decorated_request) end) + + it("preserves model and temperature fields when decorating", function() + local request_copy = deepcopy(full_chat_request) + local expected_request_copy = deepcopy(full_chat_request) + + -- combine the tables manually, and check the code does the same + table.insert(expected_request_copy.messages, 1, injector_conf_prepend.prompts.prepend[1]) + table.insert(expected_request_copy.messages, 2, injector_conf_prepend.prompts.prepend[2]) + table.insert(expected_request_copy.messages, 3, injector_conf_prepend.prompts.prepend[3]) + + local decorated_request, err = access_handler._execute(request_copy, injector_conf_prepend) + + assert.is_nil(err) + assert.same(decorated_request, expected_request_copy) + -- Ensure model and temperature are preserved + assert.equal("gpt-4o", decorated_request.model) + assert.equal(0.6, decorated_request.temperature) + end) + end) end) diff --git a/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua b/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua index 89acb6fd211..5975613c7b2 100644 --- a/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua +++ b/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua @@ -22,6 +22,18 @@ local openai_flat_chat = { }, } +local openai_full_chat = { + model = "gpt-4o", + temperature = 0.6, + max_tokens = 150, + messages = { + { + role = "user", + content = "What is 1+1?", + }, + }, +} + for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() @@ -242,6 +254,40 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.match_re(ctx, [[.*Append text 3 here.*]]) assert.match_re(ctx, [[.*Append text 4 here.*]]) end) + + + it("preserves model and temperature fields when decorating - integration", function() + local r = client:get("/", { + headers = { + host = "prepend.decorate.local", + ["Content-Type"] = "application/json" + }, + body = cjson.encode(openai_full_chat), + }) + + -- get the REQUEST body, that left Kong for the upstream, using the echo system + assert.response(r).has.status(200) + local request = assert.response(r).has.jsonbody() + request = cjson.decode(request.post_data.text) + + -- Verify the messages are decorated correctly + assert.same({ content = "Prepend text 1 here.", role = "system" }, request.messages[1]) + assert.same({ content = "Prepend text 2 here.", role = "system" }, request.messages[2]) + assert.same({ content = "What is 1+1?", role = "user" }, request.messages[3]) + + -- Verify that model, temperature, and max_tokens are preserved + assert.equal("gpt-4o", request.model) + assert.equal(0.6, request.temperature) + assert.equal(150, request.max_tokens) + + -- check ngx.ctx was set properly for later AI chain filters + local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx") + ctx = ngx.unescape_uri(ctx) + assert.match_re(ctx, [[.*decorate-prompt.*]]) + assert.match_re(ctx, [[.*decorated = true.*]]) + assert.match_re(ctx, [[.*Prepend text 1 here.*]]) + assert.match_re(ctx, [[.*Prepend text 2 here.*]]) + end) end) end) diff --git a/spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json b/spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json index 174220a4a21..df0f162deff 100644 --- a/spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json +++ b/spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json @@ -9,7 +9,7 @@ "stop_reason": "end_turn", "stop_sequence": "string", "usage": { - "input_tokens": 0, - "output_tokens": 0 + "input_tokens": 100, + "output_tokens": 150 } } diff --git a/spec/fixtures/ai-proxy/bedrock/llm-v1-chat/response/good_with_function.json b/spec/fixtures/ai-proxy/bedrock/llm-v1-chat/response/good_with_function.json new file mode 100644 index 00000000000..4d1472ff053 --- /dev/null +++ b/spec/fixtures/ai-proxy/bedrock/llm-v1-chat/response/good_with_function.json @@ -0,0 +1 @@ +{"metrics":{"latencyMs":998},"output":{"message":{"content":[{"text":"The weather in Paris is currently 16.4°C."}],"role":"assistant"}},"stopReason":"end_turn","usage":{"inputTokens":39,"outputTokens":14,"totalTokens":53}} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/bedrock/llm-v1-chat/responses/good_with_function.json b/spec/fixtures/ai-proxy/bedrock/llm-v1-chat/responses/good_with_function.json new file mode 100644 index 00000000000..4d1472ff053 --- /dev/null +++ b/spec/fixtures/ai-proxy/bedrock/llm-v1-chat/responses/good_with_function.json @@ -0,0 +1 @@ +{"metrics":{"latencyMs":998},"output":{"message":{"content":[{"text":"The weather in Paris is currently 16.4°C."}],"role":"assistant"}},"stopReason":"end_turn","usage":{"inputTokens":39,"outputTokens":14,"totalTokens":53}} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails-model-armor-floor.json b/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails-model-armor-floor.json new file mode 100644 index 00000000000..122fa773e6e --- /dev/null +++ b/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails-model-armor-floor.json @@ -0,0 +1,12 @@ +{ + "usageMetadata": { + "trafficType": "ON_DEMAND" + }, + "createTime": "2025-09-13T22:15:16.223845Z", + "promptFeedback": { + "blockReasonMessage": "Blocked by Model Armor Floor Setting: The prompt violated Responsible AI Safety settings (Harassment), Prompt Injection and Jailbreak filters.", + "blockReason": "MODEL_ARMOR" + }, + "responseId": "9OzFaOXUDdyZgLUPmd2_sA0", + "modelVersion": "gemini-2.5-flash" +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/good-gemini-2.0-flash.json b/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/good-gemini-2.0-flash.json new file mode 100644 index 00000000000..53e712d08e5 --- /dev/null +++ b/spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/good-gemini-2.0-flash.json @@ -0,0 +1,22 @@ +{ + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Everything is okay." + } + ] + }, + "finishReason": "STOP", + "avgLogprobs": -0.013348851691592823 + } + ], + "usageMetadata": { + "promptTokenCount": 2, + "candidatesTokenCount": 11, + "totalTokenCount": 13 + }, + "modelVersion": "gemini-2.0-flash-079" +} diff --git a/spec/fixtures/ai-proxy/mock_servers/anthropic.lua.txt b/spec/fixtures/ai-proxy/mock_servers/anthropic.lua.txt new file mode 100644 index 00000000000..c0196024086 --- /dev/null +++ b/spec/fixtures/ai-proxy/mock_servers/anthropic.lua.txt @@ -0,0 +1,183 @@ +server { + server_name anthropic; + listen %s; + + default_type 'application/json'; + + + location = "/llm/v1/chat/good" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["x-api-key"] + if token == "anthropic-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (not body.messages) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) + else + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json")) + end + } + } + + location = "/llm/v1/chat/bad_upstream_response" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["x-api-key"] + if token == "anthropic-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (not body.messages) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) + else + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_upstream_response.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json")) + end + } + } + + location = "/llm/v1/chat/no_usage_upstream_response" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["x-api-key"] + if token == "anthropic-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (not body.messages) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) + else + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/no_usage_response.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json")) + end + } + } + + location = "/llm/v1/chat/malformed_usage_upstream_response" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["x-api-key"] + if token == "anthropic-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (not body.messages) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) + else + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/malformed_usage_response.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json")) + end + } + } + + location = "/llm/v1/chat/bad_request" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json")) + } + } + + location = "/llm/v1/chat/internal_server_error" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 500 + ngx.header["content-type"] = "text/html" + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/internal_server_error.html")) + } + } + + location = "/llm/v1/chat/tool_choice" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + ngx.req.read_body() + local function assert_ok(ok, err) + if not ok then + ngx.status = 500 + ngx.say(err) + ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR) + end + return ok + end + local body = assert_ok(ngx.req.get_body_data()) + body = assert_ok(json.decode(body)) + local tool_choice = body.tool_choice + ngx.header["tool-choice"] = json.encode(tool_choice) + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json")) + } + } + + location = "/llm/v1/completions/good" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["x-api-key"] + if token == "anthropic-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (not body.prompt) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/bad_request.json")) + else + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/good.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/unauthorized.json")) + end + } + } + + location = "/llm/v1/completions/bad_request" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-completions/responses/bad_request.json")) + } + } + +} diff --git a/spec/fixtures/ai-proxy/mock_servers/gemini.lua.txt b/spec/fixtures/ai-proxy/mock_servers/gemini.lua.txt new file mode 100644 index 00000000000..94e3a18270d --- /dev/null +++ b/spec/fixtures/ai-proxy/mock_servers/gemini.lua.txt @@ -0,0 +1,87 @@ +server { + server_name gemini; + listen %s; + + default_type 'application/json'; + + location = "/v1/chat/completions" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["authorization"] + if token == "Bearer gemini-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/good.json")) + end + } + } + + location = "/v1/chat/completions/gemini-2.0-flash" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["authorization"] + if token == "Bearer gemini-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/good-gemini-2.0-flash.json")) + end + } + } + + location = "/v1/embeddings" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + local token = ngx.req.get_headers()["authorization"] + if token == "Bearer gemini-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-embeddings/responses/good.json")) + end + } + } + + location = "/v1/chat/completions/query-auth" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + + -- Check for query parameter authentication + local args = ngx.req.get_uri_args() + if args.key == "gemini-query-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/good.json")) + else + ngx.status = 401 + ngx.print('{"error": "Unauthorized"}') + end + } + } + + location = "/v1/chat/completions/fail-model-armor" { + content_by_lua_block { + local pl_file = require "pl.file" + + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/gemini/llm-v1-chat/responses/fails-model-armor-floor.json")) + } + } +} diff --git a/spec/fixtures/ai-proxy/native/bedrock/request/converse-stream.json b/spec/fixtures/ai-proxy/native/bedrock/request/converse-stream.json new file mode 100644 index 00000000000..b094c3555a0 --- /dev/null +++ b/spec/fixtures/ai-proxy/native/bedrock/request/converse-stream.json @@ -0,0 +1 @@ +{"toolConfig":{"tools":[{"toolSpec":{"inputSchema":{"json":{"type":"object","properties":{"sign":{"type":"string","description":"The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."}},"required":["sign"]}},"description":"Get the most popular song played on a radio station.","name":"top_song"}}]},"messages":[{"content":[{"text":"What is the most popular song on WZPZ?"}],"role":"user"}]} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/native/bedrock/request/converse.json b/spec/fixtures/ai-proxy/native/bedrock/request/converse.json new file mode 100644 index 00000000000..b094c3555a0 --- /dev/null +++ b/spec/fixtures/ai-proxy/native/bedrock/request/converse.json @@ -0,0 +1 @@ +{"toolConfig":{"tools":[{"toolSpec":{"inputSchema":{"json":{"type":"object","properties":{"sign":{"type":"string","description":"The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."}},"required":["sign"]}},"description":"Get the most popular song played on a radio station.","name":"top_song"}}]},"messages":[{"content":[{"text":"What is the most popular song on WZPZ?"}],"role":"user"}]} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/native/bedrock/request/rerank.json b/spec/fixtures/ai-proxy/native/bedrock/request/rerank.json new file mode 100644 index 00000000000..7381f5b6808 --- /dev/null +++ b/spec/fixtures/ai-proxy/native/bedrock/request/rerank.json @@ -0,0 +1 @@ +{"rerankingConfiguration":{"bedrockRerankingConfiguration":{"modelConfiguration":{"modelArn":"arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0"},"numberOfResults":3},"type":"BEDROCK_RERANKING_MODEL"},"queries":[{"textQuery":{"text":"What emails have been about returning items?"},"type":"TEXT"}],"sources":[{"inlineDocumentSource":{"textDocument":{"text":"Hola, llevo una hora intentando acceder a mi cuenta y sigue diciendo que mi contraseña es incorrecta. ¿Puede ayudarme, por favor?"},"type":"TEXT"},"type":"INLINE"},{"inlineDocumentSource":{"textDocument":{"text":"Hi, I recently purchased a product from your website but I never received a confirmation email. Can you please look into this for me?"},"type":"TEXT"},"type":"INLINE"},{"inlineDocumentSource":{"textDocument":{"text":"مرحبًا، لدي سؤال حول سياسة إرجاع هذا المنتج. لقد اشتريته قبل بضعة أسابيع وهو معيب"},"type":"TEXT"},"type":"INLINE"},{"inlineDocumentSource":{"textDocument":{"text":"Good morning, I have been trying to reach your customer support team for the past week but I keep getting a busy signal. Can you please help me?"},"type":"TEXT"},"type":"INLINE"},{"inlineDocumentSource":{"textDocument":{"text":"Hallo, ich habe eine Frage zu meiner letzten Bestellung. Ich habe den falschen Artikel erhalten und muss ihn zurückschicken."},"type":"TEXT"},"type":"INLINE"},{"inlineDocumentSource":{"textDocument":{"text":"Hello, I have been trying to reach your customer support team for the past hour but I keep getting a busy signal. Can you please help me?"},"type":"TEXT"},"type":"INLINE"},{"inlineDocumentSource":{"textDocument":{"text":"Hi, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective."},"type":"TEXT"},"type":"INLINE"},{"inlineDocumentSource":{"textDocument":{"text":"早上好,关于我最近的订单,我有一个问题。我收到了错误的商品"},"type":"TEXT"},"type":"INLINE"},{"inlineDocumentSource":{"textDocument":{"text":"Hello, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective."},"type":"TEXT"},"type":"INLINE"}]} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerate.json b/spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerate.json new file mode 100644 index 00000000000..11880f57900 --- /dev/null +++ b/spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerate.json @@ -0,0 +1 @@ +{"retrieveAndGenerateConfiguration":{"knowledgeBaseConfiguration":{"modelArn":"cohere.command-r-v1:0","knowledgeBaseId":"NGYM70ITLR"},"type":"KNOWLEDGE_BASE"},"input":{"text":"tell me about the most popular song on WZPZ"}} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerateStream.json b/spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerateStream.json new file mode 100644 index 00000000000..e8ca6975700 --- /dev/null +++ b/spec/fixtures/ai-proxy/native/bedrock/request/retrieveAndGenerateStream.json @@ -0,0 +1 @@ +{"input":{"text":"tell me about the most popular song on WZPZ"},"retrieveAndGenerateConfiguration":{"type":"KNOWLEDGE_BASE","knowledgeBaseConfiguration":{"knowledgeBaseId":"NGYM70ITLR","modelArn":"cohere.command-r-v1:0"}}} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/ollama/llm-v1-chat/responses/good_with_function.json b/spec/fixtures/ai-proxy/ollama/llm-v1-chat/responses/good_with_function.json new file mode 100644 index 00000000000..0884b869481 --- /dev/null +++ b/spec/fixtures/ai-proxy/ollama/llm-v1-chat/responses/good_with_function.json @@ -0,0 +1 @@ +{"model":"llama3.2:1b","created_at":"2025-06-05T15:38:57.122562Z","message":{"role":"assistant","content":"The current weather in Paris is mostly sunny with a high of 17°C and a low of 7°C."}} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_with_function.json b/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_with_function.json new file mode 100644 index 00000000000..aae1ab4aac7 --- /dev/null +++ b/spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_with_function.json @@ -0,0 +1 @@ +{"model":"gpt-4.1","tools":[{"function":{"parameters":{"required":["latitude","longitude"],"type":"object","properties":{"latitude":{"type":"number"},"longitude":{"type":"number"}},"additionalProperties":false},"description":"Get current temperature for provided coordinates in celsius.","strict":true,"name":"get_weather"},"type":"function"}],"messages":[{"role":"user","content":"What's the weather like in Paris today?"},{"role":"assistant","content":null,"annotations":[],"refusal":null,"tool_calls":[{"function":{"name":"get_weather","arguments":"{\"latitude\":48.8566,\"longitude\":2.3522}"},"type":"function","id":"call_nmMeFsxL9Inum214205TzjOj"}]},{"role":"tool","tool_call_id":"call_nmMeFsxL9Inum214205TzjOj","content":"16.1"}]} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/good_with_function.json b/spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/good_with_function.json new file mode 100644 index 00000000000..4515f97e25d --- /dev/null +++ b/spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/good_with_function.json @@ -0,0 +1,36 @@ +{ + "id": "chatcmpl-Bf0MVSE0QWTpllRYq7WF7Vryd9Cik", + "object": "chat.completion", + "created": 1749112539, + "model": "gpt-4.1-2025-04-14", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The current temperature in Paris today is 16.1°C. If you need more detailed weather information (like precipitation or wind), just let me know!", + "refusal": null, + "annotations": [] + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 94, + "completion_tokens": 32, + "total_tokens": 126, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default", + "system_fingerprint": "fp_51e1070cf2" +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/anthropic/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/anthropic/llm-v1-chat.json index 969489f8a09..f36dc348a70 100644 --- a/spec/fixtures/ai-proxy/unit/expected-responses/anthropic/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-responses/anthropic/llm-v1-chat.json @@ -1,7 +1,7 @@ { "choices": [ { - "finish_reason": "stop_sequence", + "finish_reason": "stop", "index": 0, "message": { "content": "You cannot divide by zero because it is not a valid operation in mathematics.", @@ -10,5 +10,10 @@ } ], "model": "claude-2.1", - "object": "chat.completion" + "object": "chat.completion", + "usage": { + "completion_tokens": 16, + "prompt_tokens": 12, + "total_tokens": 28 + } } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/anthropic/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-responses/anthropic/llm-v1-completions.json index 0c3eccb7331..abba1bb684d 100644 --- a/spec/fixtures/ai-proxy/unit/expected-responses/anthropic/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-responses/anthropic/llm-v1-completions.json @@ -1,7 +1,7 @@ { "choices": [ { - "finish_reason": "stop_sequence", + "finish_reason": "stop", "index": 0, "text": "You cannot divide by zero because it is not a valid operation in mathematics." } diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/bedrock/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/bedrock/llm-v1-chat.json index 948d3fb4746..14ec1fccc5e 100644 --- a/spec/fixtures/ai-proxy/unit/expected-responses/bedrock/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-responses/bedrock/llm-v1-chat.json @@ -1,7 +1,7 @@ { "choices": [ { - "finish_reason": "end_turn", + "finish_reason": "stop", "index": 0, "message": { "content": "You cannot divide by zero because it is not a valid operation in mathematics.", diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json index 90a1656d2a3..4f8d5775f36 100644 --- a/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json @@ -4,11 +4,18 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n", + "content": "Ah, vous voulez savoir le double de ce résultat? Eh bien, le double de 2 est **4**. \n", "role": "assistant" } } ], - "model": "gemini-pro", - "object": "chat.completion" + "created": 1751926593, + "model": "gemini-1.5-pro-002", + "object": "chat.completion", + "id": "QUdsaIKDBs-u3NoP5Iyb8QY", + "usage": { + "prompt_tokens": 223, + "completion_tokens": 12, + "total_tokens": 235 + } } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/llama2/ollama/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/llama2/ollama/llm-v1-chat.json index 08bb7a7ea85..ea7f5746526 100644 --- a/spec/fixtures/ai-proxy/unit/expected-responses/llama2/ollama/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-responses/llama2/ollama/llm-v1-chat.json @@ -13,7 +13,8 @@ "object": "chat.completion", "usage": { "completion_tokens": 139, - "prompt_tokens": 130, - "total_tokens": 269 - } + "prompt_tokens": 26, + "total_tokens": 165 + }, + "created": 1705306418 } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/llama2/ollama/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-responses/llama2/ollama/llm-v1-completions.json index e8702be854d..b27c2277dde 100644 --- a/spec/fixtures/ai-proxy/unit/expected-responses/llama2/ollama/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-responses/llama2/ollama/llm-v1-completions.json @@ -8,8 +8,9 @@ "object": "text_completion", "model": "llama2", "usage": { - "completion_tokens": 139, - "prompt_tokens": 130, - "total_tokens": 269 - } + "completion_tokens": 12, + "prompt_tokens": 13, + "total_tokens": 25 + }, + "created": 1705306461 } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/mistral/ollama/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/mistral/ollama/llm-v1-chat.json index f5b5312282c..61133fce590 100644 --- a/spec/fixtures/ai-proxy/unit/expected-responses/mistral/ollama/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-responses/mistral/ollama/llm-v1-chat.json @@ -12,8 +12,9 @@ "model": "mistral-tiny", "object": "chat.completion", "usage": { - "completion_tokens": 139, - "prompt_tokens": 130, - "total_tokens": 269 - } + "completion_tokens": 100, + "prompt_tokens": 26, + "total_tokens": 126 + }, + "created": 1705306418 } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/anthropic/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/anthropic/llm-v1-chat.json index 624214cff5d..16c3979371d 100644 --- a/spec/fixtures/ai-proxy/unit/real-responses/anthropic/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/real-responses/anthropic/llm-v1-chat.json @@ -6,7 +6,7 @@ "stop_reason": "stop_sequence", "model": "claude-2.1", "usage": { - "input_tokens": 0, - "output_tokens": 0 + "input_tokens": 12, + "output_tokens": 16 } } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/bedrock/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/bedrock/llm-v1-chat.json index e995bbd984d..1f17d9904c5 100644 --- a/spec/fixtures/ai-proxy/unit/real-responses/bedrock/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/real-responses/bedrock/llm-v1-chat.json @@ -14,8 +14,8 @@ }, "stopReason": "end_turn", "usage": { - "completion_tokens": 119, - "prompt_tokens": 19, - "total_tokens": 138 + "outputTokens": 119, + "inputTokens": 19, + "totalTokens": 138 } } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/cohere/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/cohere/llm-v1-chat.json index bbed8b91237..37eebf5df98 100644 --- a/spec/fixtures/ai-proxy/unit/real-responses/cohere/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/real-responses/cohere/llm-v1-chat.json @@ -5,7 +5,7 @@ "version": "1" }, "billed_units": { - "input_tokens": 81, + "input_tokens": 102, "output_tokens": 258 } }, diff --git a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json index 96933d9835e..cfefc9615d9 100644 --- a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json @@ -4,12 +4,13 @@ "content": { "parts": [ { - "text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n" + "text": "Ah, vous voulez savoir le double de ce résultat? Eh bien, le double de 2 est **4**. \n" } ], "role": "model" }, "finishReason": "STOP", + "avgLogprobs": -0.16819870471954346, "index": 0, "safetyRatings": [ { @@ -32,8 +33,24 @@ } ], "usageMetadata": { - "promptTokenCount": 14, - "candidatesTokenCount": 128, - "totalTokenCount": 142 - } -} + "promptTokenCount": 223, + "candidatesTokenCount": 12, + "totalTokenCount": 235, + "trafficType": "ON_DEMAND", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 223 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 12 + } + ] + }, + "modelVersion": "gemini-1.5-pro-002", + "createTime": "2025-07-07T22:16:33.098690Z", + "responseId": "QUdsaIKDBs-u3NoP5Iyb8QY" +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/llama2/ollama/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/llama2/ollama/llm-v1-chat.json index 98a3bbfc201..6e12de55155 100644 --- a/spec/fixtures/ai-proxy/unit/real-responses/llama2/ollama/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/real-responses/llama2/ollama/llm-v1-chat.json @@ -10,6 +10,6 @@ "load_duration": 1229365792, "prompt_eval_count": 26, "prompt_eval_duration": 167969000, - "eval_count": 100, + "eval_count": 139, "eval_duration": 2658646000 } diff --git a/spec/fixtures/ai-proxy/unit/real-responses/llama2/ollama/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/real-responses/llama2/ollama/llm-v1-completions.json index 644d407880d..e13da4188ad 100644 --- a/spec/fixtures/ai-proxy/unit/real-responses/llama2/ollama/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/real-responses/llama2/ollama/llm-v1-completions.json @@ -1,7 +1,7 @@ { "model": "llama2", "created_at": "2024-01-15T08:14:21.967358Z", - "response": "Because I said so.", + "response": "You cannot divide by zero because it is not a valid operation in mathematics.", "done": true, "context": [ ], diff --git a/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-chat.txt b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-chat.txt index 2f7c45fe0a5..620076905eb 100644 --- a/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-chat.txt +++ b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-chat.txt @@ -1 +1,2 @@ data: {"choices": [{"delta": {"content": "the answer"},"finish_reason": null,"index": 0,"logprobs": null}],"created": 1711938086,"id": "chatcmpl-991aYb1iD8OSD54gcxZxv8uazlTZy","model": "gpt-4-0613","object": "chat.completion.chunk","system_fingerprint": null} + diff --git a/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt index e9e1b313fa1..75852f2c7be 100644 --- a/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt +++ b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt @@ -1 +1,2 @@ data: {"choices": [{"finish_reason": null,"index": 0,"logprobs": null,"text": "the answer"}],"created": 1711938803,"id": "cmpl-991m7YSJWEnzrBqk41In8Xer9RIEB","model": "gpt-3.5-turbo-instruct","object": "text_completion"} + diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/expected-output.json new file mode 100644 index 00000000000..16ec9eb49f7 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/expected-output.json @@ -0,0 +1,11 @@ +[ + { + "data": "{\"event_type\": \"text-generation\", \"text\": \"Hello\"}" + }, + { + "data": "{\"event_type\": \"text-generation\", \"text\": \"!\"}" + }, + { + "data": "{\"event_type\": \"stream-end\", \"response\": {\"text\": \"Hello!\", \"generation_id\": \"29f14a5a-11de-4cae-9800-25e4747408ea\", \"chat_history\": [{\"role\": \"USER\", \"message\": \"hello world!\"}, {\"role\": \"CHATBOT\", \"message\": \"Hello!\"}], \"finish_reason\": \"COMPLETE\", \"meta\": {\"api_version\": {\"version\": \"1\"}, \"billed_units\": {\"input_tokens\": 3, \"output_tokens\": 9}, \"tokens\": {\"input_tokens\": 69, \"output_tokens\": 9}}}, \"finish_reason\": \"COMPLETE\"}" + } +] diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/input.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/input.json new file mode 100644 index 00000000000..d938bdc12cb --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/cohere/input.json @@ -0,0 +1,3 @@ +{ "event_type": "text-generation", "text": "Hello" } +{ "event_type": "text-generation", "text": "!" } +{ "event_type": "stream-end", "response": { "text": "Hello!", "generation_id": "29f14a5a-11de-4cae-9800-25e4747408ea", "chat_history": [ { "role": "USER", "message": "hello world!" }, { "role": "CHATBOT", "message": "Hello!" } ], "finish_reason": "COMPLETE", "meta": { "api_version": { "version": "1" }, "billed_units": { "input_tokens": 3, "output_tokens": 9 }, "tokens": { "input_tokens": 69, "output_tokens": 9 } } }, "finish_reason": "COMPLETE" } diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/gemini/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/gemini/expected-output.json new file mode 100644 index 00000000000..162647509fa --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/gemini/expected-output.json @@ -0,0 +1,11 @@ +[ + { + "data": "{\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"blah\"}], \"role\": \"model\"}, \"index\": 0}], \"usageMetadata\": {\"promptTokenCount\": 4, \"candidatesTokenCount\": 1673, \"totalTokenCount\": 2918, \"promptTokensDetails\": [{\"modality\": \"TEXT\", \"tokenCount\": 4}], \"thoughtsTokenCount\": 1241}, \"modelVersion\": \"gemini-2.5-flash\", \"responseId\": \"VrOIaOHJEdCqqtsP5L_oqQ0\"}" + }, + { "data": "{\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"blahblah\"}], \"role\": \"model\"}, \"index\": 0}], \"usageMetadata\": {\"promptTokenCount\": 4, \"candidatesTokenCount\": 1723, \"totalTokenCount\": 2968, \"promptTokensDetails\": [{\"modality\": \"TEXT\", \"tokenCount\": 4}], \"thoughtsTokenCount\": 1241}, \"modelVersion\": \"gemini-2.5-flash\", \"responseId\": \"VrOIaOHJEdCqqtsP5L_oqQ0\"}" + }, + { + "data": "{\"candidates\": [{\"content\": {\"parts\": [{\"text\": \" that, to us, appear intelligent.\"}], \"role\": \"model\"}, \"finishReason\": \"STOP\", \"index\": 0}], \"usageMetadata\": {\"promptTokenCount\": 4, \"candidatesTokenCount\": 1731, \"totalTokenCount\": 2976, \"promptTokensDetails\": [{\"modality\": \"TEXT\", \"tokenCount\": 4}], \"thoughtsTokenCount\": 1241}, \"modelVersion\": \"gemini-2.5-flash\", \"responseId\": \"VrOIaOHJEdCqqtsP5L_oqQ0\"}" + }, + { "data": "[DONE]" } +] diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/gemini/input.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/gemini/input.json new file mode 100644 index 00000000000..9272dd70f0a --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/gemini/input.json @@ -0,0 +1,91 @@ +[{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "blah" + } + ], + "role": "model" + }, + "index": 0 + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 1673, + "totalTokenCount": 2918, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 4 + } + ], + "thoughtsTokenCount": 1241 + }, + "modelVersion": "gemini-2.5-flash", + "responseId": "VrOIaOHJEdCqqtsP5L_oqQ0" +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "blahblah" + } + ], + "role": "model" + }, + "index": 0 + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 1723, + "totalTokenCount": 2968, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 4 + } + ], + "thoughtsTokenCount": 1241 + }, + "modelVersion": "gemini-2.5-flash", + "responseId": "VrOIaOHJEdCqqtsP5L_oqQ0" +} +, +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": " that, to us, appear intelligent." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 1731, + "totalTokenCount": 2976, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 4 + } + ], + "thoughtsTokenCount": 1241 + }, + "modelVersion": "gemini-2.5-flash", + "responseId": "VrOIaOHJEdCqqtsP5L_oqQ0" +} +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json index f515516c7ec..82ba44bbd37 100644 --- a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json @@ -7,5 +7,8 @@ }, { "data": "{ \"choices\": [ { \"delta\": {}, \"finish_reason\": \"stop\", \"index\": 0, \"logprobs\": null } ], \"created\": 1720136012, \"id\": \"chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY\", \"model\": \"gpt-4-0613\", \"object\": \"chat.completion.chunk\", \"system_fingerprint\": null}" + }, + { + "data": "[DONE]" } ] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin index efe2ad50c65..be73ef31e64 100644 --- a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin @@ -1,7 +1,5 @@ -data: { "choices": [ { "delta": { "content": "", "role": "assistant" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} +data: { "choices": [ { "delta": { "content": "", "role": "assistant" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} +data: { "choices": [ { "delta": { "content": "2" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} +data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} +data: [DONE] -data: { "choices": [ { "delta": { "content": "2" }, "finish_reason": null, "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} - -data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1720136012, "id": "chatcmpl-9hQFArK1oMZcRwaIa86RGwrjVNmeY", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null} - -data: [DONE] \ No newline at end of file diff --git a/spec/helpers/ai/bedrock_mock.lua b/spec/helpers/ai/bedrock_mock.lua new file mode 100644 index 00000000000..cdc229c4f52 --- /dev/null +++ b/spec/helpers/ai/bedrock_mock.lua @@ -0,0 +1,49 @@ +-- +-- imports +-- + +local mocker = require("spec.fixtures.mocker") + +local mock_bedrock_embeddings = require("spec.helpers.ai.embeddings_mock").mock_bedrock_embeddings +-- +-- private vars +-- + +-- +-- private functions +-- + +local mock_request_router = function(_self, url, opts) + if string.find(url, "model/.+/invoke") then + return mock_bedrock_embeddings(opts, url) + end + + return nil, "URL " .. url .. " is not supported by mocking" +end + +-- +-- public functions +-- + +local function setup(finally) + mocker.setup(finally, { + modules = { + { "resty.http", { + new = function() + return { + request_uri = mock_request_router, + } + end, + } }, + } + }) +end + +-- +-- module +-- + +return { + -- functions + setup = setup, +} diff --git a/spec/helpers/ai/gemini_mock.lua b/spec/helpers/ai/gemini_mock.lua new file mode 100644 index 00000000000..4d88c4d9eea --- /dev/null +++ b/spec/helpers/ai/gemini_mock.lua @@ -0,0 +1,94 @@ +-- gemini_mock.lua +-- +-- imports +-- + +local mocker = require("spec.fixtures.mocker") +local cjson = require("cjson.safe") + +local mock_gemini_embeddings = require("spec.helpers.ai.embeddings_mock").mock_gemini_embeddings +local mock_vertex_embeddings = require("spec.helpers.ai.embeddings_mock").mock_vertex_embeddings + +-- +-- private vars +-- + +-- +-- private functions +-- + +local mock_request_router = function(_self, url, opts) + if string.find(url, "404") then + return { + status = 404, + body = "404 response", + headers = {}, + } + end + + if string.find(url, "not%-json") then + return { + status = 200, + body = "not a json", + headers = {}, + } + end + + if string.find(url, "missing%-embeddings") then + return { + status = 200, + body = cjson.encode({ + predictions = { { + embeddings = { + statistics = { + truncated = false, + token_count = 8, + }, + }, + } }, + metadata = { + billableCharacterCount = 8, + }, + }), + headers = {}, + } + end + + if string.find(url, "%-aiplatform%.googleapis%.com/v1/projects/.+/locations/.+/publishers/.+/models/.+:predict") then + return mock_vertex_embeddings(opts, url) + end + + -- Public Gemini API pattern: https://generativelanguage.googleapis.com/v1beta/models/{model}:embedContent + if string.find(url, "generativelanguage%.googleapis%.com/v1beta/models/.+:embedContent") then + return mock_gemini_embeddings(opts, url) + end + + return nil, "URL " .. url .. " is not supported by gemini mocking" +end + +-- +-- public functions +-- + +local function setup(finally) + mocker.setup(finally, { + modules = { + { "resty.http", { + new = function() + return { + request_uri = mock_request_router, + } + end, + } }, + } + }) +end + +-- +-- module +-- + +return { + -- functions + setup = setup, +} diff --git a/spec/helpers/ai/huggingface_mock.lua b/spec/helpers/ai/huggingface_mock.lua new file mode 100644 index 00000000000..d69eb180151 --- /dev/null +++ b/spec/helpers/ai/huggingface_mock.lua @@ -0,0 +1,53 @@ +-- +-- imports +-- + +local mocker = require("spec.fixtures.mocker") + +local mock_huggingface_embeddings = require("spec.helpers.ai.embeddings_mock").mock_huggingface_embeddings + +-- +-- private vars +-- + +local api = "https://router.huggingface.co" +local embeddings_url = api .. "/hf-inference/models/distilbert-base-uncased/pipeline/feature-extraction" + +-- +-- private functions +-- + +local mock_request_router = function(_self, url, opts) + if url == embeddings_url then + return mock_huggingface_embeddings(opts) + end + + return nil, "URL " .. url .. " is not supported by mocking" +end + +-- +-- public functions +-- + +local function setup(finally) + mocker.setup(finally, { + modules = { + { "resty.http", { + new = function() + return { + request_uri = mock_request_router, + } + end, + } }, + } + }) +end + +-- +-- module +-- + +return { + -- functions + setup = setup, +} diff --git a/t/01-pdk/06-service-request/07-set_body.t b/t/01-pdk/06-service-request/07-set_body.t index 4e7b3ed564d..8c3b1e898ea 100644 --- a/t/01-pdk/06-service-request/07-set_body.t +++ b/t/01-pdk/06-service-request/07-set_body.t @@ -1075,6 +1075,7 @@ qq{ ngx.say((raw_body:gsub("\\r", ""))) ngx.say("foo: {", mpvalues.foo, "}") ngx.say("zzz: {", mpvalues.zzz, "}") + ngx.say("yes: {", mpvalues.yes, "}") } } } @@ -1089,6 +1090,7 @@ qq{ assert(pdk.service.request.set_body({ foo = "hello world", zzz = "goodbye world", + yes = true, })) } @@ -1115,6 +1117,10 @@ Content-Disposition: form-data; name="foo" hello world --xxyyzz +Content-Disposition: form-data; name="yes" + +true +--xxyyzz Content-Disposition: form-data; name="zzz" goodbye world @@ -1122,5 +1128,6 @@ goodbye world foo: {hello world} zzz: {goodbye world} +yes: {true} --- no_error_log [error]