From ea0d5e75643bae7e850b8c49b40735df172eb260 Mon Sep 17 00:00:00 2001 From: rinfx Date: Thu, 9 Jan 2025 22:04:51 +0800 Subject: [PATCH] Improve ai plugins (#1657) Co-authored-by: Kent Dong --- .../extensions/ai-cache/config/config.go | 4 +-- .../extensions/ai-prompt-template/main.go | 2 +- .../extensions/ai-proxy/provider/failover.go | 9 ++++--- .../extensions/ai-security-guard/main.go | 26 +++++-------------- .../wasm-go/extensions/ai-statistics/main.go | 24 +++++++++++------ 5 files changed, 31 insertions(+), 34 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 4bd6e2a18f..80c6147374 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -79,11 +79,11 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) { c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() if c.StreamResponseTemplate == "" { - c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" } c.ResponseTemplate = json.Get("responseTemplate").String() if c.ResponseTemplate == "" { - c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } if json.Get("enableSemanticCache").Exists() { diff --git a/plugins/wasm-go/extensions/ai-prompt-template/main.go b/plugins/wasm-go/extensions/ai-prompt-template/main.go index 9a806c76c3..da95df7082 100644 --- a/plugins/wasm-go/extensions/ai-prompt-template/main.go +++ b/plugins/wasm-go/extensions/ai-prompt-template/main.go @@ -34,7 +34,7 @@ func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper. func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action { templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable") - if templateEnable != "true" { + if templateEnable == "false" { ctx.DontReadRequestBody() return types.ActionContinue } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index e1b0b9819f..686b7d522d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -4,14 +4,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" - "github.com/google/uuid" "math/rand" "net/http" "strings" "time" - + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/google/uuid" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" @@ -551,7 +551,8 @@ func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.Ht } func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { - return ctx.GetContext(c.failover.ctxApiTokenInUse).(string) + token, _ := ctx.GetContext(c.failover.ctxApiTokenInUse).(string) + return token } func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 0e0a747fa1..c6cb475c14 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -41,9 +41,9 @@ const ( LowRisk = "low" NoRisk = "none" - OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}` - OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` - OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` + OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` + OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` DefaultRequestCheckService = "llm_query_moderation" @@ -262,8 +262,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] log.Debugf("checking request body...") startTime := time.Now().UnixMilli() content := gjson.GetBytes(body, config.requestContentJsonPath).String() - model := gjson.GetBytes(body, "model").String() - ctx.SetContext("requestModel", model) log.Debugf("Raw request content is: %s", content) if len(content) == 0 { log.Info("request content is empty. skip") @@ -308,11 +306,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) } else if gjson.GetBytes(body, "stream").Bool() { randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) } else { randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage)) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) } ctx.DontReadResponseBody() @@ -369,15 +367,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] return types.ActionPause } -func convertHeaders(hs [][2]string) map[string][]string { - ret := make(map[string][]string) - for _, h := range hs { - k, v := strings.ToLower(h[0]), h[1] - ret[k] = append(ret[k], v) - } - return ret -} - func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { if !config.checkResponse { log.Debugf("response checking is disabled") @@ -398,7 +387,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ startTime := time.Now().UnixMilli() contentType, _ := proxywasm.GetHttpResponseHeader("content-type") isStreamingResponse := strings.Contains(contentType, "event-stream") - model := ctx.GetStringContext("requestModel", "unknown") var content string if isStreamingResponse { content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath) @@ -449,11 +437,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) } else if isStreamingResponse { randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) + jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) } else { randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) + jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage)) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) } config.incrementCounter("ai_sec_response_deny", 1) diff --git a/plugins/wasm-go/extensions/ai-statistics/main.go b/plugins/wasm-go/extensions/ai-statistics/main.go index 363f59194e..1516c3ca5f 100644 --- a/plugins/wasm-go/extensions/ai-statistics/main.go +++ b/plugins/wasm-go/extensions/ai-statistics/main.go @@ -36,6 +36,7 @@ const ( RouteName = "route" ClusterName = "cluster" APIName = "api" + ConsumerKey = "x-mse-consumer" // Source Type FixedValue = "fixed_value" @@ -81,8 +82,8 @@ type AIStatisticsConfig struct { shouldBufferStreamingBody bool } -func generateMetricName(route, cluster, model, metricName string) string { - return fmt.Sprintf("route.%s.upstream.%s.model.%s.metric.%s", route, cluster, model, metricName) +func generateMetricName(route, cluster, model, consumer, metricName string) string { + return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName) } func getRouteName() (string, error) { @@ -115,6 +116,9 @@ func getClusterName() (string, error) { } func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) { + if inc == 0 { + return + } counter, ok := config.counterMetrics[metricName] if !ok { counter = proxywasm.DefineCounterMetric(metricName) @@ -158,6 +162,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo ctx.SetContext(ClusterName, cluster) ctx.SetUserAttribute(APIName, api) ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli()) + if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" { + ctx.SetContext(ConsumerKey, consumer) + } // Set user defined log & span attributes which type is fixed_value setAttributeBySource(ctx, config, FixedValue, nil, log) @@ -388,6 +395,7 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper var ok bool var route, cluster, model string var inputToken, outputToken uint64 + consumer := ctx.GetStringContext(ConsumerKey, "none") route, ok = ctx.GetContext(RouteName).(string) if !ok { log.Warnf("RouteName typd assert failed, skip metric record") @@ -421,8 +429,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") return } - config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken) - config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken) + config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken) + config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken) // Generate duration metrics var llmFirstTokenDuration, llmServiceDuration uint64 @@ -433,8 +441,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper log.Warnf("LLMFirstTokenDuration typd assert failed") return } - config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration) - config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1) + config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMFirstTokenDuration), llmFirstTokenDuration) + config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMStreamDurationCount), 1) } if ctx.GetUserAttribute(LLMServiceDuration) != nil { llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration)) @@ -442,8 +450,8 @@ func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper log.Warnf("LLMServiceDuration typd assert failed") return } - config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration) - config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1) + config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMServiceDuration), llmServiceDuration) + config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMDurationCount), 1) } }