@@ -10,11 +10,13 @@ import (
1010 "strings"
1111 "sync"
1212
13+ "github.com/bytedance/sonic"
1314 bifrost "github.com/maximhq/bifrost/core"
1415 "github.com/maximhq/bifrost/core/schemas"
1516 "github.com/maximhq/bifrost/framework/configstore"
1617 configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
1718 "github.com/maximhq/bifrost/framework/modelcatalog"
19+ "github.com/valyala/fasthttp"
1820)
1921
2022// PluginName is the name of the governance plugin
@@ -153,37 +155,57 @@ func (p *GovernancePlugin) GetName() string {
153155 return PluginName
154156}
155157
156- // TransportInterceptor intercepts requests before they are processed (governance decision point)
157- func (p * GovernancePlugin ) TransportInterceptor (ctx * context.Context , url string , headers map [string ]string , body map [string ]any ) (map [string ]string , map [string ]any , error ) {
158- var virtualKeyValue string
159- var err error
160-
161- for header , value := range headers {
162- if strings .ToLower (string (header )) == string (schemas .BifrostContextKeyVirtualKey ) {
163- virtualKeyValue = string (value )
164- break
158+ // HTTPTransportMiddleware intercepts requests before they are processed (governance decision point)
159+ func (p * GovernancePlugin ) HTTPTransportMiddleware (next fasthttp.RequestHandler ) fasthttp.RequestHandler {
160+ return func (ctx * fasthttp.RequestCtx ) {
161+ var virtualKeyValue string
162+ vkHeader := ctx .Request .Header .Peek ("x-bf-vk" )
163+ if string (vkHeader ) == "" {
164+ next (ctx )
165+ return
165166 }
167+ virtualKeyValue = string (vkHeader )
168+ // Get the virtual key from the store
169+ virtualKey , ok := p .store .GetVirtualKey (virtualKeyValue )
170+ if ! ok || virtualKey == nil || ! virtualKey .IsActive {
171+ next (ctx )
172+ return
173+ }
174+ headers , err := p .addMCPIncludeTools (nil , virtualKey )
175+ if err != nil {
176+ p .logger .Error ("failed to add MCP include tools: %v" , err )
177+ next (ctx )
178+ return
179+ }
180+ for header , value := range headers {
181+ ctx .Request .Header .Set (header , value )
182+ }
183+ if ctx .Request .Body () == nil {
184+ next (ctx )
185+ return
186+ }
187+ var payload map [string ]any
188+ err = sonic .Unmarshal (ctx .Request .Body (), & payload )
189+ if err != nil {
190+ p .logger .Error ("failed to marshal request body to check for virtual key: %v" , err )
191+ next (ctx )
192+ return
193+ }
194+ payload , err = p .loadBalanceProvider (payload , virtualKey )
195+ if err != nil {
196+ p .logger .Error ("failed to load balance provider: %v" , err )
197+ next (ctx )
198+ return
199+ }
200+ body , err := sonic .Marshal (payload )
201+ if err != nil {
202+ p .logger .Error ("failed to marshal request body to check for virtual key: %v" , err )
203+ next (ctx )
204+ return
205+ }
206+ ctx .Request .SetBody (body )
207+ next (ctx )
166208 }
167- if virtualKeyValue == "" {
168- return headers , body , nil
169- }
170-
171- virtualKey , ok := p .store .GetVirtualKey (virtualKeyValue )
172- if ! ok || virtualKey == nil || ! virtualKey .IsActive {
173- return headers , body , nil
174- }
175-
176- body , err = p .loadBalanceProvider (body , virtualKey )
177- if err != nil {
178- return headers , body , err
179- }
180-
181- headers , err = p .addMCPIncludeTools (headers , virtualKey )
182- if err != nil {
183- return headers , body , err
184- }
185-
186- return headers , body , nil
187209}
188210
189211func (p * GovernancePlugin ) loadBalanceProvider (body map [string ]any , virtualKey * configstoreTables.TableVirtualKey ) (map [string ]any , error ) {
0 commit comments