diff --git a/ent/client.go b/ent/client.go index a1925e46d..722bcd687 100644 --- a/ent/client.go +++ b/ent/client.go @@ -1753,8 +1753,7 @@ func (c *PaymentOrderClient) QueryTransactions(_m *PaymentOrder) *TransactionLog // Hooks returns the client hooks. func (c *PaymentOrderClient) Hooks() []Hook { - hooks := c.hooks.PaymentOrder - return append(hooks[:len(hooks):len(hooks)], paymentorder.Hooks[:]...) + return c.hooks.PaymentOrder } // Interceptors returns the client interceptors. diff --git a/ent/paymentorder/paymentorder.go b/ent/paymentorder/paymentorder.go index 867cde5bb..09c8b623a 100644 --- a/ent/paymentorder/paymentorder.go +++ b/ent/paymentorder/paymentorder.go @@ -6,7 +6,6 @@ import ( "fmt" "time" - "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/google/uuid" @@ -223,13 +222,7 @@ func ValidColumn(column string) bool { return false } -// Note that the variables below are initialized by the runtime -// package on the initialization of the application. Therefore, -// it should be imported in the main as follows: -// -// import _ "github.com/paycrest/aggregator/ent/runtime" var ( - Hooks [1]ent.Hook // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. diff --git a/ent/paymentorder_create.go b/ent/paymentorder_create.go index 80846c6b4..a3c638b63 100644 --- a/ent/paymentorder_create.go +++ b/ent/paymentorder_create.go @@ -602,9 +602,7 @@ func (_c *PaymentOrderCreate) Mutation() *PaymentOrderMutation { // Save creates the PaymentOrder in the database. func (_c *PaymentOrderCreate) Save(ctx context.Context) (*PaymentOrder, error) { - if err := _c.defaults(); err != nil { - return nil, err - } + _c.defaults() return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) } @@ -631,74 +629,44 @@ func (_c *PaymentOrderCreate) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_c *PaymentOrderCreate) defaults() error { +func (_c *PaymentOrderCreate) defaults() { if _, ok := _c.mutation.CreatedAt(); !ok { - if paymentorder.DefaultCreatedAt == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultCreatedAt (forgotten import ent/runtime?)") - } v := paymentorder.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) } if _, ok := _c.mutation.UpdatedAt(); !ok { - if paymentorder.DefaultUpdatedAt == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultUpdatedAt (forgotten import ent/runtime?)") - } v := paymentorder.DefaultUpdatedAt() _c.mutation.SetUpdatedAt(v) } if _, ok := _c.mutation.AmountPaid(); !ok { - if paymentorder.DefaultAmountPaid == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultAmountPaid (forgotten import ent/runtime?)") - } v := paymentorder.DefaultAmountPaid() _c.mutation.SetAmountPaid(v) } if _, ok := _c.mutation.AmountReturned(); !ok { - if paymentorder.DefaultAmountReturned == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultAmountReturned (forgotten import ent/runtime?)") - } v := paymentorder.DefaultAmountReturned() _c.mutation.SetAmountReturned(v) } if _, ok := _c.mutation.PercentSettled(); !ok { - if paymentorder.DefaultPercentSettled == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultPercentSettled (forgotten import ent/runtime?)") - } v := paymentorder.DefaultPercentSettled() _c.mutation.SetPercentSettled(v) } if _, ok := _c.mutation.SenderFee(); !ok { - if paymentorder.DefaultSenderFee == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultSenderFee (forgotten import ent/runtime?)") - } v := paymentorder.DefaultSenderFee() _c.mutation.SetSenderFee(v) } if _, ok := _c.mutation.NetworkFee(); !ok { - if paymentorder.DefaultNetworkFee == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultNetworkFee (forgotten import ent/runtime?)") - } v := paymentorder.DefaultNetworkFee() _c.mutation.SetNetworkFee(v) } if _, ok := _c.mutation.ProtocolFee(); !ok { - if paymentorder.DefaultProtocolFee == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultProtocolFee (forgotten import ent/runtime?)") - } v := paymentorder.DefaultProtocolFee() _c.mutation.SetProtocolFee(v) } if _, ok := _c.mutation.OrderPercent(); !ok { - if paymentorder.DefaultOrderPercent == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultOrderPercent (forgotten import ent/runtime?)") - } v := paymentorder.DefaultOrderPercent() _c.mutation.SetOrderPercent(v) } if _, ok := _c.mutation.FeePercent(); !ok { - if paymentorder.DefaultFeePercent == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultFeePercent (forgotten import ent/runtime?)") - } v := paymentorder.DefaultFeePercent() _c.mutation.SetFeePercent(v) } @@ -723,13 +691,9 @@ func (_c *PaymentOrderCreate) defaults() error { _c.mutation.SetOrderType(v) } if _, ok := _c.mutation.ID(); !ok { - if paymentorder.DefaultID == nil { - return fmt.Errorf("ent: uninitialized paymentorder.DefaultID (forgotten import ent/runtime?)") - } v := paymentorder.DefaultID() _c.mutation.SetID(v) } - return nil } // check runs all checks and user-defined validators on the builder. diff --git a/ent/paymentorder_update.go b/ent/paymentorder_update.go index b0da9d6fe..f95ea2934 100644 --- a/ent/paymentorder_update.go +++ b/ent/paymentorder_update.go @@ -891,9 +891,7 @@ func (_u *PaymentOrderUpdate) RemoveTransactions(v ...*TransactionLog) *PaymentO // Save executes the query and returns the number of nodes affected by the update operation. func (_u *PaymentOrderUpdate) Save(ctx context.Context) (int, error) { - if err := _u.defaults(); err != nil { - return 0, err - } + _u.defaults() return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) } @@ -920,15 +918,11 @@ func (_u *PaymentOrderUpdate) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_u *PaymentOrderUpdate) defaults() error { +func (_u *PaymentOrderUpdate) defaults() { if _, ok := _u.mutation.UpdatedAt(); !ok { - if paymentorder.UpdateDefaultUpdatedAt == nil { - return fmt.Errorf("ent: uninitialized paymentorder.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") - } v := paymentorder.UpdateDefaultUpdatedAt() _u.mutation.SetUpdatedAt(v) } - return nil } // check runs all checks and user-defined validators on the builder. @@ -2342,9 +2336,7 @@ func (_u *PaymentOrderUpdateOne) Select(field string, fields ...string) *Payment // Save executes the query and returns the updated PaymentOrder entity. func (_u *PaymentOrderUpdateOne) Save(ctx context.Context) (*PaymentOrder, error) { - if err := _u.defaults(); err != nil { - return nil, err - } + _u.defaults() return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) } @@ -2371,15 +2363,11 @@ func (_u *PaymentOrderUpdateOne) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_u *PaymentOrderUpdateOne) defaults() error { +func (_u *PaymentOrderUpdateOne) defaults() { if _, ok := _u.mutation.UpdatedAt(); !ok { - if paymentorder.UpdateDefaultUpdatedAt == nil { - return fmt.Errorf("ent: uninitialized paymentorder.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") - } v := paymentorder.UpdateDefaultUpdatedAt() _u.mutation.SetUpdatedAt(v) } - return nil } // check runs all checks and user-defined validators on the builder. diff --git a/ent/runtime/runtime.go b/ent/runtime/runtime.go index 225d6459a..a8f6e58fb 100644 --- a/ent/runtime/runtime.go +++ b/ent/runtime/runtime.go @@ -158,8 +158,6 @@ func init() { // network.DefaultGatewayContractAddress holds the default value on creation for the gateway_contract_address field. network.DefaultGatewayContractAddress = networkDescGatewayContractAddress.Default.(string) paymentorderMixin := schema.PaymentOrder{}.Mixin() - paymentorderHooks := schema.PaymentOrder{}.Hooks() - paymentorder.Hooks[0] = paymentorderHooks[0] paymentorderMixinFields0 := paymentorderMixin[0].Fields() _ = paymentorderMixinFields0 paymentorderFields := schema.PaymentOrder{}.Fields() diff --git a/routers/index.go b/routers/index.go index 6fff01f83..235462477 100644 --- a/routers/index.go +++ b/routers/index.go @@ -133,6 +133,7 @@ func senderRoutes(route *gin.Engine) { v1 := route.Group("/v1/sender/") v1.Use(middleware.OrdersReadinessMiddleware()) v1.Use(middleware.DynamicAuthMiddleware) + v1.Use(middleware.DomainWhitelistMiddleware) v1.Use(middleware.OnlySenderMiddleware) v1.POST("orders", senderCtrl.InitiatePaymentOrder) @@ -144,6 +145,7 @@ func senderRoutes(route *gin.Engine) { v2 := route.Group("/v2/sender/") v2.Use(middleware.OrdersReadinessMiddleware()) v2.Use(middleware.DynamicAuthMiddleware) + v2.Use(middleware.DomainWhitelistMiddleware) v2.Use(middleware.OnlySenderMiddleware) v2.POST("orders", senderCtrl.InitiatePaymentOrderV2) diff --git a/routers/middleware/domain_whitelist.go b/routers/middleware/domain_whitelist.go new file mode 100644 index 000000000..f61c942a8 --- /dev/null +++ b/routers/middleware/domain_whitelist.go @@ -0,0 +1,47 @@ +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/paycrest/aggregator/ent" + u "github.com/paycrest/aggregator/utils" +) + +// DomainWhitelistMiddleware enforces the sender profile's domain_whitelist. +// Must run after auth middleware so "sender" is in context. +// - No sender in context: pass through (e.g. provider routes). +// - Empty whitelist: allow (backward compatibility). +// - Non-empty whitelist: allow only if Origin/Referer domain is in whitelist; otherwise 403. +func DomainWhitelistMiddleware(c *gin.Context) { + val, exists := c.Get("sender") + if !exists || val == nil { + c.Next() + return + } + profile, ok := val.(*ent.SenderProfile) + if !ok || profile == nil { + c.Next() + return + } + whitelist := profile.DomainWhitelist + if len(whitelist) == 0 { + c.Next() + return + } + origin := c.GetHeader("Origin") + referer := c.GetHeader("Referer") + domain := u.ExtractRequestDomain(origin, referer) + if u.IsDomainAllowed(domain, whitelist) { + c.Next() + return + } + // When whitelist is set but no domain could be extracted, block (e.g. missing Origin/Referer). + if domain == "" { + u.APIResponse(c, http.StatusForbidden, "error", "Origin or Referer required when domain whitelist is configured", nil) + c.Abort() + return + } + u.APIResponse(c, http.StatusForbidden, "error", "Domain not allowed", nil) + c.Abort() +} diff --git a/routers/middleware/domain_whitelist_test.go b/routers/middleware/domain_whitelist_test.go new file mode 100644 index 000000000..eec18863f --- /dev/null +++ b/routers/middleware/domain_whitelist_test.go @@ -0,0 +1,109 @@ +// Domain whitelist middleware tests. +// Run from the aggregator module root: go test ./routers/middleware/... -run TestDomainWhitelist -v +package middleware + +import ( + "net/http" + "testing" + + "github.com/gin-gonic/gin" + "github.com/paycrest/aggregator/ent" + "github.com/paycrest/aggregator/utils/test" + "github.com/stretchr/testify/assert" +) + +func TestDomainWhitelistMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("no sender in context allows request", func(t *testing.T) { + router := gin.New() + router.GET("/test", DomainWhitelistMiddleware, func(c *gin.Context) { + c.Status(http.StatusOK) + }) + w, _ := test.PerformRequest(t, "GET", "/test", nil, nil, router) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("sender with empty whitelist allows any origin", func(t *testing.T) { + router := gin.New() + router.GET("/test", setSenderProfile(&ent.SenderProfile{DomainWhitelist: nil}), DomainWhitelistMiddleware, func(c *gin.Context) { + c.Status(http.StatusOK) + }) + headers := map[string]string{"Origin": "https://any-domain.com"} + w, _ := test.PerformRequest(t, "GET", "/test", nil, headers, router) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("sender with empty whitelist slice allows any origin", func(t *testing.T) { + router := gin.New() + router.GET("/test", setSenderProfile(&ent.SenderProfile{DomainWhitelist: []string{}}), DomainWhitelistMiddleware, func(c *gin.Context) { + c.Status(http.StatusOK) + }) + headers := map[string]string{"Origin": "https://any-domain.com"} + w, _ := test.PerformRequest(t, "GET", "/test", nil, headers, router) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("sender with whitelist allows whitelisted origin", func(t *testing.T) { + router := gin.New() + router.GET("/test", setSenderProfile(&ent.SenderProfile{DomainWhitelist: []string{"example.com"}}), DomainWhitelistMiddleware, func(c *gin.Context) { + c.Status(http.StatusOK) + }) + headers := map[string]string{"Origin": "https://example.com"} + w, _ := test.PerformRequest(t, "GET", "/test", nil, headers, router) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("sender with whitelist allows whitelisted subdomain", func(t *testing.T) { + router := gin.New() + router.GET("/test", setSenderProfile(&ent.SenderProfile{DomainWhitelist: []string{"example.com"}}), DomainWhitelistMiddleware, func(c *gin.Context) { + c.Status(http.StatusOK) + }) + headers := map[string]string{"Origin": "https://app.example.com"} + w, _ := test.PerformRequest(t, "GET", "/test", nil, headers, router) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("sender with whitelist blocks non-whitelisted origin", func(t *testing.T) { + router := gin.New() + router.GET("/test", setSenderProfile(&ent.SenderProfile{DomainWhitelist: []string{"example.com"}}), DomainWhitelistMiddleware, func(c *gin.Context) { + c.Status(http.StatusOK) + }) + headers := map[string]string{"Origin": "https://evil.com"} + w, _ := test.PerformRequest(t, "GET", "/test", nil, headers, router) + assert.Equal(t, http.StatusForbidden, w.Code) + body := decodeResponseBody(t, w) + assert.Equal(t, "error", body["status"]) + assert.Equal(t, "Domain not allowed", body["message"]) + }) + + t.Run("sender with whitelist but no origin or referer returns 403", func(t *testing.T) { + router := gin.New() + router.GET("/test", setSenderProfile(&ent.SenderProfile{DomainWhitelist: []string{"example.com"}}), DomainWhitelistMiddleware, func(c *gin.Context) { + c.Status(http.StatusOK) + }) + w, _ := test.PerformRequest(t, "GET", "/test", nil, nil, router) + assert.Equal(t, http.StatusForbidden, w.Code) + body := decodeResponseBody(t, w) + assert.Equal(t, "error", body["status"]) + assert.Equal(t, "Origin or Referer required when domain whitelist is configured", body["message"]) + }) + + t.Run("sender with whitelist allows referer when origin missing", func(t *testing.T) { + router := gin.New() + router.GET("/test", setSenderProfile(&ent.SenderProfile{DomainWhitelist: []string{"example.com"}}), DomainWhitelistMiddleware, func(c *gin.Context) { + c.Status(http.StatusOK) + }) + headers := map[string]string{"Referer": "https://example.com/page"} + w, _ := test.PerformRequest(t, "GET", "/test", nil, headers, router) + assert.Equal(t, http.StatusOK, w.Code) + }) +} + +// setSenderProfile returns a handler that sets the given sender profile in context. +func setSenderProfile(profile *ent.SenderProfile) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set("sender", profile) + c.Next() + } +} diff --git a/utils/domain.go b/utils/domain.go new file mode 100644 index 000000000..ce37fee3a --- /dev/null +++ b/utils/domain.go @@ -0,0 +1,67 @@ +package utils + +import ( + "net/url" + "strings" +) + +// ExtractRequestDomain returns the host (domain) from the request's Origin or Referer header. +// Prefers Origin, falls back to Referer. Returns empty string if neither is present or parseable. +func ExtractRequestDomain(origin, referer string) string { + if origin != "" { + if host := hostFromURL(origin); host != "" { + return host + } + } + if referer != "" { + if host := hostFromURL(referer); host != "" { + return host + } + } + return "" +} + +// Returns empty string on parse error. +func hostFromURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + // Ensure scheme for url.Parse + if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") { + rawURL = "https://" + rawURL + } + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + return strings.ToLower(strings.TrimSpace(u.Hostname())) +} + +// IsDomainAllowed checks if requestHost is allowed by whitelist. +// - Empty whitelist: allow any domain (backward compatibility). +// - Exact match: requestHost equals a whitelist entry (normalized lowercase). +// - Subdomain match: requestHost is a subdomain of a whitelist entry (e.g. "app.example.com" matches "example.com"). +func IsDomainAllowed(requestHost string, whitelist []string) bool { + if len(whitelist) == 0 { + return true + } + requestHost = strings.ToLower(strings.TrimSpace(requestHost)) + if requestHost == "" { + return false + } + for _, allowed := range whitelist { + allowed = strings.ToLower(strings.TrimSpace(allowed)) + if allowed == "" { + continue + } + if requestHost == allowed { + return true + } + // Subdomain: requestHost must end with "."+allowed + if strings.HasSuffix(requestHost, "."+allowed) { + return true + } + } + return false +} diff --git a/utils/domain_test.go b/utils/domain_test.go new file mode 100644 index 000000000..dc1a93cfc --- /dev/null +++ b/utils/domain_test.go @@ -0,0 +1,147 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractRequestDomain(t *testing.T) { + tests := []struct { + name string + origin string + referer string + want string + }{ + { + name: "empty both", + origin: "", + referer: "", + want: "", + }, + { + name: "origin with https", + origin: "https://app.example.com/path", + referer: "", + want: "app.example.com", + }, + { + name: "origin with http and port", + origin: "http://localhost:3000", + referer: "", + want: "localhost", + }, + { + name: "origin preferred over referer", + origin: "https://origin.example.com", + referer: "https://referer.example.com", + want: "origin.example.com", + }, + { + name: "fallback to referer when origin empty", + origin: "", + referer: "https://referer.example.com/foo?q=1", + want: "referer.example.com", + }, + { + name: "invalid origin fallback to referer", + origin: "://bad", + referer: "https://good.example.com", + want: "good.example.com", + }, + { + name: "origin with trailing slash", + origin: "https://api.example.com/", + referer: "", + want: "api.example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractRequestDomain(tt.origin, tt.referer) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestIsDomainAllowed(t *testing.T) { + tests := []struct { + name string + requestHost string + whitelist []string + want bool + }{ + { + name: "empty whitelist allows any", + requestHost: "any.example.com", + whitelist: nil, + want: true, + }, + { + name: "empty whitelist slice allows any", + requestHost: "evil.com", + whitelist: []string{}, + want: true, + }, + { + name: "exact match", + requestHost: "example.com", + whitelist: []string{"example.com"}, + want: true, + }, + { + name: "exact match multiple entries", + requestHost: "allowed.com", + whitelist: []string{"other.com", "allowed.com"}, + want: true, + }, + { + name: "subdomain match", + requestHost: "app.example.com", + whitelist: []string{"example.com"}, + want: true, + }, + { + name: "subdomain match deep", + requestHost: "api.app.example.com", + whitelist: []string{"example.com"}, + want: true, + }, + { + name: "no match", + requestHost: "evil.com", + whitelist: []string{"example.com"}, + want: false, + }, + { + name: "suffix not subdomain", + requestHost: "notexample.com", + whitelist: []string{"example.com"}, + want: false, + }, + { + name: "case normalized", + requestHost: "Example.COM", + whitelist: []string{"example.com"}, + want: true, + }, + { + name: "whitelist entry case normalized", + requestHost: "example.com", + whitelist: []string{"Example.COM"}, + want: true, + }, + { + name: "empty request host with non-empty whitelist", + requestHost: "", + whitelist: []string{"example.com"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsDomainAllowed(tt.requestHost, tt.whitelist) + assert.Equal(t, tt.want, got) + }) + } +}