diff --git a/midware/middleware.go b/midware/middleware.go index f87ece6..caa1c2e 100644 --- a/midware/middleware.go +++ b/midware/middleware.go @@ -4,10 +4,13 @@ package midware import ( "bytes" + "fmt" "net/http" + "sort" "strings" "github.com/google/uuid" + "github.com/luthersystems/svc/static" ) // DefaultTraceHeader is the default header when TraceHeaders is given an empty @@ -35,19 +38,52 @@ type PathOverrides map[string]http.Handler // Wrap implements the Middleware interface. func (m PathOverrides) Wrap(next http.Handler) http.Handler { - return &pathOverridesHandler{m, next} + var prefixes []string + // public file system may have nested directories we want to access but we + // want to ensure that the /public/ handler handles the request + for path := range m { + if path != static.PublicPathPrefix && strings.HasPrefix(path, static.PublicPathPrefix) { + panic(fmt.Sprintf("PathOverride conflict: disallowed registration of nested public route: %s", path)) + } + if strings.HasSuffix(path, "/") { + prefixes = append(prefixes, path) + } + } + sort.Slice(prefixes, func(i, j int) bool { + return len(prefixes[i]) > len(prefixes[j]) + }) + + return &pathOverridesHandler{ + m: m, + prefixes: prefixes, + next: next, + } } type pathOverridesHandler struct { - m PathOverrides - next http.Handler + m PathOverrides + prefixes []string + next http.Handler } func (h *pathOverridesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if route, ok := h.m[r.URL.Path]; ok { - route.ServeHTTP(w, r) + path := r.URL.Path + + // Exact match + if handler, ok := h.m[path]; ok { + handler.ServeHTTP(w, r) return } + + // do longest match first + for _, prefix := range h.prefixes { + if strings.HasPrefix(path, prefix) { + h.m[prefix].ServeHTTP(w, r) + return + } + } + + // Default to next handler h.next.ServeHTTP(w, r) } diff --git a/midware/middleware_test.go b/midware/middleware_test.go index 538981e..2442665 100644 --- a/midware/middleware_test.go +++ b/midware/middleware_test.go @@ -7,19 +7,62 @@ import ( "net/http/httptest" "testing" + "github.com/luthersystems/svc/static" "github.com/stretchr/testify/assert" ) var basicHandler = staticBytes([]byte("applicationdata")) func TestPathOverrides(t *testing.T) { - basicOverride := &PathOverrides{"/override": staticBytes([]byte("overridden"))} - h := (basicOverride).Wrap(basicHandler) + basicOverride := &PathOverrides{ + "/override": staticBytes([]byte("overridden")), + "/api/": staticBytes([]byte("api handler")), + "/api/nested-api/": staticBytes([]byte("nested api handler")), + static.PublicPathPrefix: staticBytes([]byte("public handler")), + } + + h := basicOverride.Wrap(staticBytes([]byte("applicationdata"))) + testServer(t, h, func(t *testing.T, server *httptest.Server) { - assert.Equal(t, []byte("applicationdata"), testRequest(t, server, "GET", "/", nil, nil)) - assert.Equal(t, []byte("applicationdata"), testRequest(t, server, "GET", "/hello/world", nil, nil)) - assert.Equal(t, []byte("overridden"), testRequest(t, server, "GET", "/override", nil, nil)) - assert.Equal(t, []byte("applicationdata"), testRequest(t, server, "GET", "/override/2", nil, nil)) + t.Run("falls back to next handler on root", func(t *testing.T) { + assert.Equal(t, []byte("applicationdata"), testRequest(t, server, "GET", "/", nil, nil)) + }) + + t.Run("falls back to next handler on unmatched path", func(t *testing.T) { + assert.Equal(t, []byte("applicationdata"), testRequest(t, server, "GET", "/hello/world", nil, nil)) + }) + + t.Run("exact match override works", func(t *testing.T) { + assert.Equal(t, []byte("overridden"), testRequest(t, server, "GET", "/override", nil, nil)) + }) + + t.Run("non-exact override should fall back", func(t *testing.T) { + assert.Equal(t, []byte("applicationdata"), testRequest(t, server, "GET", "/override/2", nil, nil)) + }) + + t.Run("prefix match with /api/ works", func(t *testing.T) { + assert.Equal(t, []byte("api handler"), testRequest(t, server, "GET", "/api/user/42", nil, nil)) + }) + + t.Run("prefix match with /api/nested-api/ chooses longest path (/api/nested-api/)", func(t *testing.T) { + assert.Equal(t, []byte("nested api handler"), testRequest(t, server, "GET", "/api/nested-api/user/42", nil, nil)) + }) + + t.Run("prefix match with /public/ works", func(t *testing.T) { + assert.Equal(t, []byte("public handler"), testRequest(t, server, "GET", "/public/assets/logo.png", nil, nil)) + }) + + }) + + t.Run("panic on disallowed nested /public route", func(t *testing.T) { + assert.PanicsWithValue(t, + "PathOverride conflict: disallowed registration of nested public route: /public/nested/", + func() { + _ = PathOverrides{ + static.PublicPathPrefix: staticBytes([]byte("good")), + "/public/nested/": staticBytes([]byte("bad")), + }.Wrap(staticBytes([]byte("fallback"))) + }) }) } diff --git a/oracle/config.go b/oracle/config.go index a64f1df..d2e7bbd 100644 --- a/oracle/config.go +++ b/oracle/config.go @@ -7,6 +7,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/luthersystems/lutherauth-sdk-go/jwk" "github.com/luthersystems/svc/opttrace" + "github.com/luthersystems/svc/static" ) // DefaultConfig returns a default config. @@ -68,6 +69,8 @@ type Config struct { depTxForwarder *CookieForwarder // fakeIDP is for testing auth. fakeIDP *FakeIDP + // publicContentHandlers configures endpoints to serve public content. + publicContentHandlers *http.ServeMux } const ( @@ -83,6 +86,18 @@ func (c *Config) SetSwaggerHandler(h http.Handler) { c.swaggerHandler = h } +// SetPublicContentHandler sets the handler for /public/ routes. +func (c *Config) SetPublicContentHandler(handler http.Handler) { + if c == nil { + return + } + if c.publicContentHandlers == nil { + c.publicContentHandlers = http.NewServeMux() + } + // pattern MUST be kept in line with static.PublicHandler method + c.publicContentHandlers.Handle(static.PublicPathPrefix, handler) +} + // SetOTLPEndpoint is a helper to set the OTLP trace endpoint. func (c *Config) SetOTLPEndpoint(endpoint string) { if c == nil || endpoint == "" { diff --git a/oracle/oracle.go b/oracle/oracle.go index 0359dfe..d14b6f2 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -86,6 +86,10 @@ type Oracle struct { // claims gets app claims from grpc contexts. claims *claims.GRPCClaims + + // publicContentHandlers configures endpoints to serve public static + // content. + publicContentHandlers *http.ServeMux } // option provides additional configuration to the oracle. Primarily for @@ -171,8 +175,9 @@ func newOracle(config *Config, opts ...option) (*Oracle, error) { return nil, fmt.Errorf("invalid config: %w", err) } oracle := &Oracle{ - cfg: *config, - swaggerHandler: config.swaggerHandler, + cfg: *config, + swaggerHandler: config.swaggerHandler, + publicContentHandlers: config.publicContentHandlers, } oracle.logBase = logrus.StandardLogger().WithFields(nil) for _, opt := range opts { diff --git a/oracle/oraclerun.go b/oracle/oraclerun.go index 8748483..7f29325 100644 --- a/oracle/oraclerun.go +++ b/oracle/oraclerun.go @@ -18,6 +18,7 @@ import ( "github.com/luthersystems/svc/grpclogging" "github.com/luthersystems/svc/logmon" "github.com/luthersystems/svc/midware" + "github.com/luthersystems/svc/static" "github.com/luthersystems/svc/svcerr" "github.com/luthersystems/svc/txctx" "github.com/prometheus/client_golang/prometheus" @@ -116,7 +117,7 @@ func (orc *Oracle) txctxInterceptor(ctx context.Context, req interface{}, info * return resp, err } -func (orc *Oracle) grpcGateway(swaggerHandler http.Handler) (*runtime.ServeMux, http.Handler) { +func (orc *Oracle) grpcGateway(swaggerHandler http.Handler, publicContentHandler *http.ServeMux) (*runtime.ServeMux, http.Handler) { jsonapi := orc.grpcGatewayMux() pathOverides := midware.PathOverrides{ healthCheckPath: orc.healthCheckHandler(), @@ -124,6 +125,10 @@ func (orc *Oracle) grpcGateway(swaggerHandler http.Handler) (*runtime.ServeMux, if swaggerHandler != nil { pathOverides[swaggerPath] = swaggerHandler } + if publicContentHandler != nil { + pathOverides[static.PublicPathPrefix] = publicContentHandler + } + middleware := midware.Chain{ // The trace header middleware appears early in the chain // because of how important it is that they happen for essentially all @@ -233,7 +238,7 @@ func (orc *Oracle) StartGateway(ctx context.Context, grpcConfig GrpcGatewayConfi return fmt.Errorf("grpc dial: %w", err) } - mux, httpHandler := orc.grpcGateway(orc.swaggerHandler) + mux, httpHandler := orc.grpcGateway(orc.swaggerHandler, orc.publicContentHandlers) if err := grpcConfig.RegisterServiceClient(ctx, grpcConn, mux); err != nil { return fmt.Errorf("register service client: %w", err) } diff --git a/static/README.md b/static/README.md new file mode 100644 index 0000000..c34b749 --- /dev/null +++ b/static/README.md @@ -0,0 +1,99 @@ +# static + +Serve embedded static files from the Oracle using a simple convention. +Designed for use in conjunction with the `oracle` package. + +## Usage + +### 1. Embed your public directory + +Create a folder named `public` inside the package from which you configure your Oracle. +Add any public files you wish to serve. + +Use the `embed` package to include them in your Go binary: + +```go +//go:embed public/** +var PublicFS embed.FS +``` + +### 2. Mount it using the provided handler + +Use `PublicHandler` to create an `http.Handler` for your embedded files: + +```go +handler, err := static.PublicHandler(PublicFS) +``` + +This will serve files under the `/public/` URL path. + +--- + +### Optional: Panic wrapper + +For convenience, create a panic-on-failure helper in your app: + +```go +func PublicHandlerOrPanic(fs embed.FS) http.Handler { + h, err := static.PublicHandler(fs) + if err != nil { + panic(err) + } + return h +} +``` + +--- + +### 3. Register the handler with Oracle + +Pass the handler to `SetPublicContentHandler` method of the Oracle: + +```go +cfg.SetPublicContentHandler(api.PublicContentHandlerOrPanic()) +``` + +### 4. Access the files in the browser to verify successful setup + +Visit: +``` +/public/ +``` + +--- + +## Example Usage + +```go +func (r *startCmd) Run() error { + dir, err := os.Getwd() + if err != nil { + log.Fatalf("could not get working directory: %v", err) + } + log.Printf("Process running from: %s", dir) + + cfg := svc.DefaultConfig() + // ... + cfg.SetPublicContentHandler(api.PublicContentHandlerOrPanic()) + // ... + return oracle.Run(r.ctx, &oracle.Config{ + Config: *cfg, + PortalConfig: r.PortalConfig, + }) +} +``` + +With: + +```go +//go:embed public/** +var publicFS embed.FS + +func PublicContentHandlerOrPanic() http.Handler { + h, err := static.PublicHandler(publicFS) + if err != nil { + panic(err) + } + return h +} +``` diff --git a/static/static.go b/static/static.go new file mode 100644 index 0000000..8a7c0eb --- /dev/null +++ b/static/static.go @@ -0,0 +1,58 @@ +// Package static provides HTTP handlers for serving embedded static content. +// +// It supports serving files from a subdirectory within an embed.FS at a specified +// URL prefix, such as mounting embedded "public/**" content at the "/public/" path. +// +// This package is typically used to expose browser-accessible static files like +// JavaScript bundles, CSS, or HTML generated at build time. +// +// Example usage: +// +// //go:embed public/** +// var PublicFS embed.FS +// +// http.Handle("/public/", static.PublicHandler(PublicFS)) +package static + +import ( + "embed" + "fmt" + "io/fs" + "net/http" + "strings" +) + +const PublicFSDirSegment = "public" +const PublicPathPrefix = "/" + PublicFSDirSegment + "/" + +// PublicHandler returns an http.Handler that serves embedded files under the +// "public/" subdirectory of the provided embed.FS. This content MUST be served +// under the /public pattern +func PublicHandler(staticFS embed.FS) (http.Handler, error) { + return publicContentHandler(staticFS, PublicFSDirSegment, PublicFSDirSegment) +} + +// staticContentHandlerreturns an http.Handler that serves embedded files from a +// subdirectory within the embed.FS (e.g., "static") and maps them to a given URL prefix. +// +// For example: +// - Embedded files live under embed.FS path "static/**" +// - You want to serve them at the URL prefix "/assets/" +// +// Call: +// +// staticContentHandler(staticFS, "static", "assets") +// +// Then a request to /assets/index.html will serve embedded file "static/index.html". +func publicContentHandler(embeddedFS embed.FS, subdir, urlPrefix string) (http.Handler, error) { + + cleanStaticDir := strings.Trim(subdir, "/") + cleanURLPrefix := strings.Trim(urlPrefix, "/") + subFS, err := fs.Sub(embeddedFS, cleanStaticDir) + if err != nil { + return nil, fmt.Errorf("cannot create sub FS: %w", err) + } + + prefix := fmt.Sprintf("/%s/", cleanURLPrefix) + return http.StripPrefix(prefix, http.FileServer(http.FS(subFS))), nil +}