diff --git a/go.mod b/go.mod index 308d58811..4fd6e22dc 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/modelcontextprotocol/go-sdk v1.1.0 github.com/openai/openai-go v1.12.0 github.com/pdfcpu/pdfcpu v0.11.1 + github.com/pires/go-proxyproto v0.8.1 github.com/prometheus/client_golang v1.23.2 github.com/stretchr/testify v1.11.1 github.com/vektah/gqlparser/v2 v2.5.31 diff --git a/go.sum b/go.sum index cd44f81b7..eb0b8305f 100644 --- a/go.sum +++ b/go.sum @@ -164,6 +164,8 @@ github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde h1:x0TT0RDC7UhA github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0= github.com/pdfcpu/pdfcpu v0.11.1 h1:htHBSkGH5jMKWC6e0sihBFbcKZ8vG1M67c8/dJxhjas= github.com/pdfcpu/pdfcpu v0.11.1/go.mod h1:pP3aGga7pRvwFWAm9WwFvo+V68DfANi9kxSQYioNYcw= +github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= +github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pkg/probod/api_config.go b/pkg/probod/api_config.go index d2e42a4f1..376b62415 100644 --- a/pkg/probod/api_config.go +++ b/pkg/probod/api_config.go @@ -14,14 +14,23 @@ package probod +import ( + "net" +) + type ( corsConfig struct { AllowedOrigins []string `json:"allowed-origins"` } + proxyProtocolConfig struct { + TrustedProxies []net.IP `json:"trusted-proxies"` + } + apiConfig struct { - Addr string `json:"addr"` - Cors corsConfig `json:"cors"` - ExtraHeaderFields map[string]string `json:"extra-header-fields"` + Addr string `json:"addr"` + ProxyProtocol proxyProtocolConfig `json:"proxy-protocol"` + Cors corsConfig `json:"cors"` + ExtraHeaderFields map[string]string `json:"extra-header-fields"` } ) diff --git a/pkg/probod/probod.go b/pkg/probod/probod.go index e7946386d..99a7ad425 100644 --- a/pkg/probod/probod.go +++ b/pkg/probod/probod.go @@ -27,6 +27,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/s3" + proxyproto "github.com/pires/go-proxyproto" "github.com/prometheus/client_golang/prometheus" "go.gearno.de/kit/httpclient" "go.gearno.de/kit/httpserver" @@ -81,8 +82,9 @@ type ( } trustCenterConfig struct { - HTTPAddr string `json:"http-addr"` - HTTPSAddr string `json:"https-addr"` + HTTPAddr string `json:"http-addr"` + HTTPSAddr string `json:"https-addr"` + ProxyProtocol proxyProtocolConfig `json:"proxy-protocol"` } ) @@ -578,6 +580,18 @@ func (impl *Implm) runApiServer( span.RecordError(err) return fmt.Errorf("cannot listen on %q: %w", apiServer.Addr, err) } + + if len(impl.cfg.Api.ProxyProtocol.TrustedProxies) > 0 { + policy := rejectProxyHeaderFrom(impl.cfg.Api.ProxyProtocol.TrustedProxies...) + + listener = &proxyproto.Listener{ + Listener: listener, + ReadHeaderTimeout: 10 * time.Second, + ConnPolicy: policy, + } + + l.Info("using proxy protocol", log.Any("trusted-proxies", impl.cfg.Api.ProxyProtocol.TrustedProxies)) + } defer listener.Close() serverErrCh := make(chan error, 1) @@ -732,6 +746,18 @@ func (impl *Implm) runTrustCenterServer( } defer listener.Close() + if len(impl.cfg.TrustCenter.ProxyProtocol.TrustedProxies) > 0 { + policy := rejectProxyHeaderFrom(impl.cfg.TrustCenter.ProxyProtocol.TrustedProxies...) + + listener = &proxyproto.Listener{ + Listener: listener, + ReadHeaderTimeout: 10 * time.Second, + ConnPolicy: policy, + } + + l.Info("using proxy protocol for trust center HTTP server", log.Any("trusted-proxies", impl.cfg.TrustCenter.ProxyProtocol.TrustedProxies)) + } + if err := httpServer.Serve(listener); err != nil && err != http.ErrServerClosed { return fmt.Errorf("cannot serve http requests: %w", err) } @@ -789,6 +815,18 @@ func (impl *Implm) runTrustCenterServer( } defer listener.Close() + if len(impl.cfg.TrustCenter.ProxyProtocol.TrustedProxies) > 0 { + policy := rejectProxyHeaderFrom(impl.cfg.TrustCenter.ProxyProtocol.TrustedProxies...) + + listener = &proxyproto.Listener{ + Listener: listener, + ReadHeaderTimeout: 10 * time.Second, + ConnPolicy: policy, + } + + l.Info("using proxy protocol for trust center HTTPS server", log.Any("trusted-proxies", impl.cfg.TrustCenter.ProxyProtocol.TrustedProxies)) + } + if err := httpsServer.ServeTLS(listener, "", ""); err != nil && err != http.ErrServerClosed { return fmt.Errorf("cannot serve https requests: %w", err) } @@ -829,3 +867,34 @@ func (impl *Implm) runTrustCenterServer( return ctx.Err() } + +func rejectProxyHeaderFrom(trustedIPs ...net.IP) proxyproto.ConnPolicyFunc { + return func(connOpts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { + ip, err := ipFromAddr(connOpts.Upstream) + if err != nil { + return proxyproto.REJECT, err + } + + for _, trustedIP := range trustedIPs { + if trustedIP.Equal(ip) { + return proxyproto.USE, nil + } + } + + return proxyproto.REJECT, nil + } +} + +func ipFromAddr(upstream net.Addr) (net.IP, error) { + upstreamString, _, err := net.SplitHostPort(upstream.String()) + if err != nil { + return nil, err + } + + upstreamIP := net.ParseIP(upstreamString) + if nil == upstreamIP { + return nil, fmt.Errorf("proxyproto: invalid IP address") + } + + return upstreamIP, nil +}