Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion transports/bifrost-http/handlers/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
func CorsMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware {
return func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
logger.Debug("CorsMiddleware: %s", ctx.Request.URI().Path())
logger.Debug("CorsMiddleware: %s", string(ctx.Path()))
origin := string(ctx.Request.Header.Peek("Origin"))
allowed := IsOriginAllowed(origin, config.ClientConfig.AllowedOrigins)
// Check if origin is allowed (localhost always allowed + configured origins)
Expand Down
18 changes: 16 additions & 2 deletions transports/bifrost-http/handlers/middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@ package handlers
import (
"testing"

"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)

// mockLogger is a mock implementation of schemas.Logger for testing
type mockLogger struct{}

func (m *mockLogger) Debug(format string, args ...any) {}
func (m *mockLogger) Info(format string, args ...any) {}
func (m *mockLogger) Warn(format string, args ...any) {}
func (m *mockLogger) Error(format string, args ...any) {}
func (m *mockLogger) Fatal(format string, args ...any) {}
func (m *mockLogger) SetLevel(level schemas.LogLevel) {}
func (m *mockLogger) SetOutputType(outputType schemas.LoggerOutputType) {}

// TestCorsMiddleware_LocalhostOrigins tests that localhost origins are always allowed
func TestCorsMiddleware_LocalhostOrigins(t *testing.T) {
config := &lib.Config{
Expand All @@ -16,6 +28,8 @@ func TestCorsMiddleware_LocalhostOrigins(t *testing.T) {
},
}

SetLogger(&mockLogger{})

localhostOrigins := []string{
"http://localhost:3000",
"https://localhost:3000",
Expand All @@ -42,10 +56,10 @@ func TestCorsMiddleware_LocalhostOrigins(t *testing.T) {
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != origin {
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", origin, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
}
if string(ctx.Response.Header.Peek("Access-Control-Allow-Methods")) != "GET, POST, PUT, DELETE, PATCH, OPTIONS" {
if string(ctx.Response.Header.Peek("Access-Control-Allow-Methods")) != "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD" {
t.Errorf("Access-Control-Allow-Methods header not set correctly")
}
if string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != "Content-Type, Authorization, X-Requested-With" {
if string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout" {
t.Errorf("Access-Control-Allow-Headers header not set correctly")
}
if string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials")) != "true" {
Expand Down