From a96d5b997e2d87dc5ea5211d56a6ca45e7b066c9 Mon Sep 17 00:00:00 2001 From: Victor Santos Date: Wed, 12 Nov 2025 21:19:57 -0300 Subject: [PATCH 01/14] Implement Infisical proxy server with caching and debug capabilities - Added `proxy.go` to handle proxy server commands, including starting the server and printing cache debug information. - Introduced a caching mechanism in `cache.go` to store and manage HTTP responses, supporting token-based cache invalidation. - Implemented resync logic to refresh cached entries based on expiration. - Added command-line flags for configuring the proxy server's domain, listen address, resync interval, and cache TTL. - Included a debug endpoint for development mode to retrieve cache information. --- packages/cmd/proxy.go | 458 ++++++++++++++++++++++++++++++++++++++++ packages/proxy/cache.go | 279 ++++++++++++++++++++++++ 2 files changed, 737 insertions(+) create mode 100644 packages/cmd/proxy.go create mode 100644 packages/proxy/cache.go diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go new file mode 100644 index 00000000..865a17b3 --- /dev/null +++ b/packages/cmd/proxy.go @@ -0,0 +1,458 @@ +package cmd + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/Infisical/infisical-merge/packages/proxy" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var proxyCmd = &cobra.Command{ + Example: `infisical proxy start`, + Short: "Used to run Infisical proxy server", + Use: "proxy", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, +} + +var proxyStartCmd = &cobra.Command{ + Example: `infisical proxy start --domain=https://app.infisical.com --listen-address=localhost:8081`, + Short: "Start the Infisical proxy server", + Use: "start", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: startProxyServer, +} + +var proxyDebugCmd = &cobra.Command{ + Example: `infisical proxy debug --listen-address=localhost:8081`, + Short: "Print cache debug information (dev mode only)", + Use: "debug", + DisableFlagsInUseLine: true, + Args: cobra.NoArgs, + Run: printCacheDebug, +} + +func startResyncLoop(ctx context.Context, cache *proxy.Cache, domainURL *url.URL, httpClient *http.Client, resyncInterval int, cacheTTL int) { + ticker := time.NewTicker(time.Duration(resyncInterval) * time.Minute) + defer ticker.Stop() + + log.Info(). + Int("resyncInterval", resyncInterval). + Int("cacheTTL", cacheTTL). + Msg("Resync loop started") + + for { + select { + case <-ticker.C: + log.Info().Msg("Starting resync cycle") + cacheTTLDuration := time.Duration(cacheTTL) * time.Minute + requests := cache.GetExpiredRequests(cacheTTLDuration) + + refetched := 0 + evicted := 0 + + for cacheKey, request := range requests { + // --- Reconstruct the request -- + + targetURL := *domainURL + parsedURI, err := url.Parse(request.RequestURI) + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Failed to parse requestURI during resync") + continue + } + + targetURL.Path = domainURL.Path + parsedURI.Path + targetURL.RawQuery = parsedURI.RawQuery + + proxyReq, err := http.NewRequest(request.Method, targetURL.String(), nil) + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", cacheKey). + Str("targetURL", targetURL.String()). + Msg("Failed to create proxy request during resync") + continue + } + + proxy.CopyHeaders(proxyReq.Header, request.Headers) + + resp, err := httpClient.Do(proxyReq) + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Network error during resync - keeping stale entry") + // Keep stale entry for high availability + continue + } + + // --- Handle response -- + + if resp.StatusCode == http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", cacheKey). + Msg("Failed to read response body during resync") + continue + } + + // Update only response data (IndexEntry doesn't change during resync) + cache.UpdateResponse(cacheKey, resp.StatusCode, resp.Header, bodyBytes) + refetched++ + + log.Debug(). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Successfully refetched and updated cache entry") + } else if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + // Evict entry on 401/403 + cache.EvictEntry(cacheKey) + evicted++ + resp.Body.Close() + + log.Info(). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Int("statusCode", resp.StatusCode). + Msg("Evicted cache entry due to authorization failure") + } else { + // Other error status codes - keep stale entry + resp.Body.Close() + log.Warn(). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Int("statusCode", resp.StatusCode). + Msg("Unexpected status code during resync - keeping stale entry") + } + } + + log.Info(). + Int("expiredEntries", len(requests)). + Int("refetched", refetched). + Int("evicted", evicted). + Msg("Resync cycle completed") + + case <-ctx.Done(): + log.Info().Msg("Resync loop stopped") + return + } + } +} + +func startProxyServer(cmd *cobra.Command, args []string) { + domain, err := cmd.Flags().GetString("domain") + if err != nil { + util.HandleError(err, "Unable to parse domain flag") + } + + if domain == "" { + util.PrintErrorMessageAndExit("Domain flag is required") + } + + listenAddress, err := cmd.Flags().GetString("listen-address") + if err != nil { + util.HandleError(err, "Unable to parse listen-address flag") + } + + if listenAddress == "" { + util.PrintErrorMessageAndExit("Listen-address flag is required") + } + + resyncInterval, err := cmd.Flags().GetInt("resync-interval") + if err != nil { + util.HandleError(err, "Unable to parse resync-interval flag") + } + + cacheTTL, err := cmd.Flags().GetInt("cache-ttl") + if err != nil { + util.HandleError(err, "Unable to parse cache-ttl flag") + } + + domainURL, err := url.Parse(domain) + if err != nil { + util.HandleError(err, fmt.Sprintf("Invalid domain URL: %s", domain)) + } + + httpClient := &http.Client{ + Timeout: 30 * time.Second, + } + + cache := proxy.NewCache() + devMode := util.CLI_VERSION == "devel" + mux := http.NewServeMux() + + // Debug endpoint (dev mode only) + if devMode { + mux.HandleFunc("/_debug/cache", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + debugInfo := cache.GetDebugInfo() + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(debugInfo); err != nil { + log.Error().Err(err).Msg("Failed to encode cache debug info") + http.Error(w, "Failed to encode debug info", http.StatusInternalServerError) + return + } + }) + log.Info().Msg("Dev mode enabled: debug endpoint available at /_debug/cache") + } + + proxyHandler := func(w http.ResponseWriter, r *http.Request) { + // Skip debug endpoints - they're handled by mux + if strings.HasPrefix(r.URL.Path, "/_debug/") { + http.NotFound(w, r) + return + } + + token := proxy.ExtractTokenFromRequest(r) + + isCacheable := proxy.IsCacheableRequest(r.URL.Path, r.Method) + + // -- Cache Check -- + + if isCacheable && token != "" { + cacheKey := proxy.GenerateCacheKey(r.Method, r.URL.Path, r.URL.RawQuery, token) + + if cachedResp, found := cache.Get(cacheKey); found { + log.Info(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("cacheKey", cacheKey). + Msg("Cache hit - serving from cache") + + proxy.CopyHeaders(w.Header(), cachedResp.Header) + w.WriteHeader(cachedResp.StatusCode) + _, err := io.Copy(w, cachedResp.Body) + if err != nil { + log.Error().Err(err).Msg("Failed to copy cached response body") + return + } + return + } + + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("cacheKey", cacheKey). + Msg("Cache miss - forwarding request") + } + + // -- Proxy Request -- + + targetURL := *domainURL + targetURL.Path = domainURL.Path + r.URL.Path + targetURL.RawQuery = r.URL.RawQuery + + proxyReq, err := http.NewRequest(r.Method, targetURL.String(), r.Body) + if err != nil { + log.Error().Err(err).Msg("Failed to create proxy request") + http.Error(w, fmt.Sprintf("Failed to create proxy request: %v", err), http.StatusInternalServerError) + return + } + + proxy.CopyHeaders(proxyReq.Header, r.Header) + + log.Info(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("target", targetURL.String()). + Msg("Forwarding request") + + resp, err := httpClient.Do(proxyReq) + if err != nil { + log.Error().Err(err).Msg("Failed to forward request") + http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // Read response body into memory for caching (if needed) and serving + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + log.Error().Err(err).Msg("Failed to read response body") + http.Error(w, fmt.Sprintf("Failed to read response body: %v", err), http.StatusInternalServerError) + return + } + + // -- Proxy Response -- + + proxy.CopyHeaders(w.Header(), resp.Header) + + w.WriteHeader(resp.StatusCode) + + _, err = w.Write(bodyBytes) + if err != nil { + log.Error().Err(err).Msg("Failed to write response body") + return + } + + // -- Cache Set -- + + if isCacheable && token != "" && resp.StatusCode == http.StatusOK { + cacheKey := proxy.GenerateCacheKey(r.Method, r.URL.Path, r.URL.RawQuery, token) + + queryParams := r.URL.Query() + projectId := queryParams.Get("projectId") + environment := queryParams.Get("environment") + secretPath := queryParams.Get("secretPath") + if secretPath == "" { + secretPath = "/" + } + + indexEntry := proxy.IndexEntry{ + CacheKey: cacheKey, + SecretPath: secretPath, + EnvironmentSlug: environment, + ProjectId: projectId, + } + + cachedResp := &http.Response{ + StatusCode: resp.StatusCode, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(bodyBytes)), + } + + proxy.CopyHeaders(cachedResp.Header, resp.Header) + + cache.Set(cacheKey, r, cachedResp, token, indexEntry) + + log.Info(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("cacheKey", cacheKey). + Msg("Response cached successfully") + } + + log.Info(). + Str("method", r.Method). + Str("path", r.URL.Path). + Int("status", resp.StatusCode). + Msg("Request forwarded successfully") + } + + // Add proxy handler to mux + mux.HandleFunc("/", proxyHandler) + + server := &http.Server{ + Addr: listenAddress, + Handler: mux, + } + + resyncCtx, resyncCancel := context.WithCancel(context.Background()) + defer resyncCancel() + + go startResyncLoop(resyncCtx, cache, domainURL, httpClient, resyncInterval, cacheTTL) + + // Handle graceful shutdown + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigCh + log.Info().Msgf("Received signal %v, shutting down proxy server...", sig) + + // Cancel resync goroutine + resyncCancel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + log.Error().Err(err).Msg("Error during server shutdown") + os.Exit(1) + } + + log.Info().Msg("Proxy server shutdown complete") + os.Exit(0) + }() + + log.Info().Msgf("Infisical proxy server starting on %s", listenAddress) + log.Info().Msgf("Forwarding requests to %s", domain) + + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + util.HandleError(err, "Failed to start proxy server") + } +} + +func printCacheDebug(cmd *cobra.Command, args []string) { + if util.CLI_VERSION != "devel" { + util.PrintErrorMessageAndExit("This command is only available in dev mode (when CLI_VERSION is 'devel').") + } + + listenAddress, err := cmd.Flags().GetString("listen-address") + if err != nil { + util.HandleError(err, "Unable to parse listen-address flag") + } + + if listenAddress == "" { + util.PrintErrorMessageAndExit("Listen-address flag is required") + } + + baseURL := "http://" + listenAddress + if strings.HasPrefix(listenAddress, ":") { + baseURL = "http://localhost" + listenAddress + } + + debugURL := baseURL + "/_debug/cache" + resp, err := http.Get(debugURL) + if err != nil { + util.HandleError(err, fmt.Sprintf("Failed to connect to proxy at %s. Make sure the proxy is running in dev mode (CLI_VERSION='devel')", listenAddress)) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + util.PrintErrorMessageAndExit(fmt.Sprintf("Failed to get cache debug info: %s", string(body))) + } + + var debugInfo proxy.CacheDebugInfo + if err := json.NewDecoder(resp.Body).Decode(&debugInfo); err != nil { + util.HandleError(err, "Failed to decode cache debug info") + } + + output, err := json.MarshalIndent(debugInfo, "", " ") + if err != nil { + util.HandleError(err, "Failed to marshal cache debug info") + } + + fmt.Println("Cache Debug Information:") + fmt.Println(string(output)) +} + +func init() { + proxyStartCmd.Flags().String("domain", "", "Domain of your Infisical instance (e.g., https://app.infisical.com for cloud, https://my-self-hosted-instance.com for self-hosted)") + proxyStartCmd.Flags().String("listen-address", "localhost:8081", "The address for the proxy server to listen on. Defaults to localhost:8081") + proxyStartCmd.Flags().Int("resync-interval", 10, "Interval in minutes for resyncing cached secrets. Defaults to 10 minutes.") + proxyStartCmd.Flags().Int("cache-ttl", 60, "TTL in minutes for individual cache entries. Defaults to 60 minutes.") + + proxyDebugCmd.Flags().String("listen-address", "localhost:8081", "The address where the proxy server is listening. Defaults to localhost:8081") + + proxyCmd.AddCommand(proxyStartCmd) + proxyCmd.AddCommand(proxyDebugCmd) + rootCmd.AddCommand(proxyCmd) +} diff --git a/packages/proxy/cache.go b/packages/proxy/cache.go new file mode 100644 index 00000000..f6f41622 --- /dev/null +++ b/packages/proxy/cache.go @@ -0,0 +1,279 @@ +package proxy + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "io" + "net/http" + "strings" + "sync" + "time" +) + +type IndexEntry struct { + CacheKey string + SecretPath string + EnvironmentSlug string + ProjectId string +} + +type CachedRequest struct { + Method string + RequestURI string + Headers http.Header + CachedAt time.Time +} + +type CachedResponse struct { + StatusCode int + Header http.Header + BodyBytes []byte +} + +type CacheEntry struct { + Request *CachedRequest + Response *CachedResponse +} + +// Cache is an in-memory cache for HTTP responses +type Cache struct { + entries map[string]*CacheEntry // main store: cacheKey -> cache entry (request + response) + tokenIndex map[string]map[string]IndexEntry // secondary index: token -> map[cacheKey]IndexEntry, used for token invalidation + mu sync.RWMutex // for thread-safe access +} + +func NewCache() *Cache { + return &Cache{ + entries: make(map[string]*CacheEntry), + tokenIndex: make(map[string]map[string]IndexEntry), + } +} + +// Only GET requests to /v3/secrets/* and /v4/secrets/* routes are cacheable +func IsCacheableRequest(path string, method string) bool { + if method != http.MethodGet { + return false + } + + return (strings.HasPrefix(path, "/api/v3/secrets/") || strings.HasPrefix(path, "/api/v4/secrets/")) || + path == "/api/v3/secrets" || path == "/api/v4/secrets" +} + +func (c *Cache) Get(cacheKey string) (*http.Response, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + entry, exists := c.entries[cacheKey] + if !exists { + return nil, false + } + + resp := &http.Response{ + StatusCode: entry.Response.StatusCode, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(entry.Response.BodyBytes)), + } + + CopyHeaders(resp.Header, entry.Response.Header) + + return resp, true +} + +func (c *Cache) Set(cacheKey string, req *http.Request, resp *http.Response, token string, indexEntry IndexEntry) { + c.mu.Lock() + defer c.mu.Unlock() + + // We can't use the response body directly because it will be closed by the time we need to use it + var bodyBytes []byte + if resp.Body != nil { + bodyBytes, _ = io.ReadAll(resp.Body) + } + + // Extract request metadata + requestURI := req.URL.RequestURI() + requestHeaders := make(http.Header) + CopyHeaders(requestHeaders, req.Header) + + // Extract response data + responseHeader := make(http.Header) + CopyHeaders(responseHeader, resp.Header) + + entry := &CacheEntry{ + Request: &CachedRequest{ + Method: req.Method, + RequestURI: requestURI, + Headers: requestHeaders, + CachedAt: time.Now(), + }, + Response: &CachedResponse{ + StatusCode: resp.StatusCode, + Header: responseHeader, + BodyBytes: bodyBytes, + }, + } + + c.entries[cacheKey] = entry + + // Update secondary index for token + if c.tokenIndex[token] == nil { + c.tokenIndex[token] = make(map[string]IndexEntry) + } + c.tokenIndex[token][cacheKey] = indexEntry +} + +// UpdateResponse updates only the response data and cachedAt timestamp for an existing cache entry +// This is used during resync when the request parameters (and thus IndexEntry) haven't changed +func (c *Cache) UpdateResponse(cacheKey string, statusCode int, header http.Header, bodyBytes []byte) { + c.mu.Lock() + defer c.mu.Unlock() + + entry, exists := c.entries[cacheKey] + if !exists { + return + } + + // Deep copy response header + responseHeader := make(http.Header) + CopyHeaders(responseHeader, header) + + // Deep copy bodyBytes + bodyBytesCopy := make([]byte, len(bodyBytes)) + copy(bodyBytesCopy, bodyBytes) + + entry.Response.StatusCode = statusCode + entry.Response.Header = responseHeader + entry.Response.BodyBytes = bodyBytesCopy + entry.Request.CachedAt = time.Now() +} + +func CopyHeaders(dst, src http.Header) { + for key, values := range src { + for _, value := range values { + dst.Add(key, value) + } + } +} + +func ExtractTokenFromRequest(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "" + } + + // Parse "Bearer " + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + return "" + } + + return parts[1] +} + +// GenerateCacheKey generates a cache key for a request by hashing the method, path, query, and token +func GenerateCacheKey(method, path, query, token string) string { + data := method + path + query + token + hash := sha256.Sum256([]byte(data)) + return hex.EncodeToString(hash[:]) +} + +// GetExpiredRequests returns only expired request data for resync +func (c *Cache) GetExpiredRequests(cacheTTL time.Duration) map[string]*CachedRequest { + c.mu.RLock() + defer c.mu.RUnlock() + + now := time.Now() + requests := make(map[string]*CachedRequest, 0) + + for key, entry := range c.entries { + // Only include entries where cache-ttl has expired + age := now.Sub(entry.Request.CachedAt) + if age <= cacheTTL { + continue + } + + // Create a deep copy of request data only + requestCopy := &CachedRequest{ + Method: entry.Request.Method, + RequestURI: entry.Request.RequestURI, + Headers: make(http.Header), + CachedAt: entry.Request.CachedAt, + } + + CopyHeaders(requestCopy.Headers, entry.Request.Headers) + + requests[key] = requestCopy + } + + return requests +} + +func (c *Cache) EvictEntry(cacheKey string) { + c.mu.Lock() + defer c.mu.Unlock() + + entry, exists := c.entries[cacheKey] + if !exists { + return + } + + token := ExtractTokenFromRequest(&http.Request{Header: entry.Request.Headers}) + + // Remove from main store + delete(c.entries, cacheKey) + + // Remove from token index + if token != "" { + if tokenEntries, ok := c.tokenIndex[token]; ok { + delete(tokenEntries, cacheKey) + if len(tokenEntries) == 0 { + delete(c.tokenIndex, token) + } + } + } +} + +// CacheDebugInfo contains debug information about the cache +type CacheDebugInfo struct { + TotalEntries int `json:"totalEntries"` + TotalTokens int `json:"totalTokens"` + TotalSizeBytes int64 `json:"totalSizeBytes"` + EntriesByToken map[string]int `json:"entriesByToken"` + CacheKeys []string `json:"cacheKeys"` + TokenIndex map[string][]IndexEntry `json:"tokenIndex"` +} + +// GetDebugInfo returns debug information about the cache (dev mode only) +func (c *Cache) GetDebugInfo() CacheDebugInfo { + c.mu.RLock() + defer c.mu.RUnlock() + + var totalSize int64 + entriesByToken := make(map[string]int) + tokenIndex := make(map[string][]IndexEntry) + cacheKeys := make([]string, 0, len(c.entries)) + + // Calculate sizes + for cacheKey, entry := range c.entries { + cacheKeys = append(cacheKeys, cacheKey) + totalSize += int64(len(entry.Response.BodyBytes)) + } + + // Build token index and count entries per token + for token, entries := range c.tokenIndex { + entriesByToken[token] = len(entries) + tokenIndex[token] = make([]IndexEntry, 0, len(entries)) + for _, entry := range entries { + tokenIndex[token] = append(tokenIndex[token], entry) + } + } + + return CacheDebugInfo{ + TotalEntries: len(c.entries), + TotalTokens: len(c.tokenIndex), + TotalSizeBytes: totalSize, + EntriesByToken: entriesByToken, + CacheKeys: cacheKeys, + TokenIndex: tokenIndex, + } +} From a71a99a8a1577577ef8eb42b88df7520daba4542 Mon Sep 17 00:00:00 2001 From: Victor Santos Date: Thu, 13 Nov 2025 21:06:18 -0300 Subject: [PATCH 02/14] Refactor proxy server and caching mechanism - Removed the `startResyncLoop` function from `proxy.go` and moved it to a new `resync.go` file for better organization. - Enhanced the caching system in `cache.go` to include a compound path index for improved cache entry management and eviction after mutation calls. - Introduced a new method to handle resync responses, including rate limit handling and entry eviction based on HTTP status codes. - Updated the proxy server to utilize a streaming client for long-lived connections and improved logging for cache hits and misses. - Added functionality to purge cache entries based on mutation paths across all tokens. --- packages/cmd/proxy.go | 286 ++++++++++++++-------------- packages/proxy/cache.go | 390 ++++++++++++++++++++++++++++++++++++--- packages/proxy/resync.go | 300 ++++++++++++++++++++++++++++++ 3 files changed, 813 insertions(+), 163 deletions(-) create mode 100644 packages/proxy/resync.go diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index 865a17b3..0c6642cf 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -46,121 +46,6 @@ var proxyDebugCmd = &cobra.Command{ Run: printCacheDebug, } -func startResyncLoop(ctx context.Context, cache *proxy.Cache, domainURL *url.URL, httpClient *http.Client, resyncInterval int, cacheTTL int) { - ticker := time.NewTicker(time.Duration(resyncInterval) * time.Minute) - defer ticker.Stop() - - log.Info(). - Int("resyncInterval", resyncInterval). - Int("cacheTTL", cacheTTL). - Msg("Resync loop started") - - for { - select { - case <-ticker.C: - log.Info().Msg("Starting resync cycle") - cacheTTLDuration := time.Duration(cacheTTL) * time.Minute - requests := cache.GetExpiredRequests(cacheTTLDuration) - - refetched := 0 - evicted := 0 - - for cacheKey, request := range requests { - // --- Reconstruct the request -- - - targetURL := *domainURL - parsedURI, err := url.Parse(request.RequestURI) - if err != nil { - log.Error(). - Err(err). - Str("cacheKey", cacheKey). - Str("requestURI", request.RequestURI). - Msg("Failed to parse requestURI during resync") - continue - } - - targetURL.Path = domainURL.Path + parsedURI.Path - targetURL.RawQuery = parsedURI.RawQuery - - proxyReq, err := http.NewRequest(request.Method, targetURL.String(), nil) - if err != nil { - log.Error(). - Err(err). - Str("cacheKey", cacheKey). - Str("targetURL", targetURL.String()). - Msg("Failed to create proxy request during resync") - continue - } - - proxy.CopyHeaders(proxyReq.Header, request.Headers) - - resp, err := httpClient.Do(proxyReq) - if err != nil { - log.Error(). - Err(err). - Str("cacheKey", cacheKey). - Str("requestURI", request.RequestURI). - Msg("Network error during resync - keeping stale entry") - // Keep stale entry for high availability - continue - } - - // --- Handle response -- - - if resp.StatusCode == http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - log.Error(). - Err(err). - Str("cacheKey", cacheKey). - Msg("Failed to read response body during resync") - continue - } - - // Update only response data (IndexEntry doesn't change during resync) - cache.UpdateResponse(cacheKey, resp.StatusCode, resp.Header, bodyBytes) - refetched++ - - log.Debug(). - Str("cacheKey", cacheKey). - Str("requestURI", request.RequestURI). - Msg("Successfully refetched and updated cache entry") - } else if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { - // Evict entry on 401/403 - cache.EvictEntry(cacheKey) - evicted++ - resp.Body.Close() - - log.Info(). - Str("cacheKey", cacheKey). - Str("requestURI", request.RequestURI). - Int("statusCode", resp.StatusCode). - Msg("Evicted cache entry due to authorization failure") - } else { - // Other error status codes - keep stale entry - resp.Body.Close() - log.Warn(). - Str("cacheKey", cacheKey). - Str("requestURI", request.RequestURI). - Int("statusCode", resp.StatusCode). - Msg("Unexpected status code during resync - keeping stale entry") - } - } - - log.Info(). - Int("expiredEntries", len(requests)). - Int("refetched", refetched). - Int("evicted", evicted). - Msg("Resync cycle completed") - - case <-ctx.Done(): - log.Info().Msg("Resync loop stopped") - return - } - } -} - func startProxyServer(cmd *cobra.Command, args []string) { domain, err := cmd.Flags().GetString("domain") if err != nil { @@ -199,6 +84,11 @@ func startProxyServer(cmd *cobra.Command, args []string) { Timeout: 30 * time.Second, } + // Create a separate client for streaming endpoints (no timeout for long-lived connections) + streamingClient := &http.Client{ + Timeout: 0, + } + cache := proxy.NewCache() devMode := util.CLI_VERSION == "devel" mux := http.NewServeMux() @@ -232,6 +122,7 @@ func startProxyServer(cmd *cobra.Command, args []string) { token := proxy.ExtractTokenFromRequest(r) isCacheable := proxy.IsCacheableRequest(r.URL.Path, r.Method) + isStreaming := isStreamingEndpoint(r.URL.Path) // -- Cache Check -- @@ -240,10 +131,8 @@ func startProxyServer(cmd *cobra.Command, args []string) { if cachedResp, found := cache.Get(cacheKey); found { log.Info(). - Str("method", r.Method). - Str("path", r.URL.Path). - Str("cacheKey", cacheKey). - Msg("Cache hit - serving from cache") + Str("hash", cacheKey). + Msg("Cache hit") proxy.CopyHeaders(w.Header(), cachedResp.Header) w.WriteHeader(cachedResp.StatusCode) @@ -255,20 +144,34 @@ func startProxyServer(cmd *cobra.Command, args []string) { return } - log.Debug(). - Str("method", r.Method). - Str("path", r.URL.Path). - Str("cacheKey", cacheKey). - Msg("Cache miss - forwarding request") + log.Info(). + Str("hash", cacheKey). + Msg("Cache miss") } // -- Proxy Request -- + // Read request body for mutation eviction (PATCH/DELETE) or restore for forwarding + var requestBodyBytes []byte + if r.Body != nil { + requestBodyBytes, err = io.ReadAll(r.Body) + if err != nil { + log.Error().Err(err).Msg("Failed to read request body") + http.Error(w, fmt.Sprintf("Failed to read request body: %v", err), http.StatusInternalServerError) + return + } + } + targetURL := *domainURL targetURL.Path = domainURL.Path + r.URL.Path targetURL.RawQuery = r.URL.RawQuery - proxyReq, err := http.NewRequest(r.Method, targetURL.String(), r.Body) + var bodyReader io.Reader + if requestBodyBytes != nil { + bodyReader = bytes.NewReader(requestBodyBytes) + } + + proxyReq, err := http.NewRequest(r.Method, targetURL.String(), bodyReader) if err != nil { log.Error().Err(err).Msg("Failed to create proxy request") http.Error(w, fmt.Sprintf("Failed to create proxy request: %v", err), http.StatusInternalServerError) @@ -277,13 +180,19 @@ func startProxyServer(cmd *cobra.Command, args []string) { proxy.CopyHeaders(proxyReq.Header, r.Header) - log.Info(). + log.Debug(). Str("method", r.Method). Str("path", r.URL.Path). Str("target", targetURL.String()). Msg("Forwarding request") - resp, err := httpClient.Do(proxyReq) + // Use streaming client for SSE/streaming endpoints, regular client for others + clientToUse := httpClient + if isStreaming { + clientToUse = streamingClient + } + + resp, err := clientToUse.Do(proxyReq) if err != nil { log.Error().Err(err).Msg("Failed to forward request") http.Error(w, fmt.Sprintf("Failed to forward request: %v", err), http.StatusBadGateway) @@ -291,7 +200,44 @@ func startProxyServer(cmd *cobra.Command, args []string) { } defer resp.Body.Close() - // Read response body into memory for caching (if needed) and serving + // -- Proxy Response -- + + proxy.CopyHeaders(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + + // For streaming endpoints, stream directly instead of buffering + if isStreaming { + // Flush headers immediately for SSE + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + + // Stream with periodic flushing for SSE events + buf := make([]byte, 1024) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + if _, writeErr := w.Write(buf[:n]); writeErr != nil { + log.Error().Err(writeErr).Msg("Failed to write streaming response") + return + } + // Flush after each write to send SSE events immediately + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } + if err == io.EOF { + break + } + if err != nil { + log.Error().Err(err).Msg("Failed to read streaming response") + return + } + } + return + } + + // For non-streaming endpoints, read into memory for caching and serving bodyBytes, err := io.ReadAll(resp.Body) if err != nil { log.Error().Err(err).Msg("Failed to read response body") @@ -299,31 +245,93 @@ func startProxyServer(cmd *cobra.Command, args []string) { return } - // -- Proxy Response -- - - proxy.CopyHeaders(w.Header(), resp.Header) - - w.WriteHeader(resp.StatusCode) - _, err = w.Write(bodyBytes) if err != nil { log.Error().Err(err).Msg("Failed to write response body") return } + // -- Secret Mutation Purging -- + + if (r.Method == http.MethodPatch || r.Method == http.MethodDelete) && + proxy.IsSecretsEndpoint(r.URL.Path) && + resp.StatusCode >= 200 && resp.StatusCode < 300 { + var projectId, environment, secretPath string + + if len(requestBodyBytes) > 0 { + var bodyData map[string]interface{} + if err := json.Unmarshal(requestBodyBytes, &bodyData); err == nil { + // Support both v3 (workspaceId/workspaceSlug) and v4 (projectId) + if projId, ok := bodyData["projectId"].(string); ok { + projectId = projId + } else if workspaceId, ok := bodyData["workspaceId"].(string); ok { + projectId = workspaceId + } else if workspaceSlug, ok := bodyData["workspaceSlug"].(string); ok { + projectId = workspaceSlug + } + if env, ok := bodyData["environment"].(string); ok { + environment = env + } + if path, ok := bodyData["secretPath"].(string); ok { + secretPath = path + } + } + } + + if secretPath == "" { + secretPath = "/" + } + + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("projectId", projectId). + Str("environment", environment). + Str("secretPath", secretPath). + Msg("Attempting mutation purging across all tokens") + + purgedCount := cache.PurgeByMutation(projectId, environment, secretPath) + + if purgedCount == 1 { + log.Info(). + Str("mutationPath", secretPath). + Msg("Entry purged") + } else { + log.Info(). + Int("purgedCount", purgedCount). + Str("mutationPath", secretPath). + Msg("Entries purged") + } + } + // -- Cache Set -- if isCacheable && token != "" && resp.StatusCode == http.StatusOK { cacheKey := proxy.GenerateCacheKey(r.Method, r.URL.Path, r.URL.RawQuery, token) queryParams := r.URL.Query() + // Support both v3 (workspaceId/workspaceSlug) and v4 (projectId) projectId := queryParams.Get("projectId") + if projectId == "" { + projectId = queryParams.Get("workspaceId") + } + if projectId == "" { + projectId = queryParams.Get("workspaceSlug") + } environment := queryParams.Get("environment") secretPath := queryParams.Get("secretPath") if secretPath == "" { secretPath = "/" } + if r.URL.Path == "/api/v3/secrets" || r.URL.Path == "/api/v4/secrets" || + r.URL.Path == "/api/v3/secrets/raw" || r.URL.Path == "/api/v4/secrets/raw" { + recursive := queryParams.Get("recursive") + if recursive == "true" { + secretPath = secretPath + "*" + } + } + indexEntry := proxy.IndexEntry{ CacheKey: cacheKey, SecretPath: secretPath, @@ -341,14 +349,14 @@ func startProxyServer(cmd *cobra.Command, args []string) { cache.Set(cacheKey, r, cachedResp, token, indexEntry) - log.Info(). + log.Debug(). Str("method", r.Method). Str("path", r.URL.Path). Str("cacheKey", cacheKey). Msg("Response cached successfully") } - log.Info(). + log.Debug(). Str("method", r.Method). Str("path", r.URL.Path). Int("status", resp.StatusCode). @@ -366,7 +374,7 @@ func startProxyServer(cmd *cobra.Command, args []string) { resyncCtx, resyncCancel := context.WithCancel(context.Background()) defer resyncCancel() - go startResyncLoop(resyncCtx, cache, domainURL, httpClient, resyncInterval, cacheTTL) + go proxy.StartResyncLoop(resyncCtx, cache, domainURL, httpClient, resyncInterval, cacheTTL) // Handle graceful shutdown sigCh := make(chan os.Signal, 1) @@ -444,6 +452,10 @@ func printCacheDebug(cmd *cobra.Command, args []string) { fmt.Println(string(output)) } +func isStreamingEndpoint(path string) bool { + return strings.HasPrefix(path, "/api/v1/events/") +} + func init() { proxyStartCmd.Flags().String("domain", "", "Domain of your Infisical instance (e.g., https://app.infisical.com for cloud, https://my-self-hosted-instance.com for self-hosted)") proxyStartCmd.Flags().String("listen-address", "localhost:8081", "The address for the proxy server to listen on. Defaults to localhost:8081") diff --git a/packages/proxy/cache.go b/packages/proxy/cache.go index f6f41622..df257581 100644 --- a/packages/proxy/cache.go +++ b/packages/proxy/cache.go @@ -38,26 +38,31 @@ type CacheEntry struct { // Cache is an in-memory cache for HTTP responses type Cache struct { - entries map[string]*CacheEntry // main store: cacheKey -> cache entry (request + response) - tokenIndex map[string]map[string]IndexEntry // secondary index: token -> map[cacheKey]IndexEntry, used for token invalidation - mu sync.RWMutex // for thread-safe access + entries map[string]*CacheEntry // main store: cacheKey -> cache entry (request + response) + tokenIndex map[string]map[string]IndexEntry // secondary index: token -> map[cacheKey]IndexEntry, used for token invalidation + compoundPathIndex map[string]map[string]map[string]map[string]map[string]struct{} // token -> projectID -> envSlug -> secretPath -> cacheKey -> struct{}, used for evictions after mutation calls + mu sync.RWMutex // for thread-safe access } func NewCache() *Cache { return &Cache{ - entries: make(map[string]*CacheEntry), - tokenIndex: make(map[string]map[string]IndexEntry), + entries: make(map[string]*CacheEntry), + tokenIndex: make(map[string]map[string]IndexEntry), + compoundPathIndex: make(map[string]map[string]map[string]map[string]map[string]struct{}), } } -// Only GET requests to /v3/secrets/* and /v4/secrets/* routes are cacheable +func IsSecretsEndpoint(path string) bool { + return (strings.HasPrefix(path, "/api/v3/secrets/") || strings.HasPrefix(path, "/api/v4/secrets/")) || + path == "/api/v3/secrets" || path == "/api/v4/secrets" +} + func IsCacheableRequest(path string, method string) bool { if method != http.MethodGet { return false } - return (strings.HasPrefix(path, "/api/v3/secrets/") || strings.HasPrefix(path, "/api/v4/secrets/")) || - path == "/api/v3/secrets" || path == "/api/v4/secrets" + return IsSecretsEndpoint(path) } func (c *Cache) Get(cacheKey string) (*http.Response, bool) { @@ -120,6 +125,21 @@ func (c *Cache) Set(cacheKey string, req *http.Request, resp *http.Response, tok c.tokenIndex[token] = make(map[string]IndexEntry) } c.tokenIndex[token][cacheKey] = indexEntry + + // Update compound path index + if c.compoundPathIndex[token] == nil { + c.compoundPathIndex[token] = make(map[string]map[string]map[string]map[string]struct{}) + } + if c.compoundPathIndex[token][indexEntry.ProjectId] == nil { + c.compoundPathIndex[token][indexEntry.ProjectId] = make(map[string]map[string]map[string]struct{}) + } + if c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug] == nil { + c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug] = make(map[string]map[string]struct{}) + } + if c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug][indexEntry.SecretPath] == nil { + c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug][indexEntry.SecretPath] = make(map[string]struct{}) + } + c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug][indexEntry.SecretPath][cacheKey] = struct{}{} } // UpdateResponse updates only the response data and cachedAt timestamp for an existing cache entry @@ -161,7 +181,6 @@ func ExtractTokenFromRequest(r *http.Request) string { return "" } - // Parse "Bearer " parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { return "" @@ -177,6 +196,25 @@ func GenerateCacheKey(method, path, query, token string) string { return hex.EncodeToString(hash[:]) } +func matchesPath(storedPath, queryPath string) bool { + if strings.HasSuffix(storedPath, "/*") { + base := strings.TrimSuffix(storedPath, "/*") + + if queryPath == base { + return true + } + + // Check if queryPath is under base (e.g., base="/test", queryPath="/test/sub") + return strings.HasPrefix(queryPath+"/", base+"/") + } + + if storedPath == queryPath { + return true + } + + return false +} + // GetExpiredRequests returns only expired request data for resync func (c *Cache) GetExpiredRequests(cacheTTL time.Duration) map[string]*CachedRequest { c.mu.RLock() @@ -192,7 +230,6 @@ func (c *Cache) GetExpiredRequests(cacheTTL time.Duration) map[string]*CachedReq continue } - // Create a deep copy of request data only requestCopy := &CachedRequest{ Method: entry.Request.Method, RequestURI: entry.Request.RequestURI, @@ -222,25 +259,260 @@ func (c *Cache) EvictEntry(cacheKey string) { // Remove from main store delete(c.entries, cacheKey) - // Remove from token index + // Remove from token index and get IndexEntry for compound index cleanup + var indexEntry IndexEntry if token != "" { if tokenEntries, ok := c.tokenIndex[token]; ok { + indexEntry = tokenEntries[cacheKey] delete(tokenEntries, cacheKey) if len(tokenEntries) == 0 { delete(c.tokenIndex, token) } } } + + // Remove from compound path index + if token == "" || indexEntry.ProjectId == "" || indexEntry.EnvironmentSlug == "" || indexEntry.SecretPath == "" { + return + } + + projectMap := c.compoundPathIndex[token] + if projectMap == nil { + return + } + + envMap := projectMap[indexEntry.ProjectId] + if envMap == nil { + // Orphaned project entry + delete(projectMap, indexEntry.ProjectId) + if len(projectMap) == 0 { + delete(c.compoundPathIndex, token) + } + return + } + + pathsMap := envMap[indexEntry.EnvironmentSlug] + if pathsMap == nil { + // Orphaned environment entry + delete(envMap, indexEntry.EnvironmentSlug) + if len(envMap) == 0 { + delete(projectMap, indexEntry.ProjectId) + } + if len(projectMap) == 0 { + delete(c.compoundPathIndex, token) + } + return + } + + cacheKeys := pathsMap[indexEntry.SecretPath] + if cacheKeys == nil { + // Orphaned path entry + delete(pathsMap, indexEntry.SecretPath) + if len(pathsMap) == 0 { + delete(envMap, indexEntry.EnvironmentSlug) + } + if len(envMap) == 0 { + delete(projectMap, indexEntry.ProjectId) + } + if len(projectMap) == 0 { + delete(c.compoundPathIndex, token) + } + return + } + + delete(cacheKeys, cacheKey) + + // If no more cacheKeys for this path, remove the path entry + if len(cacheKeys) == 0 { + delete(pathsMap, indexEntry.SecretPath) + } + + // Clean up empty nested maps + if len(pathsMap) == 0 { + delete(envMap, indexEntry.EnvironmentSlug) + } + if len(envMap) == 0 { + delete(projectMap, indexEntry.ProjectId) + } + if len(projectMap) == 0 { + delete(c.compoundPathIndex, token) + } +} + +func (c *Cache) GetAllTokens() []string { + c.mu.RLock() + defer c.mu.RUnlock() + + tokens := make([]string, 0, len(c.tokenIndex)) + for token := range c.tokenIndex { + tokens = append(tokens, token) + } + return tokens +} + +// GetFirstRequestForToken gets the first request (any, regardless of expiration) for a token +func (c *Cache) GetFirstRequestForToken(token string) (cacheKey string, request *CachedRequest, found bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + tokenEntries, exists := c.tokenIndex[token] + if !exists || len(tokenEntries) == 0 { + return "", nil, false + } + + // Get the first cacheKey from the token's entries + for key := range tokenEntries { + entry, exists := c.entries[key] + if !exists { + // Delete orphan cache entry + delete(tokenEntries, key) + continue + } + + requestCopy := &CachedRequest{ + Method: entry.Request.Method, + RequestURI: entry.Request.RequestURI, + Headers: make(http.Header), + CachedAt: entry.Request.CachedAt, + } + + CopyHeaders(requestCopy.Headers, entry.Request.Headers) + + return key, requestCopy, true + } + + return "", nil, false +} + +func (c *Cache) EvictAllEntriesForToken(token string) int { + c.mu.Lock() + defer c.mu.Unlock() + + tokenEntries, exists := c.tokenIndex[token] + if !exists { + return 0 + } + + evictedCount := len(tokenEntries) + + // Delete all entries from main store + for cacheKey := range tokenEntries { + delete(c.entries, cacheKey) + } + + // Delete token from token index + delete(c.tokenIndex, token) + + return evictedCount +} + +func (c *Cache) RemoveTokenFromIndex(token string) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.tokenIndex, token) +} + +// PurgeByMutation purges cache entries across ALL tokens that match the mutation path +func (c *Cache) PurgeByMutation(projectID, envSlug, mutationPath string) int { + c.mu.Lock() + defer c.mu.Unlock() + + purgedCount := 0 + + // Iterate through all tokens in the compound index + for token, projectMap := range c.compoundPathIndex { + envMap, ok := projectMap[projectID] + if !ok { + continue + } + + pathsMap, ok := envMap[envSlug] + if !ok { + continue + } + + // Iterate through all paths and check matches + for storedPath, cacheKeys := range pathsMap { + if matchesPath(storedPath, mutationPath) { + for cacheKey := range cacheKeys { + // Remove from main store + delete(c.entries, cacheKey) + + // Remove from token index + if tokenEntries, ok := c.tokenIndex[token]; ok { + delete(tokenEntries, cacheKey) + if len(tokenEntries) == 0 { + delete(c.tokenIndex, token) + } + } + + purgedCount++ + } + delete(pathsMap, storedPath) + } + } + + // Clean up empty nested maps for this token + if len(pathsMap) == 0 { + delete(envMap, envSlug) + } + if len(envMap) == 0 { + delete(projectMap, projectID) + } + if len(projectMap) == 0 { + delete(c.compoundPathIndex, token) + } + } + + return purgedCount +} + +// CompoundPathIndexDebugInfo represents the compound path index structure +type CompoundPathIndexDebugInfo struct { + Token string `json:"token"` + Projects map[string]ProjectDebugInfo `json:"projects"` + TotalPaths int `json:"totalPaths"` + TotalKeys int `json:"totalKeys"` +} + +// ProjectDebugInfo represents project-level debug info +type ProjectDebugInfo struct { + ProjectID string `json:"projectId"` + Environments map[string]EnvironmentDebugInfo `json:"environments"` + TotalPaths int `json:"totalPaths"` + TotalKeys int `json:"totalKeys"` +} + +// EnvironmentDebugInfo represents environment-level debug info +type EnvironmentDebugInfo struct { + EnvironmentSlug string `json:"environmentSlug"` + Paths map[string]PathDebugInfo `json:"paths"` + TotalKeys int `json:"totalKeys"` +} + +// CacheKeyDebugInfo represents a cache key with its timestamp +type CacheKeyDebugInfo struct { + CacheKey string `json:"cacheKey"` + CachedAt time.Time `json:"cachedAt"` +} + +// PathDebugInfo represents path-level debug info +type PathDebugInfo struct { + SecretPath string `json:"secretPath"` + CacheKeys []CacheKeyDebugInfo `json:"cacheKeys"` + KeyCount int `json:"keyCount"` } // CacheDebugInfo contains debug information about the cache type CacheDebugInfo struct { - TotalEntries int `json:"totalEntries"` - TotalTokens int `json:"totalTokens"` - TotalSizeBytes int64 `json:"totalSizeBytes"` - EntriesByToken map[string]int `json:"entriesByToken"` - CacheKeys []string `json:"cacheKeys"` - TokenIndex map[string][]IndexEntry `json:"tokenIndex"` + TotalEntries int `json:"totalEntries"` + TotalTokens int `json:"totalTokens"` + TotalSizeBytes int64 `json:"totalSizeBytes"` + EntriesByToken map[string]int `json:"entriesByToken"` + CacheKeys []CacheKeyDebugInfo `json:"cacheKeys"` + TokenIndex map[string][]IndexEntry `json:"tokenIndex"` + CompoundPathIndex []CompoundPathIndexDebugInfo `json:"compoundPathIndex"` } // GetDebugInfo returns debug information about the cache (dev mode only) @@ -251,11 +523,14 @@ func (c *Cache) GetDebugInfo() CacheDebugInfo { var totalSize int64 entriesByToken := make(map[string]int) tokenIndex := make(map[string][]IndexEntry) - cacheKeys := make([]string, 0, len(c.entries)) + cacheKeys := make([]CacheKeyDebugInfo, 0, len(c.entries)) - // Calculate sizes + // Calculate sizes and build cache keys with timestamps for cacheKey, entry := range c.entries { - cacheKeys = append(cacheKeys, cacheKey) + cacheKeys = append(cacheKeys, CacheKeyDebugInfo{ + CacheKey: cacheKey, + CachedAt: entry.Request.CachedAt, + }) totalSize += int64(len(entry.Response.BodyBytes)) } @@ -268,12 +543,75 @@ func (c *Cache) GetDebugInfo() CacheDebugInfo { } } + // Build compound path index debug info + compoundPathIndex := make([]CompoundPathIndexDebugInfo, 0, len(c.compoundPathIndex)) + for token, projectMap := range c.compoundPathIndex { + projects := make(map[string]ProjectDebugInfo) + totalPaths := 0 + totalKeys := 0 + + for projectID, envMap := range projectMap { + environments := make(map[string]EnvironmentDebugInfo) + projectTotalPaths := 0 + projectTotalKeys := 0 + + for envSlug, pathsMap := range envMap { + paths := make(map[string]PathDebugInfo) + envTotalKeys := 0 + + for secretPath, cacheKeys := range pathsMap { + keys := make([]CacheKeyDebugInfo, 0, len(cacheKeys)) + for cacheKey := range cacheKeys { + + if entry, exists := c.entries[cacheKey]; exists { + keys = append(keys, CacheKeyDebugInfo{ + CacheKey: cacheKey, + CachedAt: entry.Request.CachedAt, + }) + } + } + paths[secretPath] = PathDebugInfo{ + SecretPath: secretPath, + CacheKeys: keys, + KeyCount: len(cacheKeys), + } + envTotalKeys += len(cacheKeys) + projectTotalPaths++ + } + + environments[envSlug] = EnvironmentDebugInfo{ + EnvironmentSlug: envSlug, + Paths: paths, + TotalKeys: envTotalKeys, + } + projectTotalKeys += envTotalKeys + } + + projects[projectID] = ProjectDebugInfo{ + ProjectID: projectID, + Environments: environments, + TotalPaths: projectTotalPaths, + TotalKeys: projectTotalKeys, + } + totalPaths += projectTotalPaths + totalKeys += projectTotalKeys + } + + compoundPathIndex = append(compoundPathIndex, CompoundPathIndexDebugInfo{ + Token: token, + Projects: projects, + TotalPaths: totalPaths, + TotalKeys: totalKeys, + }) + } + return CacheDebugInfo{ - TotalEntries: len(c.entries), - TotalTokens: len(c.tokenIndex), - TotalSizeBytes: totalSize, - EntriesByToken: entriesByToken, - CacheKeys: cacheKeys, - TokenIndex: tokenIndex, + TotalEntries: len(c.entries), + TotalTokens: len(c.tokenIndex), + TotalSizeBytes: totalSize, + EntriesByToken: entriesByToken, + CacheKeys: cacheKeys, + TokenIndex: tokenIndex, + CompoundPathIndex: compoundPathIndex, } } diff --git a/packages/proxy/resync.go b/packages/proxy/resync.go new file mode 100644 index 00000000..9569b85c --- /dev/null +++ b/packages/proxy/resync.go @@ -0,0 +1,300 @@ +package proxy + +import ( + "context" + "encoding/json" + "io" + "math/rand" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "time" + + "github.com/rs/zerolog/log" +) + +// parseRateLimitSeconds extracts retry-after seconds from rate limit error message +// Expected format: "Rate limit exceeded. Please try again in 57 seconds" +// Returns default of 10 seconds if parsing fails +func parseRateLimitSeconds(body []byte) int { + var errorResponse struct { + Message string `json:"message"` + } + + var seconds int = 10 + + if err := json.Unmarshal(body, &errorResponse); err != nil { + return seconds + } + + re := regexp.MustCompile(`(\d+)\s+seconds?`) + matches := re.FindStringSubmatch(errorResponse.Message) + if len(matches) < 2 { + return 10 + } + + seconds, err := strconv.Atoi(matches[1]) + if err != nil { + return 10 + } + + return seconds +} + +func handleResyncResponse(cache *Cache, cacheKey string, requestURI string, resp *http.Response) (refetched bool, evicted bool, rateLimited bool, retryAfterSeconds int) { + switch resp.StatusCode { + case http.StatusOK: + bodyBytes, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", cacheKey). + Msg("Failed to read response body during resync") + return false, false, false, 0 + } + + // Update only response data (IndexEntry doesn't change during resync) + cache.UpdateResponse(cacheKey, resp.StatusCode, resp.Header, bodyBytes) + + log.Debug(). + Str("cacheKey", cacheKey). + Str("requestURI", requestURI). + Msg("Successfully refetched and updated cache entry") + return true, false, false, 0 + case http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound: + // Evict entry on 401/403/404 + cache.EvictEntry(cacheKey) + resp.Body.Close() + + log.Info(). + Str("hash", cacheKey). + Msg("Entry evicted") + return false, true, false, 0 + case http.StatusTooManyRequests: + bodyBytes, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", cacheKey). + Msg("Failed to read rate limit response body, using default 10 seconds") + return false, false, true, 10 + } + + retryAfter := parseRateLimitSeconds(bodyBytes) + + log.Debug(). + Str("cacheKey", cacheKey). + Str("requestURI", requestURI). + Int("retryAfterSeconds", retryAfter). + Msg("Rate limited during resync") + return false, false, true, retryAfter + default: + // Other error status codes - keep stale entry + resp.Body.Close() + log.Debug(). + Str("cacheKey", cacheKey). + Str("requestURI", requestURI). + Int("statusCode", resp.StatusCode). + Msg("Unexpected status code during resync - keeping stale entry") + return false, false, false, 0 + } +} + +func reconstructProxyRequest(domainURL *url.URL, request *CachedRequest) (*http.Request, error) { + targetURL := *domainURL + parsedURI, err := url.Parse(request.RequestURI) + if err != nil { + return nil, err + } + + targetURL.Path = domainURL.Path + parsedURI.Path + targetURL.RawQuery = parsedURI.RawQuery + + proxyReq, err := http.NewRequest(request.Method, targetURL.String(), nil) + if err != nil { + return nil, err + } + + CopyHeaders(proxyReq.Header, request.Headers) + return proxyReq, nil +} + +// StartResyncLoop starts the background resync loop for cache entries +func StartResyncLoop(ctx context.Context, cache *Cache, domainURL *url.URL, httpClient *http.Client, resyncInterval int, cacheTTL int) { + ticker := time.NewTicker(time.Duration(resyncInterval) * time.Minute) + defer ticker.Stop() + + log.Info(). + Int("resyncInterval", resyncInterval). + Int("cacheTTL", cacheTTL). + Msg("Resync loop started") + + for { + select { + case <-ticker.C: + log.Info().Msg("Starting resync cycle") + cacheTTLDuration := time.Duration(cacheTTL) * time.Minute + + // -- Token testing phase (before expired requests) -- + + tokens := cache.GetAllTokens() + tokensEvicted := 0 + + for _, token := range tokens { + // Add jitter + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + + cacheKey, request, found := cache.GetFirstRequestForToken(token) + if !found { + cache.RemoveTokenFromIndex(token) + log.Debug(). + Str("token", token). + Msg("Removed orphaned token entry during token testing") + continue + } + + proxyReq, err := reconstructProxyRequest(domainURL, request) + if err != nil { + log.Error(). + Err(err). + Str("token", token). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Failed to reconstruct request during token testing") + continue + } + + resp, err := httpClient.Do(proxyReq) + if err != nil { + // Keep entries for high availability + + log.Error(). + Err(err). + Str("token", token). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Network error during token testing - keeping entries") + + continue + } + + // If 401, evict all entries for this token + if resp.StatusCode == http.StatusUnauthorized { + evictedCount := cache.EvictAllEntriesForToken(token) + resp.Body.Close() + tokensEvicted++ + + if evictedCount == 1 { + log.Info(). + Str("token", token). + Msg("Entry evicted") + } else { + log.Info(). + Int("evictedCount", evictedCount). + Str("token", token). + Msg("Entries evicted") + } + } else { + resp.Body.Close() + } + } + + if tokensEvicted > 0 { + log.Debug(). + Int("tokensEvicted", tokensEvicted). + Msg("Token testing phase completed") + } + + // -- Expired entries processing phase -- + + cycleStartTime := time.Now() + resyncIntervalDuration := time.Duration(resyncInterval) * time.Minute + + requests := cache.GetExpiredRequests(cacheTTLDuration) + + // Convert map to slice and sort by CachedAt (oldest first) + type orderedEntry struct { + cacheKey string + request *CachedRequest + } + ordered := make([]orderedEntry, 0, len(requests)) + for key, req := range requests { + ordered = append(ordered, orderedEntry{key, req}) + } + sort.Slice(ordered, func(i, j int) bool { + return ordered[i].request.CachedAt.Before(ordered[j].request.CachedAt) + }) + + refetched := 0 + evicted := 0 + + for _, entry := range ordered { + // Add jitter + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + + proxyReq, err := reconstructProxyRequest(domainURL, entry.request) + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", entry.cacheKey). + Str("requestURI", entry.request.RequestURI). + Msg("Failed to parse requestURI during resync") + continue + } + + resp, err := httpClient.Do(proxyReq) + if err != nil { + // Keep stale entry for high availability + + log.Error(). + Err(err). + Str("cacheKey", entry.cacheKey). + Str("requestURI", entry.request.RequestURI). + Msg("Network error during resync - keeping stale entry") + + continue + } + + refetchedResult, evictedResult, rateLimited, retryAfterSeconds := handleResyncResponse(cache, entry.cacheKey, entry.request.RequestURI, resp) + if refetchedResult { + refetched++ + } + if evictedResult { + evicted++ + } + + // Handle rate limiting + if rateLimited { + pauseDuration := time.Duration(retryAfterSeconds+2) * time.Second // 2 seconds buffer + timeUntilNextTick := resyncIntervalDuration - time.Since(cycleStartTime) + + if pauseDuration <= timeUntilNextTick { + log.Info(). + Int("pauseSeconds", retryAfterSeconds+2). + Msg("Rate limited, pausing resync") + time.Sleep(pauseDuration) + } else { + log.Warn(). + Int("pauseSeconds", retryAfterSeconds+2). + Msg("Rate limit pause exceeds resync interval, remaining entries will be processed next cycle. Increase the resync-interval value and the cache-ttl value to prevent this behavior.") + break + } + } + } + + log.Info(). + Int("expiredEntries", len(requests)). + Int("refetched", refetched). + Int("evicted", evicted). + Msg("Resync cycle completed") + + case <-ctx.Done(): + log.Info().Msg("Resync loop stopped") + return + } + } +} From 0216e540d1af7b85d0d2f7a9c86d115e3e5045e2 Mon Sep 17 00:00:00 2001 From: Victor Santos Date: Thu, 13 Nov 2025 21:43:39 -0300 Subject: [PATCH 03/14] Refactor cache management and response handling in proxy server - Updated the `compoundPathIndex` comment in `cache.go` to clarify its purpose for purging after mutation calls. - Changed the locking mechanism in `GetFirstRequestForToken` to use a write lock for thread safety. - Enhanced the `EvictAllEntriesForToken` and `RemoveTokenFromIndex` methods to delete entries from the `compoundPathIndex` when a token is evicted. - Improved response handling in `handleResyncResponse` by ensuring the response body is closed properly in all cases to prevent resource leaks. --- packages/proxy/cache.go | 11 ++++++++--- packages/proxy/resync.go | 6 ++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/packages/proxy/cache.go b/packages/proxy/cache.go index df257581..d4a9572d 100644 --- a/packages/proxy/cache.go +++ b/packages/proxy/cache.go @@ -40,7 +40,7 @@ type CacheEntry struct { type Cache struct { entries map[string]*CacheEntry // main store: cacheKey -> cache entry (request + response) tokenIndex map[string]map[string]IndexEntry // secondary index: token -> map[cacheKey]IndexEntry, used for token invalidation - compoundPathIndex map[string]map[string]map[string]map[string]map[string]struct{} // token -> projectID -> envSlug -> secretPath -> cacheKey -> struct{}, used for evictions after mutation calls + compoundPathIndex map[string]map[string]map[string]map[string]map[string]struct{} // token -> projectID -> envSlug -> secretPath -> cacheKey -> struct{}, used for purging after mutation calls mu sync.RWMutex // for thread-safe access } @@ -352,8 +352,8 @@ func (c *Cache) GetAllTokens() []string { // GetFirstRequestForToken gets the first request (any, regardless of expiration) for a token func (c *Cache) GetFirstRequestForToken(token string) (cacheKey string, request *CachedRequest, found bool) { - c.mu.RLock() - defer c.mu.RUnlock() + c.mu.Lock() + defer c.mu.Unlock() tokenEntries, exists := c.tokenIndex[token] if !exists || len(tokenEntries) == 0 { @@ -403,6 +403,9 @@ func (c *Cache) EvictAllEntriesForToken(token string) int { // Delete token from token index delete(c.tokenIndex, token) + // Delete token from compound path index + delete(c.compoundPathIndex, token) + return evictedCount } @@ -411,6 +414,8 @@ func (c *Cache) RemoveTokenFromIndex(token string) { defer c.mu.Unlock() delete(c.tokenIndex, token) + + delete(c.compoundPathIndex, token) } // PurgeByMutation purges cache entries across ALL tokens that match the mutation path diff --git a/packages/proxy/resync.go b/packages/proxy/resync.go index 9569b85c..f99980a6 100644 --- a/packages/proxy/resync.go +++ b/packages/proxy/resync.go @@ -44,10 +44,11 @@ func parseRateLimitSeconds(body []byte) int { } func handleResyncResponse(cache *Cache, cacheKey string, requestURI string, resp *http.Response) (refetched bool, evicted bool, rateLimited bool, retryAfterSeconds int) { + defer resp.Body.Close() + switch resp.StatusCode { case http.StatusOK: bodyBytes, err := io.ReadAll(resp.Body) - resp.Body.Close() if err != nil { log.Error(). Err(err). @@ -67,7 +68,6 @@ func handleResyncResponse(cache *Cache, cacheKey string, requestURI string, resp case http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound: // Evict entry on 401/403/404 cache.EvictEntry(cacheKey) - resp.Body.Close() log.Info(). Str("hash", cacheKey). @@ -75,7 +75,6 @@ func handleResyncResponse(cache *Cache, cacheKey string, requestURI string, resp return false, true, false, 0 case http.StatusTooManyRequests: bodyBytes, err := io.ReadAll(resp.Body) - resp.Body.Close() if err != nil { log.Error(). Err(err). @@ -94,7 +93,6 @@ func handleResyncResponse(cache *Cache, cacheKey string, requestURI string, resp return false, false, true, retryAfter default: // Other error status codes - keep stale entry - resp.Body.Close() log.Debug(). Str("cacheKey", cacheKey). Str("requestURI", requestURI). From 8a1ca9fb0564be4b3e64cc1dd61d400a08e3d0c7 Mon Sep 17 00:00:00 2001 From: Victor Santos Date: Fri, 14 Nov 2025 10:08:31 -0300 Subject: [PATCH 04/14] Add error logging for cache purging mutation request failures in proxy server - Enhanced the `startProxyServer` function in `proxy.go` to log errors when parsing mutation request bodies fails, ensuring better visibility into potential cache issues. --- packages/cmd/proxy.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index 0c6642cf..dd4a91c4 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -275,6 +275,12 @@ func startProxyServer(cmd *cobra.Command, args []string) { if path, ok := bodyData["secretPath"].(string); ok { secretPath = path } + } else { + log.Error(). + Err(err). + Str("method", r.Method). + Str("path", r.URL.Path). + Msg("Failed to parse mutation request body for cache purging - cache may serve stale data") } } From e073b76a1a7dd0f2f59a4c4443c04015bd28ee3d Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Tue, 6 Jan 2026 19:41:05 +0100 Subject: [PATCH 05/14] Update proxy.go --- packages/cmd/proxy.go | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index dd4a91c4..4dd34059 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -44,6 +44,7 @@ var proxyDebugCmd = &cobra.Command{ DisableFlagsInUseLine: true, Args: cobra.NoArgs, Run: printCacheDebug, + Hidden: true, } func startProxyServer(cmd *cobra.Command, args []string) { From cb452c7118dc0ca755e1cd9dc99925a00fd8a6e3 Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 02:52:30 +0100 Subject: [PATCH 06/14] feat: infisical proxy --- packages/cmd/proxy.go | 117 +++++- packages/proxy/cache.go | 552 ++++++++++++++++----------- packages/proxy/resync.go | 331 ++++++++-------- packages/util/agent.go | 58 --- packages/util/cache/cache-storage.go | 132 +++++++ packages/util/helper.go | 56 +++ 6 files changed, 794 insertions(+), 452 deletions(-) delete mode 100644 packages/util/agent.go diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index 4dd34059..5a5e60c9 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "context" + "crypto/tls" "encoding/json" "fmt" "io" @@ -14,8 +15,10 @@ import ( "syscall" "time" + "github.com/Infisical/infisical-merge/packages/crypto" "github.com/Infisical/infisical-merge/packages/proxy" "github.com/Infisical/infisical-merge/packages/util" + "github.com/Infisical/infisical-merge/packages/util/cache" "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) @@ -62,18 +65,56 @@ func startProxyServer(cmd *cobra.Command, args []string) { util.HandleError(err, "Unable to parse listen-address flag") } + tlsEnabled, err := cmd.Flags().GetBool("tls-enabled") + if err != nil { + util.HandleError(err, "Unable to parse tls-enabled flag") + } + + tlsCertFile, err := cmd.Flags().GetString("tls-cert-file") + if err != nil { + util.HandleError(err, "Unable to parse tls-cert-file flag") + } + + tlsKeyFile, err := cmd.Flags().GetString("tls-key-file") + if err != nil { + util.HandleError(err, "Unable to parse tls-key-file flag") + } + + if tlsEnabled && (tlsCertFile == "" || tlsKeyFile == "") { + util.PrintErrorMessageAndExit("`tls-cert-file` and `tls-key-file` are required when `tls-enabled` is set to true") + } + if listenAddress == "" { util.PrintErrorMessageAndExit("Listen-address flag is required") } - resyncInterval, err := cmd.Flags().GetInt("resync-interval") + evictionStrategy, err := cmd.Flags().GetString("eviction-strategy") + if err != nil { + util.HandleError(err, "Unable to parse eviction-strategy flag") + } + + if evictionStrategy != "optimistic" { + util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid eviction-strategy '%s'. Currently only 'optimistic' is supported.", evictionStrategy)) + } + + accessTokenCheckIntervalStr, err := cmd.Flags().GetString("access-token-check-interval") + if err != nil { + util.HandleError(err, "Unable to parse access-token-check-interval flag") + } + + accessTokenCheckInterval, err := util.ParseTimeDurationString(accessTokenCheckIntervalStr, true) + if err != nil { + util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid access-token-check-interval format '%s'. Use formats like 5m, 1h, 1d", accessTokenCheckIntervalStr)) + } + + staticSecretsRefreshIntervalStr, err := cmd.Flags().GetString("static-secrets-refresh-interval") if err != nil { - util.HandleError(err, "Unable to parse resync-interval flag") + util.HandleError(err, "Unable to parse static-secrets-refresh-interval flag") } - cacheTTL, err := cmd.Flags().GetInt("cache-ttl") + staticSecretsRefreshInterval, err := util.ParseTimeDurationString(staticSecretsRefreshIntervalStr, true) if err != nil { - util.HandleError(err, "Unable to parse cache-ttl flag") + util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid static-secrets-refresh-interval format '%s'. Use formats like 30m, 1h, 1d", staticSecretsRefreshIntervalStr)) } domainURL, err := url.Parse(domain) @@ -90,12 +131,29 @@ func startProxyServer(cmd *cobra.Command, args []string) { Timeout: 0, } - cache := proxy.NewCache() - devMode := util.CLI_VERSION == "devel" + // Create in-memory cache (no persistence, no encryption needed for ephemeral data) + // For persistent cache with encryption, use proxy.NewCacheWithOptions + + encryptionKey, err := crypto.GenerateRandomBytes(32) + if err != nil { + util.HandleError(err, "Failed to generate random encryption key") + } + + cache, err := proxy.NewCache(cache.EncryptedStorageOptions{ + InMemory: true, + EncryptionKey: [32]byte(encryptionKey), + }) + + if err != nil { + util.PrintErrorMessageAndExit(fmt.Sprintf("Failed to create cache: %v", err)) + } + + defer cache.Close() + mux := http.NewServeMux() // Debug endpoint (dev mode only) - if devMode { + if util.IsDevelopmentMode() { mux.HandleFunc("/_debug/cache", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -373,15 +431,28 @@ func startProxyServer(cmd *cobra.Command, args []string) { // Add proxy handler to mux mux.HandleFunc("/", proxyHandler) + var tlsConfig *tls.Config + if tlsEnabled { + cert, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile) + if err != nil { + util.HandleError(err, fmt.Sprintf("Failed to load TLS certificate and key: %s", err)) + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } + server := &http.Server{ - Addr: listenAddress, - Handler: mux, + Addr: listenAddress, + Handler: mux, + TLSConfig: tlsConfig, } resyncCtx, resyncCancel := context.WithCancel(context.Background()) defer resyncCancel() - go proxy.StartResyncLoop(resyncCtx, cache, domainURL, httpClient, resyncInterval, cacheTTL) + go proxy.StartBackgroundLoops(resyncCtx, cache, domainURL, httpClient, evictionStrategy, accessTokenCheckInterval, staticSecretsRefreshInterval) // Handle graceful shutdown sigCh := make(chan os.Signal, 1) @@ -406,12 +477,22 @@ func startProxyServer(cmd *cobra.Command, args []string) { os.Exit(0) }() - log.Info().Msgf("Infisical proxy server starting on %s", listenAddress) - log.Info().Msgf("Forwarding requests to %s", domain) + if tlsEnabled { + log.Info().Msgf("Infisical proxy server starting on %s with TLS enabled", listenAddress) + } else { + log.Info().Msgf("Infisical proxy server starting on %s", listenAddress) + } - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - util.HandleError(err, "Failed to start proxy server") + if tlsEnabled { + if err := server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { + util.HandleError(err, "Failed to start proxy server with TLS") + } + } else { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + util.HandleError(err, "Failed to start proxy server") + } } + log.Info().Msgf("Forwarding requests to %s", domain) } func printCacheDebug(cmd *cobra.Command, args []string) { @@ -466,8 +547,12 @@ func isStreamingEndpoint(path string) bool { func init() { proxyStartCmd.Flags().String("domain", "", "Domain of your Infisical instance (e.g., https://app.infisical.com for cloud, https://my-self-hosted-instance.com for self-hosted)") proxyStartCmd.Flags().String("listen-address", "localhost:8081", "The address for the proxy server to listen on. Defaults to localhost:8081") - proxyStartCmd.Flags().Int("resync-interval", 10, "Interval in minutes for resyncing cached secrets. Defaults to 10 minutes.") - proxyStartCmd.Flags().Int("cache-ttl", 60, "TTL in minutes for individual cache entries. Defaults to 60 minutes.") + proxyStartCmd.Flags().String("eviction-strategy", "optimistic", "Cache eviction strategy. 'optimistic' keeps cached data when Infisical is unreachable for high availability. Defaults to optimistic.") + proxyStartCmd.Flags().String("access-token-check-interval", "5m", "How often to validate that access tokens are still valid (e.g., 5m, 1h). Defaults to 5m.") + proxyStartCmd.Flags().String("static-secrets-refresh-interval", "1h", "How often to refresh cached secrets (e.g., 30m, 1h, 1d). Defaults to 1h.") + proxyStartCmd.Flags().String("tls-cert-file", "", "The path to the TLS certificate file for the proxy server. Required when `tls-enabled` is set to true (default)") + proxyStartCmd.Flags().String("tls-key-file", "", "The path to the TLS key file for the proxy server. Required when `tls-enabled` is set to true (default)") + proxyStartCmd.Flags().Bool("tls-enabled", true, "Whether to enable TLS for the proxy server. Defaults to true") proxyDebugCmd.Flags().String("listen-address", "localhost:8081", "The address where the proxy server is listening. Defaults to localhost:8081") diff --git a/packages/proxy/cache.go b/packages/proxy/cache.go index d4a9572d..98d255b8 100644 --- a/packages/proxy/cache.go +++ b/packages/proxy/cache.go @@ -4,52 +4,114 @@ import ( "bytes" "crypto/sha256" "encoding/hex" + "fmt" "io" "net/http" "strings" "sync" "time" + + "github.com/Infisical/infisical-merge/packages/util/cache" + "github.com/rs/zerolog/log" +) + +// Storage key prefixes +const ( + prefixEntry = "entry:" + prefixToken = "token:" + prefixPath = "path:" ) type IndexEntry struct { - CacheKey string - SecretPath string - EnvironmentSlug string - ProjectId string + CacheKey string `json:"cacheKey"` + SecretPath string `json:"secretPath"` + EnvironmentSlug string `json:"environmentSlug"` + ProjectId string `json:"projectId"` } type CachedRequest struct { - Method string - RequestURI string - Headers http.Header - CachedAt time.Time + Method string `json:"method"` + RequestURI string `json:"requestUri"` + Headers http.Header `json:"headers"` + CachedAt time.Time `json:"cachedAt"` } type CachedResponse struct { - StatusCode int - Header http.Header - BodyBytes []byte + StatusCode int `json:"statusCode"` + Header http.Header `json:"header"` + BodyBytes []byte `json:"bodyBytes"` +} + +// StoredCacheEntry is the structure stored in EncryptedStorage +type StoredCacheEntry struct { + Request *CachedRequest `json:"request"` + Response *CachedResponse `json:"response"` + Token string `json:"token"` + Index IndexEntry `json:"index"` } -type CacheEntry struct { - Request *CachedRequest - Response *CachedResponse +// PathIndexMarker is a simple marker stored at path index keys +type PathIndexMarker struct { + CacheKey string `json:"cacheKey"` } -// Cache is an in-memory cache for HTTP responses +// Cache is an HTTP response cache fully backed by EncryptedStorage type Cache struct { - entries map[string]*CacheEntry // main store: cacheKey -> cache entry (request + response) - tokenIndex map[string]map[string]IndexEntry // secondary index: token -> map[cacheKey]IndexEntry, used for token invalidation - compoundPathIndex map[string]map[string]map[string]map[string]map[string]struct{} // token -> projectID -> envSlug -> secretPath -> cacheKey -> struct{}, used for purging after mutation calls - mu sync.RWMutex // for thread-safe access + storage *cache.EncryptedStorage + mu sync.RWMutex } -func NewCache() *Cache { - return &Cache{ - entries: make(map[string]*CacheEntry), - tokenIndex: make(map[string]map[string]IndexEntry), - compoundPathIndex: make(map[string]map[string]map[string]map[string]map[string]struct{}), +// NewCache creates a cache with the specified options +func NewCache(opts cache.EncryptedStorageOptions) (*Cache, error) { + storage, err := cache.NewEncryptedStorage(opts) + if err != nil { + return nil, fmt.Errorf("failed to create cache storage: %w", err) } + + return &Cache{ + storage: storage, + }, nil +} + +// Close closes the underlying storage +func (c *Cache) Close() error { + return c.storage.Close() +} + +// hashToken creates a short hash of the token for use in storage keys +// This avoids storing the full token in key names while still being unique +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:8]) // First 8 bytes = 16 hex chars +} + +// buildEntryKey builds the storage key for a cache entry +func buildEntryKey(cacheKey string) string { + return prefixEntry + cacheKey +} + +// buildTokenIndexKey builds the storage key for token index entry +func buildTokenIndexKey(token, cacheKey string) string { + return prefixToken + hashToken(token) + ":" + cacheKey +} + +// buildTokenIndexPrefix builds the prefix for all token index entries for a token +func buildTokenIndexPrefix(token string) string { + return prefixToken + hashToken(token) + ":" +} + +// buildPathIndexKey builds the storage key for path index entry +func buildPathIndexKey(token string, indexEntry IndexEntry) string { + // Escape colons in secretPath to avoid key parsing issues + escapedPath := strings.ReplaceAll(indexEntry.SecretPath, ":", "\\:") + return fmt.Sprintf("%s%s:%s:%s:%s:%s", + prefixPath, + hashToken(token), + indexEntry.ProjectId, + indexEntry.EnvironmentSlug, + escapedPath, + indexEntry.CacheKey, + ) } func IsSecretsEndpoint(path string) bool { @@ -69,8 +131,13 @@ func (c *Cache) Get(cacheKey string) (*http.Response, bool) { c.mu.RLock() defer c.mu.RUnlock() - entry, exists := c.entries[cacheKey] - if !exists { + var entry StoredCacheEntry + err := c.storage.Get(buildEntryKey(cacheKey), &entry) + if err != nil { + return nil, false + } + + if entry.Response == nil { return nil, false } @@ -89,10 +156,15 @@ func (c *Cache) Set(cacheKey string, req *http.Request, resp *http.Response, tok c.mu.Lock() defer c.mu.Unlock() - // We can't use the response body directly because it will be closed by the time we need to use it + // Read response body var bodyBytes []byte if resp.Body != nil { - bodyBytes, _ = io.ReadAll(resp.Body) + var err error + bodyBytes, err = io.ReadAll(resp.Body) + if err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to read response body") + bodyBytes = nil + } } // Extract request metadata @@ -104,7 +176,7 @@ func (c *Cache) Set(cacheKey string, req *http.Request, resp *http.Response, tok responseHeader := make(http.Header) CopyHeaders(responseHeader, resp.Header) - entry := &CacheEntry{ + entry := StoredCacheEntry{ Request: &CachedRequest{ Method: req.Method, RequestURI: requestURI, @@ -116,40 +188,37 @@ func (c *Cache) Set(cacheKey string, req *http.Request, resp *http.Response, tok Header: responseHeader, BodyBytes: bodyBytes, }, + Token: token, + Index: indexEntry, } - c.entries[cacheKey] = entry - - // Update secondary index for token - if c.tokenIndex[token] == nil { - c.tokenIndex[token] = make(map[string]IndexEntry) + // Store main entry + if err := c.storage.Set(buildEntryKey(cacheKey), entry); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to store cache entry") + return } - c.tokenIndex[token][cacheKey] = indexEntry - // Update compound path index - if c.compoundPathIndex[token] == nil { - c.compoundPathIndex[token] = make(map[string]map[string]map[string]map[string]struct{}) + // Store token index entry + tokenIndexKey := buildTokenIndexKey(token, cacheKey) + if err := c.storage.Set(tokenIndexKey, indexEntry); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to store token index entry") } - if c.compoundPathIndex[token][indexEntry.ProjectId] == nil { - c.compoundPathIndex[token][indexEntry.ProjectId] = make(map[string]map[string]map[string]struct{}) - } - if c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug] == nil { - c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug] = make(map[string]map[string]struct{}) - } - if c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug][indexEntry.SecretPath] == nil { - c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug][indexEntry.SecretPath] = make(map[string]struct{}) + + // Store path index entry + pathIndexKey := buildPathIndexKey(token, indexEntry) + if err := c.storage.Set(pathIndexKey, PathIndexMarker{CacheKey: cacheKey}); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to store path index entry") } - c.compoundPathIndex[token][indexEntry.ProjectId][indexEntry.EnvironmentSlug][indexEntry.SecretPath][cacheKey] = struct{}{} } // UpdateResponse updates only the response data and cachedAt timestamp for an existing cache entry -// This is used during resync when the request parameters (and thus IndexEntry) haven't changed func (c *Cache) UpdateResponse(cacheKey string, statusCode int, header http.Header, bodyBytes []byte) { c.mu.Lock() defer c.mu.Unlock() - entry, exists := c.entries[cacheKey] - if !exists { + var entry StoredCacheEntry + err := c.storage.Get(buildEntryKey(cacheKey), &entry) + if err != nil { return } @@ -165,6 +234,11 @@ func (c *Cache) UpdateResponse(cacheKey string, statusCode int, header http.Head entry.Response.Header = responseHeader entry.Response.BodyBytes = bodyBytesCopy entry.Request.CachedAt = time.Now() + + // Update in storage + if err := c.storage.Set(buildEntryKey(cacheKey), entry); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to update cache entry") + } } func CopyHeaders(dst, src http.Header) { @@ -221,15 +295,34 @@ func (c *Cache) GetExpiredRequests(cacheTTL time.Duration) map[string]*CachedReq defer c.mu.RUnlock() now := time.Now() - requests := make(map[string]*CachedRequest, 0) + requests := make(map[string]*CachedRequest) + + // Get all entry keys + entryKeys, err := c.storage.GetKeysByPrefix(prefixEntry) + if err != nil { + log.Error().Err(err).Msg("Failed to get entry keys for expired requests check") + return requests + } + + for _, key := range entryKeys { + var entry StoredCacheEntry + if err := c.storage.Get(key, &entry); err != nil { + continue + } + + if entry.Request == nil { + continue + } - for key, entry := range c.entries { // Only include entries where cache-ttl has expired age := now.Sub(entry.Request.CachedAt) if age <= cacheTTL { continue } + // Extract cacheKey from storage key (remove prefix) + cacheKey := strings.TrimPrefix(key, prefixEntry) + requestCopy := &CachedRequest{ Method: entry.Request.Method, RequestURI: entry.Request.RequestURI, @@ -239,7 +332,7 @@ func (c *Cache) GetExpiredRequests(cacheTTL time.Duration) map[string]*CachedReq CopyHeaders(requestCopy.Headers, entry.Request.Headers) - requests[key] = requestCopy + requests[cacheKey] = requestCopy } return requests @@ -249,104 +342,76 @@ func (c *Cache) EvictEntry(cacheKey string) { c.mu.Lock() defer c.mu.Unlock() - entry, exists := c.entries[cacheKey] - if !exists { + c.evictEntryUnsafe(cacheKey) +} + +// evictEntryUnsafe evicts an entry without acquiring the lock (caller must hold lock) +func (c *Cache) evictEntryUnsafe(cacheKey string) { + // Get the entry to find its token and index info + var entry StoredCacheEntry + if err := c.storage.Get(buildEntryKey(cacheKey), &entry); err != nil { return } - token := ExtractTokenFromRequest(&http.Request{Header: entry.Request.Headers}) - - // Remove from main store - delete(c.entries, cacheKey) - - // Remove from token index and get IndexEntry for compound index cleanup - var indexEntry IndexEntry - if token != "" { - if tokenEntries, ok := c.tokenIndex[token]; ok { - indexEntry = tokenEntries[cacheKey] - delete(tokenEntries, cacheKey) - if len(tokenEntries) == 0 { - delete(c.tokenIndex, token) - } - } + // Remove main entry + if err := c.storage.Delete(buildEntryKey(cacheKey)); err != nil { + log.Error().Err(err).Str("cacheKey", cacheKey).Msg("Failed to delete cache entry") } - // Remove from compound path index - if token == "" || indexEntry.ProjectId == "" || indexEntry.EnvironmentSlug == "" || indexEntry.SecretPath == "" { - return + // Remove token index entry + tokenIndexKey := buildTokenIndexKey(entry.Token, cacheKey) + if err := c.storage.Delete(tokenIndexKey); err != nil { + log.Debug().Err(err).Str("cacheKey", cacheKey).Msg("Failed to delete token index entry") } - projectMap := c.compoundPathIndex[token] - if projectMap == nil { - return + // Remove path index entry + pathIndexKey := buildPathIndexKey(entry.Token, entry.Index) + if err := c.storage.Delete(pathIndexKey); err != nil { + log.Debug().Err(err).Str("cacheKey", cacheKey).Msg("Failed to delete path index entry") } +} - envMap := projectMap[indexEntry.ProjectId] - if envMap == nil { - // Orphaned project entry - delete(projectMap, indexEntry.ProjectId) - if len(projectMap) == 0 { - delete(c.compoundPathIndex, token) - } - return - } +// GetAllTokens returns all unique tokens that have cached entries +func (c *Cache) GetAllTokens() []string { + c.mu.RLock() + defer c.mu.RUnlock() - pathsMap := envMap[indexEntry.EnvironmentSlug] - if pathsMap == nil { - // Orphaned environment entry - delete(envMap, indexEntry.EnvironmentSlug) - if len(envMap) == 0 { - delete(projectMap, indexEntry.ProjectId) - } - if len(projectMap) == 0 { - delete(c.compoundPathIndex, token) - } - return + // Get all token index keys and extract unique token hashes + tokenKeys, err := c.storage.GetKeysByPrefix(prefixToken) + if err != nil { + log.Error().Err(err).Msg("Failed to get token index keys") + return nil } - cacheKeys := pathsMap[indexEntry.SecretPath] - if cacheKeys == nil { - // Orphaned path entry - delete(pathsMap, indexEntry.SecretPath) - if len(pathsMap) == 0 { - delete(envMap, indexEntry.EnvironmentSlug) - } - if len(envMap) == 0 { - delete(projectMap, indexEntry.ProjectId) - } - if len(projectMap) == 0 { - delete(c.compoundPathIndex, token) - } - return - } + // We need to get unique tokens, but we only have hashes in the keys + // We need to look up the actual token from entries + tokenHashToToken := make(map[string]string) - delete(cacheKeys, cacheKey) + for _, key := range tokenKeys { + // Key format: token:{tokenHash}:{cacheKey} + parts := strings.SplitN(strings.TrimPrefix(key, prefixToken), ":", 2) + if len(parts) < 2 { + continue + } + tokenHash := parts[0] + cacheKey := parts[1] - // If no more cacheKeys for this path, remove the path entry - if len(cacheKeys) == 0 { - delete(pathsMap, indexEntry.SecretPath) - } + if _, exists := tokenHashToToken[tokenHash]; exists { + continue // Already found this token + } - // Clean up empty nested maps - if len(pathsMap) == 0 { - delete(envMap, indexEntry.EnvironmentSlug) - } - if len(envMap) == 0 { - delete(projectMap, indexEntry.ProjectId) - } - if len(projectMap) == 0 { - delete(c.compoundPathIndex, token) + // Get the entry to find the actual token + var entry StoredCacheEntry + if err := c.storage.Get(buildEntryKey(cacheKey), &entry); err == nil { + tokenHashToToken[tokenHash] = entry.Token + } } -} -func (c *Cache) GetAllTokens() []string { - c.mu.RLock() - defer c.mu.RUnlock() - - tokens := make([]string, 0, len(c.tokenIndex)) - for token := range c.tokenIndex { + tokens := make([]string, 0, len(tokenHashToToken)) + for _, token := range tokenHashToToken { tokens = append(tokens, token) } + return tokens } @@ -355,17 +420,30 @@ func (c *Cache) GetFirstRequestForToken(token string) (cacheKey string, request c.mu.Lock() defer c.mu.Unlock() - tokenEntries, exists := c.tokenIndex[token] - if !exists || len(tokenEntries) == 0 { + tokenPrefix := buildTokenIndexPrefix(token) + tokenKeys, err := c.storage.GetKeysByPrefix(tokenPrefix) + if err != nil || len(tokenKeys) == 0 { return "", nil, false } // Get the first cacheKey from the token's entries - for key := range tokenEntries { - entry, exists := c.entries[key] - if !exists { - // Delete orphan cache entry - delete(tokenEntries, key) + for _, key := range tokenKeys { + // Key format: token:{tokenHash}:{cacheKey} + parts := strings.SplitN(strings.TrimPrefix(key, prefixToken), ":", 2) + if len(parts) < 2 { + continue + } + cacheKey := parts[1] + + var entry StoredCacheEntry + if err := c.storage.Get(buildEntryKey(cacheKey), &entry); err != nil { + // Delete orphan index entry + c.storage.Delete(key) + continue + } + + if entry.Request == nil { + c.storage.Delete(key) continue } @@ -378,44 +456,52 @@ func (c *Cache) GetFirstRequestForToken(token string) (cacheKey string, request CopyHeaders(requestCopy.Headers, entry.Request.Headers) - return key, requestCopy, true + return cacheKey, requestCopy, true } return "", nil, false } +// EvictAllEntriesForToken evicts all cache entries for a given token func (c *Cache) EvictAllEntriesForToken(token string) int { c.mu.Lock() defer c.mu.Unlock() - tokenEntries, exists := c.tokenIndex[token] - if !exists { + tokenPrefix := buildTokenIndexPrefix(token) + tokenKeys, err := c.storage.GetKeysByPrefix(tokenPrefix) + if err != nil { return 0 } - evictedCount := len(tokenEntries) - - // Delete all entries from main store - for cacheKey := range tokenEntries { - delete(c.entries, cacheKey) - } + evictedCount := 0 - // Delete token from token index - delete(c.tokenIndex, token) + for _, key := range tokenKeys { + // Key format: token:{tokenHash}:{cacheKey} + parts := strings.SplitN(strings.TrimPrefix(key, prefixToken), ":", 2) + if len(parts) < 2 { + continue + } + cacheKey := parts[1] - // Delete token from compound path index - delete(c.compoundPathIndex, token) + c.evictEntryUnsafe(cacheKey) + evictedCount++ + } return evictedCount } +// RemoveTokenFromIndex removes all index entries for a token (without deleting main entries) func (c *Cache) RemoveTokenFromIndex(token string) { c.mu.Lock() defer c.mu.Unlock() - delete(c.tokenIndex, token) + tokenPrefix := buildTokenIndexPrefix(token) + c.storage.DeleteByPrefix(tokenPrefix) - delete(c.compoundPathIndex, token) + // Also delete path index entries for this token + // Path keys start with path:{tokenHash}:... + pathPrefix := prefixPath + hashToken(token) + ":" + c.storage.DeleteByPrefix(pathPrefix) } // PurgeByMutation purges cache entries across ALL tokens that match the mutation path @@ -425,48 +511,34 @@ func (c *Cache) PurgeByMutation(projectID, envSlug, mutationPath string) int { purgedCount := 0 - // Iterate through all tokens in the compound index - for token, projectMap := range c.compoundPathIndex { - envMap, ok := projectMap[projectID] - if !ok { - continue - } + // Get all path index keys + pathKeys, err := c.storage.GetKeysByPrefix(prefixPath) + if err != nil { + log.Error().Err(err).Msg("Failed to get path index keys for mutation purge") + return 0 + } - pathsMap, ok := envMap[envSlug] - if !ok { + for _, key := range pathKeys { + // Key format: path:{tokenHash}:{projectId}:{envSlug}:{escapedSecretPath}:{cacheKey} + withoutPrefix := strings.TrimPrefix(key, prefixPath) + parts := strings.SplitN(withoutPrefix, ":", 5) + if len(parts) < 5 { continue } - // Iterate through all paths and check matches - for storedPath, cacheKeys := range pathsMap { - if matchesPath(storedPath, mutationPath) { - for cacheKey := range cacheKeys { - // Remove from main store - delete(c.entries, cacheKey) - - // Remove from token index - if tokenEntries, ok := c.tokenIndex[token]; ok { - delete(tokenEntries, cacheKey) - if len(tokenEntries) == 0 { - delete(c.tokenIndex, token) - } - } + keyProjectID := parts[1] + keyEnvSlug := parts[2] + keySecretPath := strings.ReplaceAll(parts[3], "\\:", ":") // Unescape colons + keyCacheKey := parts[4] - purgedCount++ - } - delete(pathsMap, storedPath) - } + // Check if this entry matches the mutation criteria + if keyProjectID != projectID || keyEnvSlug != envSlug { + continue } - // Clean up empty nested maps for this token - if len(pathsMap) == 0 { - delete(envMap, envSlug) - } - if len(envMap) == 0 { - delete(projectMap, projectID) - } - if len(projectMap) == 0 { - delete(c.compoundPathIndex, token) + if matchesPath(keySecretPath, mutationPath) { + c.evictEntryUnsafe(keyCacheKey) + purgedCount++ } } @@ -528,29 +600,77 @@ func (c *Cache) GetDebugInfo() CacheDebugInfo { var totalSize int64 entriesByToken := make(map[string]int) tokenIndex := make(map[string][]IndexEntry) - cacheKeys := make([]CacheKeyDebugInfo, 0, len(c.entries)) + cacheKeys := make([]CacheKeyDebugInfo, 0) + totalEntries := 0 - // Calculate sizes and build cache keys with timestamps - for cacheKey, entry := range c.entries { - cacheKeys = append(cacheKeys, CacheKeyDebugInfo{ - CacheKey: cacheKey, - CachedAt: entry.Request.CachedAt, - }) - totalSize += int64(len(entry.Response.BodyBytes)) + // Get all entry keys + entryKeys, err := c.storage.GetKeysByPrefix(prefixEntry) + if err != nil { + log.Error().Err(err).Msg("Failed to get entry keys for debug info") + return CacheDebugInfo{} } - // Build token index and count entries per token - for token, entries := range c.tokenIndex { - entriesByToken[token] = len(entries) - tokenIndex[token] = make([]IndexEntry, 0, len(entries)) - for _, entry := range entries { - tokenIndex[token] = append(tokenIndex[token], entry) + // Maps for building compound path index debug info + // tokenHash -> projectID -> envSlug -> secretPath -> []CacheKeyDebugInfo + pathIndexData := make(map[string]map[string]map[string]map[string][]CacheKeyDebugInfo) + tokenHashToToken := make(map[string]string) + + for _, key := range entryKeys { + var entry StoredCacheEntry + if err := c.storage.Get(key, &entry); err != nil { + continue } + + cacheKey := strings.TrimPrefix(key, prefixEntry) + tokenHash := hashToken(entry.Token) + tokenHashToToken[tokenHash] = entry.Token + + // Count entries per token + entriesByToken[entry.Token]++ + + // Add to token index + if tokenIndex[entry.Token] == nil { + tokenIndex[entry.Token] = make([]IndexEntry, 0) + } + tokenIndex[entry.Token] = append(tokenIndex[entry.Token], entry.Index) + + // Calculate size + if entry.Response != nil { + totalSize += int64(len(entry.Response.BodyBytes)) + } + + // Add to cache keys list + if entry.Request != nil { + cacheKeys = append(cacheKeys, CacheKeyDebugInfo{ + CacheKey: cacheKey, + CachedAt: entry.Request.CachedAt, + }) + } + + totalEntries++ + + // Build path index data + if pathIndexData[tokenHash] == nil { + pathIndexData[tokenHash] = make(map[string]map[string]map[string][]CacheKeyDebugInfo) + } + if pathIndexData[tokenHash][entry.Index.ProjectId] == nil { + pathIndexData[tokenHash][entry.Index.ProjectId] = make(map[string]map[string][]CacheKeyDebugInfo) + } + if pathIndexData[tokenHash][entry.Index.ProjectId][entry.Index.EnvironmentSlug] == nil { + pathIndexData[tokenHash][entry.Index.ProjectId][entry.Index.EnvironmentSlug] = make(map[string][]CacheKeyDebugInfo) + } + keyInfo := CacheKeyDebugInfo{CacheKey: cacheKey} + if entry.Request != nil { + keyInfo.CachedAt = entry.Request.CachedAt + } + pathIndexData[tokenHash][entry.Index.ProjectId][entry.Index.EnvironmentSlug][entry.Index.SecretPath] = + append(pathIndexData[tokenHash][entry.Index.ProjectId][entry.Index.EnvironmentSlug][entry.Index.SecretPath], keyInfo) } // Build compound path index debug info - compoundPathIndex := make([]CompoundPathIndexDebugInfo, 0, len(c.compoundPathIndex)) - for token, projectMap := range c.compoundPathIndex { + compoundPathIndex := make([]CompoundPathIndexDebugInfo, 0) + for tokenHash, projectMap := range pathIndexData { + token := tokenHashToToken[tokenHash] projects := make(map[string]ProjectDebugInfo) totalPaths := 0 totalKeys := 0 @@ -564,23 +684,13 @@ func (c *Cache) GetDebugInfo() CacheDebugInfo { paths := make(map[string]PathDebugInfo) envTotalKeys := 0 - for secretPath, cacheKeys := range pathsMap { - keys := make([]CacheKeyDebugInfo, 0, len(cacheKeys)) - for cacheKey := range cacheKeys { - - if entry, exists := c.entries[cacheKey]; exists { - keys = append(keys, CacheKeyDebugInfo{ - CacheKey: cacheKey, - CachedAt: entry.Request.CachedAt, - }) - } - } + for secretPath, keyInfos := range pathsMap { paths[secretPath] = PathDebugInfo{ SecretPath: secretPath, - CacheKeys: keys, - KeyCount: len(cacheKeys), + CacheKeys: keyInfos, + KeyCount: len(keyInfos), } - envTotalKeys += len(cacheKeys) + envTotalKeys += len(keyInfos) projectTotalPaths++ } @@ -611,8 +721,8 @@ func (c *Cache) GetDebugInfo() CacheDebugInfo { } return CacheDebugInfo{ - TotalEntries: len(c.entries), - TotalTokens: len(c.tokenIndex), + TotalEntries: totalEntries, + TotalTokens: len(tokenHashToToken), TotalSizeBytes: totalSize, EntriesByToken: entriesByToken, CacheKeys: cacheKeys, diff --git a/packages/proxy/resync.go b/packages/proxy/resync.go index f99980a6..95bf696e 100644 --- a/packages/proxy/resync.go +++ b/packages/proxy/resync.go @@ -15,6 +15,14 @@ import ( "github.com/rs/zerolog/log" ) +// maskToken masks a token showing only first 5 and last 5 characters +func maskToken(token string) string { + if len(token) <= 10 { + return "***" + } + return token[:5] + "..." + token[len(token)-5:] +} + // parseRateLimitSeconds extracts retry-after seconds from rate limit error message // Expected format: "Rate limit exceeded. Please try again in 57 seconds" // Returns default of 10 seconds if parsing fails @@ -121,177 +129,186 @@ func reconstructProxyRequest(domainURL *url.URL, request *CachedRequest) (*http. return proxyReq, nil } -// StartResyncLoop starts the background resync loop for cache entries -func StartResyncLoop(ctx context.Context, cache *Cache, domainURL *url.URL, httpClient *http.Client, resyncInterval int, cacheTTL int) { - ticker := time.NewTicker(time.Duration(resyncInterval) * time.Minute) - defer ticker.Stop() +// runAccessTokenValidation validates all cached tokens and evicts entries for invalid tokens +func runAccessTokenValidation(cache *Cache, domainURL *url.URL, httpClient *http.Client) { + log.Info().Msg("Starting access token validation") - log.Info(). - Int("resyncInterval", resyncInterval). - Int("cacheTTL", cacheTTL). - Msg("Resync loop started") + tokens := cache.GetAllTokens() + tokensEvicted := 0 - for { - select { - case <-ticker.C: - log.Info().Msg("Starting resync cycle") - cacheTTLDuration := time.Duration(cacheTTL) * time.Minute - - // -- Token testing phase (before expired requests) -- - - tokens := cache.GetAllTokens() - tokensEvicted := 0 - - for _, token := range tokens { - // Add jitter - time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) - - cacheKey, request, found := cache.GetFirstRequestForToken(token) - if !found { - cache.RemoveTokenFromIndex(token) - log.Debug(). - Str("token", token). - Msg("Removed orphaned token entry during token testing") - continue - } - - proxyReq, err := reconstructProxyRequest(domainURL, request) - if err != nil { - log.Error(). - Err(err). - Str("token", token). - Str("cacheKey", cacheKey). - Str("requestURI", request.RequestURI). - Msg("Failed to reconstruct request during token testing") - continue - } - - resp, err := httpClient.Do(proxyReq) - if err != nil { - // Keep entries for high availability - - log.Error(). - Err(err). - Str("token", token). - Str("cacheKey", cacheKey). - Str("requestURI", request.RequestURI). - Msg("Network error during token testing - keeping entries") - - continue - } - - // If 401, evict all entries for this token - if resp.StatusCode == http.StatusUnauthorized { - evictedCount := cache.EvictAllEntriesForToken(token) - resp.Body.Close() - tokensEvicted++ - - if evictedCount == 1 { - log.Info(). - Str("token", token). - Msg("Entry evicted") - } else { - log.Info(). - Int("evictedCount", evictedCount). - Str("token", token). - Msg("Entries evicted") - } - } else { - resp.Body.Close() - } + for _, token := range tokens { + // Add jitter to avoid bursts + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + + cacheKey, request, found := cache.GetFirstRequestForToken(token) + if !found { + cache.RemoveTokenFromIndex(token) + log.Debug(). + Str("token", maskToken(token)). + Msg("Removed orphaned token entry during token validation") + continue + } + + proxyReq, err := reconstructProxyRequest(domainURL, request) + if err != nil { + log.Error(). + Err(err). + Str("token", maskToken(token)). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Failed to reconstruct request during token validation") + continue + } + + resp, err := httpClient.Do(proxyReq) + if err != nil || (resp != nil && resp.StatusCode >= 500) { + // Keep entries for high availability (optimistic eviction strategy) + if resp != nil { + resp.Body.Close() } + log.Error(). + Err(err). + Str("token", maskToken(token)). + Str("cacheKey", cacheKey). + Str("requestURI", request.RequestURI). + Msg("Network error during token validation - keeping entries (optimistic strategy)") + continue + } - if tokensEvicted > 0 { - log.Debug(). - Int("tokensEvicted", tokensEvicted). - Msg("Token testing phase completed") + // If 401, evict all entries for this token + if resp.StatusCode == http.StatusUnauthorized { + evictedCount := cache.EvictAllEntriesForToken(token) + resp.Body.Close() + tokensEvicted++ + + if evictedCount == 1 { + log.Info(). + Str("token", maskToken(token)). + Msg("Token invalid - entry evicted") + } else { + log.Info(). + Int("evictedCount", evictedCount). + Str("token", maskToken(token)). + Msg("Token invalid - entries evicted") } + } else { + resp.Body.Close() + } + } + + log.Info(). + Int("tokensChecked", len(tokens)). + Int("tokensEvicted", tokensEvicted). + Msg("Access token validation completed") +} - // -- Expired entries processing phase -- +// runStaticSecretsRefresh refreshes all cached secrets that have exceeded the refresh interval +func runStaticSecretsRefresh(cache *Cache, domainURL *url.URL, httpClient *http.Client, refreshInterval time.Duration) { + log.Info().Msg("Starting static secrets refresh") - cycleStartTime := time.Now() - resyncIntervalDuration := time.Duration(resyncInterval) * time.Minute + cycleStartTime := time.Now() - requests := cache.GetExpiredRequests(cacheTTLDuration) + requests := cache.GetExpiredRequests(refreshInterval) - // Convert map to slice and sort by CachedAt (oldest first) - type orderedEntry struct { - cacheKey string - request *CachedRequest - } - ordered := make([]orderedEntry, 0, len(requests)) - for key, req := range requests { - ordered = append(ordered, orderedEntry{key, req}) + // Convert map to slice and sort by CachedAt (oldest first) + type orderedEntry struct { + cacheKey string + request *CachedRequest + } + ordered := make([]orderedEntry, 0, len(requests)) + for key, req := range requests { + ordered = append(ordered, orderedEntry{key, req}) + } + sort.Slice(ordered, func(i, j int) bool { + return ordered[i].request.CachedAt.Before(ordered[j].request.CachedAt) + }) + + refetched := 0 + evicted := 0 + + for _, entry := range ordered { + // Add jitter to avoid bursts + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + + proxyReq, err := reconstructProxyRequest(domainURL, entry.request) + if err != nil { + log.Error(). + Err(err). + Str("cacheKey", entry.cacheKey). + Str("requestURI", entry.request.RequestURI). + Msg("Failed to parse requestURI during secrets refresh") + continue + } + + resp, err := httpClient.Do(proxyReq) + if err != nil || (resp != nil && resp.StatusCode >= 500) { + // Keep stale entry for high availability (optimistic eviction strategy) + if resp != nil { + resp.Body.Close() } - sort.Slice(ordered, func(i, j int) bool { - return ordered[i].request.CachedAt.Before(ordered[j].request.CachedAt) - }) - - refetched := 0 - evicted := 0 - - for _, entry := range ordered { - // Add jitter - time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) - - proxyReq, err := reconstructProxyRequest(domainURL, entry.request) - if err != nil { - log.Error(). - Err(err). - Str("cacheKey", entry.cacheKey). - Str("requestURI", entry.request.RequestURI). - Msg("Failed to parse requestURI during resync") - continue - } - - resp, err := httpClient.Do(proxyReq) - if err != nil { - // Keep stale entry for high availability - - log.Error(). - Err(err). - Str("cacheKey", entry.cacheKey). - Str("requestURI", entry.request.RequestURI). - Msg("Network error during resync - keeping stale entry") - - continue - } - - refetchedResult, evictedResult, rateLimited, retryAfterSeconds := handleResyncResponse(cache, entry.cacheKey, entry.request.RequestURI, resp) - if refetchedResult { - refetched++ - } - if evictedResult { - evicted++ - } - - // Handle rate limiting - if rateLimited { - pauseDuration := time.Duration(retryAfterSeconds+2) * time.Second // 2 seconds buffer - timeUntilNextTick := resyncIntervalDuration - time.Since(cycleStartTime) - - if pauseDuration <= timeUntilNextTick { - log.Info(). - Int("pauseSeconds", retryAfterSeconds+2). - Msg("Rate limited, pausing resync") - time.Sleep(pauseDuration) - } else { - log.Warn(). - Int("pauseSeconds", retryAfterSeconds+2). - Msg("Rate limit pause exceeds resync interval, remaining entries will be processed next cycle. Increase the resync-interval value and the cache-ttl value to prevent this behavior.") - break - } - } + log.Error(). + Err(err). + Str("cacheKey", entry.cacheKey). + Str("requestURI", entry.request.RequestURI). + Msg("Network error during secrets refresh - keeping stale entry (optimistic strategy)") + continue + } + + refetchedResult, evictedResult, rateLimited, retryAfterSeconds := handleResyncResponse(cache, entry.cacheKey, entry.request.RequestURI, resp) + if refetchedResult { + refetched++ + } + if evictedResult { + evicted++ + } + + // Handle rate limiting + if rateLimited { + pauseDuration := time.Duration(retryAfterSeconds+2) * time.Second // 2 seconds buffer + timeUntilNextTick := refreshInterval - time.Since(cycleStartTime) + + if pauseDuration <= timeUntilNextTick { + log.Info(). + Int("pauseSeconds", retryAfterSeconds+2). + Msg("Rate limited, pausing secrets refresh") + time.Sleep(pauseDuration) + } else { + log.Warn(). + Int("pauseSeconds", retryAfterSeconds+2). + Msg("Rate limit pause exceeds refresh interval, remaining entries will be processed next cycle. Increase the static-secrets-refresh-interval value to prevent this behavior.") + break } + } + } + + log.Info(). + Int("expiredEntries", len(requests)). + Int("refetched", refetched). + Int("evicted", evicted). + Msg("Static secrets refresh completed") +} + +// StartBackgroundLoops starts the background loops for token validation and secrets refresh +func StartBackgroundLoops(ctx context.Context, cache *Cache, domainURL *url.URL, httpClient *http.Client, evictionStrategy string, accessTokenCheckInterval time.Duration, staticSecretsRefreshInterval time.Duration) { + tokenTicker := time.NewTicker(accessTokenCheckInterval) + secretsTicker := time.NewTicker(staticSecretsRefreshInterval) + defer tokenTicker.Stop() + defer secretsTicker.Stop() - log.Info(). - Int("expiredEntries", len(requests)). - Int("refetched", refetched). - Int("evicted", evicted). - Msg("Resync cycle completed") + log.Info(). + Str("evictionStrategy", evictionStrategy). + Str("accessTokenCheckInterval", accessTokenCheckInterval.String()). + Str("staticSecretsRefreshInterval", staticSecretsRefreshInterval.String()). + Msg("Background loops started") + for { + select { + case <-tokenTicker.C: + runAccessTokenValidation(cache, domainURL, httpClient) + case <-secretsTicker.C: + runStaticSecretsRefresh(cache, domainURL, httpClient, staticSecretsRefreshInterval) case <-ctx.Done(): - log.Info().Msg("Resync loop stopped") + log.Info().Msg("Background loops stopped") return } } diff --git a/packages/util/agent.go b/packages/util/agent.go deleted file mode 100644 index 585ab9ed..00000000 --- a/packages/util/agent.go +++ /dev/null @@ -1,58 +0,0 @@ -package util - -import ( - "fmt" - "strconv" - "time" -) - -// ParseTimeDurationString converts a string representation of a polling interval to a time.Duration -func ParseTimeDurationString(pollingInterval string, allowLessThanOneSecond bool) (time.Duration, error) { - length := len(pollingInterval) - if length < 2 { - return 0, fmt.Errorf("invalid format") - } - - splitIndex := length - for i := length - 1; i >= 0; i-- { - if pollingInterval[i] >= '0' && pollingInterval[i] <= '9' { - splitIndex = i + 1 - break - } - } - - if splitIndex == 0 || splitIndex == length { - return 0, fmt.Errorf("invalid format: must contain both number and unit") - } - - numberPart := pollingInterval[:splitIndex] - unit := pollingInterval[splitIndex:] - - number, err := strconv.Atoi(numberPart) - if err != nil { - return 0, err - } - - switch unit { - case "s": - if number < 60 && !IsDevelopmentMode() && !allowLessThanOneSecond { - return 0, fmt.Errorf("polling interval must be at least 60 seconds") - } - return time.Duration(number) * time.Second, nil - case "ms": - if number < 1000 && !IsDevelopmentMode() && !allowLessThanOneSecond { - return 0, fmt.Errorf("polling interval must be at least 1000 milliseconds") - } - return time.Duration(number) * time.Millisecond, nil - case "m": - return time.Duration(number) * time.Minute, nil - case "h": - return time.Duration(number) * time.Hour, nil - case "d": - return time.Duration(number) * 24 * time.Hour, nil - case "w": - return time.Duration(number) * 7 * 24 * time.Hour, nil - default: - return 0, fmt.Errorf("invalid time unit") - } -} diff --git a/packages/util/cache/cache-storage.go b/packages/util/cache/cache-storage.go index 1dee11f9..8a318242 100644 --- a/packages/util/cache/cache-storage.go +++ b/packages/util/cache/cache-storage.go @@ -181,6 +181,138 @@ func (s *EncryptedStorage) Delete(key string) error { }) } +// GetKeysByPrefix returns all keys that start with the given prefix (keys only, no values) +func (s *EncryptedStorage) GetKeysByPrefix(prefix string) ([]string, error) { + var keys []string + + err := s.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchValues = false // Keys only, much faster + it := txn.NewIterator(opts) + defer it.Close() + + prefixBytes := []byte(prefix) + for it.Seek(prefixBytes); it.ValidForPrefix(prefixBytes); it.Next() { + keys = append(keys, string(it.Item().Key())) + } + return nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to get keys by prefix: %w", err) + } + + return keys, nil +} + +// GetByPrefix returns all key-value pairs where the key starts with the given prefix +func (s *EncryptedStorage) GetByPrefix(prefix string, destFactory func() interface{}) (map[string]interface{}, error) { + result := make(map[string]interface{}) + + err := s.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchSize = 10 + it := txn.NewIterator(opts) + defer it.Close() + + prefixBytes := []byte(prefix) + for it.Seek(prefixBytes); it.ValidForPrefix(prefixBytes); it.Next() { + item := it.Item() + key := string(item.Key()) + + encrypted, err := item.ValueCopy(nil) + if err != nil { + return fmt.Errorf("failed to copy value for key %s: %w", key, err) + } + + decrypted, err := s.decrypt(encrypted) + if err != nil { + return fmt.Errorf("failed to decrypt value for key %s: %w", key, err) + } + + dest := destFactory() + if err := json.Unmarshal(decrypted, dest); err != nil { + return fmt.Errorf("failed to unmarshal value for key %s: %w", key, err) + } + + result[key] = dest + } + return nil + }) + + if err != nil { + return nil, err + } + + return result, nil +} + +// DeleteByPrefix deletes all keys that start with the given prefix +// Deletions are batched to avoid exceeding BadgerDB's transaction size limits +func (s *EncryptedStorage) DeleteByPrefix(prefix string) (int, error) { + const batchSize = 1000 // Process deletions in batches to avoid transaction size limits + + log.Debug().Str("prefix", prefix).Msg("Deleting by prefix") + + // First, collect all keys to delete + keysToDelete, err := s.GetKeysByPrefix(prefix) + if err != nil { + return 0, err + } + + if len(keysToDelete) == 0 { + return 0, nil + } + + deletedCount := 0 + + // Process deletions in batches + for i := 0; i < len(keysToDelete); i += batchSize { + end := i + batchSize + if end > len(keysToDelete) { + end = len(keysToDelete) + } + batch := keysToDelete[i:end] + + err = s.db.Update(func(txn *badger.Txn) error { + for _, key := range batch { + if err := txn.Delete([]byte(key)); err != nil { + return fmt.Errorf("failed to delete key %s: %w", key, err) + } + } + return nil + }) + + if err != nil { + return deletedCount, fmt.Errorf("failed to delete batch starting at index %d: %w", i, err) + } + + deletedCount += len(batch) + } + + return deletedCount, nil +} + +// Exists checks if a key exists in the storage +func (s *EncryptedStorage) Exists(key string) (bool, error) { + var exists bool + + err := s.db.View(func(txn *badger.Txn) error { + _, err := txn.Get([]byte(key)) + if err == badger.ErrKeyNotFound { + exists = false + return nil + } + if err != nil { + return err + } + exists = true + return nil + }) + + return exists, err +} + func (s *EncryptedStorage) Close() error { return s.db.Close() } diff --git a/packages/util/helper.go b/packages/util/helper.go index 30eebaa1..1e346cce 100644 --- a/packages/util/helper.go +++ b/packages/util/helper.go @@ -13,6 +13,7 @@ import ( "path" "runtime" "sort" + "strconv" "strings" "sync" "time" @@ -623,3 +624,58 @@ func OpenBrowser(url string) error { return cmd.Start() } + +// ParseTimeDurationString converts a string representation of a polling interval to a time.Duration +func ParseTimeDurationString(pollingInterval string, allowLessThanOneSecond bool) (time.Duration, error) { + length := len(pollingInterval) + if length < 2 { + return 0, fmt.Errorf("invalid format") + } + + splitIndex := length + for i := length - 1; i >= 0; i-- { + if pollingInterval[i] >= '0' && pollingInterval[i] <= '9' { + splitIndex = i + 1 + break + } + } + + if splitIndex == 0 || splitIndex == length { + return 0, fmt.Errorf("invalid format: must contain both number and unit") + } + + numberPart := pollingInterval[:splitIndex] + unit := pollingInterval[splitIndex:] + + number, err := strconv.Atoi(numberPart) + if err != nil { + return 0, err + } + + if number <= 0 { + return 0, fmt.Errorf("polling interval must be greater than 0") + } + + switch unit { + case "s": + if number < 60 && !IsDevelopmentMode() && !allowLessThanOneSecond { + return 0, fmt.Errorf("polling interval must be at least 60 seconds") + } + return time.Duration(number) * time.Second, nil + case "ms": + if number < 1000 && !IsDevelopmentMode() && !allowLessThanOneSecond { + return 0, fmt.Errorf("polling interval must be at least 1000 milliseconds") + } + return time.Duration(number) * time.Millisecond, nil + case "m": + return time.Duration(number) * time.Minute, nil + case "h": + return time.Duration(number) * time.Hour, nil + case "d": + return time.Duration(number) * 24 * time.Hour, nil + case "w": + return time.Duration(number) * 7 * 24 * time.Hour, nil + default: + return 0, fmt.Errorf("invalid time unit") + } +} From 4de006e8bedeea17c0788d41688e43db126954de Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 03:05:46 +0100 Subject: [PATCH 07/14] Update proxy.go --- packages/cmd/proxy.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index 5a5e60c9..58f2ed1c 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -483,6 +483,8 @@ func startProxyServer(cmd *cobra.Command, args []string) { log.Info().Msgf("Infisical proxy server starting on %s", listenAddress) } + log.Info().Msgf("Forwarding requests to %s", domain) + if tlsEnabled { if err := server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { util.HandleError(err, "Failed to start proxy server with TLS") @@ -492,7 +494,7 @@ func startProxyServer(cmd *cobra.Command, args []string) { util.HandleError(err, "Failed to start proxy server") } } - log.Info().Msgf("Forwarding requests to %s", domain) + } func printCacheDebug(cmd *cobra.Command, args []string) { From c5889a029ffeb6e38a81e7d6b0edc675d7e5936c Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 03:59:30 +0100 Subject: [PATCH 08/14] requested changes --- packages/cmd/proxy.go | 30 ++++++++++++++----- packages/proxy/cache.go | 64 +++++++++++++++++++++++++++------------- packages/proxy/resync.go | 5 ++-- 3 files changed, 70 insertions(+), 29 deletions(-) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index 58f2ed1c..d265a884 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -347,6 +347,14 @@ func startProxyServer(cmd *cobra.Command, args []string) { secretPath = "/" } + if projectId == "" || environment == "" { + log.Warn(). + Str("method", r.Method). + Str("path", r.URL.Path). + Msg("Missing projectId or environment for cache purging - skipping cache purge") + return + } + log.Debug(). Str("method", r.Method). Str("path", r.URL.Path). @@ -354,7 +362,6 @@ func startProxyServer(cmd *cobra.Command, args []string) { Str("environment", environment). Str("secretPath", secretPath). Msg("Attempting mutation purging across all tokens") - purgedCount := cache.PurgeByMutation(projectId, environment, secretPath) if purgedCount == 1 { @@ -412,13 +419,22 @@ func startProxyServer(cmd *cobra.Command, args []string) { proxy.CopyHeaders(cachedResp.Header, resp.Header) - cache.Set(cacheKey, r, cachedResp, token, indexEntry) + if indexEntry.ProjectId != "" && indexEntry.EnvironmentSlug != "" { - log.Debug(). - Str("method", r.Method). - Str("path", r.URL.Path). - Str("cacheKey", cacheKey). - Msg("Response cached successfully") + cache.Set(cacheKey, r, cachedResp, token, indexEntry) + + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("cacheKey", cacheKey). + Msg("Secret response cached successfully") + } else { + log.Warn(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("cacheKey", cacheKey). + Msg("Secret response not cached because project ID or environment slug is empty") + } } log.Debug(). diff --git a/packages/proxy/cache.go b/packages/proxy/cache.go index 98d255b8..9364529e 100644 --- a/packages/proxy/cache.go +++ b/packages/proxy/cache.go @@ -101,17 +101,27 @@ func buildTokenIndexPrefix(token string) string { } // buildPathIndexKey builds the storage key for path index entry +// Key format: path:{projectId}:{envSlug}:{tokenHash}:{escapedSecretPath}:{cacheKey} func buildPathIndexKey(token string, indexEntry IndexEntry) string { // Escape colons in secretPath to avoid key parsing issues escapedPath := strings.ReplaceAll(indexEntry.SecretPath, ":", "\\:") - return fmt.Sprintf("%s%s:%s:%s:%s:%s", + key := fmt.Sprintf("%s%s:%s:%s:%s:%s", prefixPath, - hashToken(token), indexEntry.ProjectId, indexEntry.EnvironmentSlug, + hashToken(token), escapedPath, indexEntry.CacheKey, ) + + log.Debug().Str("pathIndexKey", key).Msg("Built path index key") + + return key +} + +// buildPathIndexPrefixForProject builds the prefix for all path entries matching a project+env +func buildPathIndexPrefixForProject(projectId, envSlug string) string { + return fmt.Sprintf("%s%s:%s:", prefixPath, projectId, envSlug) } func IsSecretsEndpoint(path string) bool { @@ -491,6 +501,7 @@ func (c *Cache) EvictAllEntriesForToken(token string) int { } // RemoveTokenFromIndex removes all index entries for a token (without deleting main entries) +// This is a cleanup function called rarely for orphaned tokens func (c *Cache) RemoveTokenFromIndex(token string) { c.mu.Lock() defer c.mu.Unlock() @@ -499,9 +510,27 @@ func (c *Cache) RemoveTokenFromIndex(token string) { c.storage.DeleteByPrefix(tokenPrefix) // Also delete path index entries for this token - // Path keys start with path:{tokenHash}:... - pathPrefix := prefixPath + hashToken(token) + ":" - c.storage.DeleteByPrefix(pathPrefix) + // since path keys are prefixed by projectId:envSlug + // we need to scan all path keys to find those containing this token's hash + tokenHash := hashToken(token) + pathKeys, err := c.storage.GetKeysByPrefix(prefixPath) + if err != nil { + log.Debug().Err(err).Msg("Failed to get path keys for token index cleanup") + return + } + + for _, key := range pathKeys { + // Key format: path:{projectId}:{envSlug}:{tokenHash}:{secretPath}:{cacheKey} + withoutPrefix := strings.TrimPrefix(key, prefixPath) + parts := strings.SplitN(withoutPrefix, ":", 4) + if len(parts) < 3 { + continue + } + keyTokenHash := parts[2] + if keyTokenHash == tokenHash { + c.storage.Delete(key) + } + } } // PurgeByMutation purges cache entries across ALL tokens that match the mutation path @@ -511,30 +540,25 @@ func (c *Cache) PurgeByMutation(projectID, envSlug, mutationPath string) int { purgedCount := 0 - // Get all path index keys - pathKeys, err := c.storage.GetKeysByPrefix(prefixPath) + prefix := buildPathIndexPrefixForProject(projectID, envSlug) + pathKeys, err := c.storage.GetKeysByPrefix(prefix) if err != nil { log.Error().Err(err).Msg("Failed to get path index keys for mutation purge") return 0 } for _, key := range pathKeys { - // Key format: path:{tokenHash}:{projectId}:{envSlug}:{escapedSecretPath}:{cacheKey} - withoutPrefix := strings.TrimPrefix(key, prefixPath) - parts := strings.SplitN(withoutPrefix, ":", 5) - if len(parts) < 5 { + // Key format: path:{projectId}:{envSlug}:{tokenHash}:{escapedSecretPath}:{cacheKey} + // We already filtered by projectId:envSlug via prefix, so extract remaining parts + withoutPrefix := strings.TrimPrefix(key, prefix) + parts := strings.SplitN(withoutPrefix, ":", 3) + if len(parts) < 3 { continue } - keyProjectID := parts[1] - keyEnvSlug := parts[2] - keySecretPath := strings.ReplaceAll(parts[3], "\\:", ":") // Unescape colons - keyCacheKey := parts[4] - - // Check if this entry matches the mutation criteria - if keyProjectID != projectID || keyEnvSlug != envSlug { - continue - } + // parts[0] = tokenHash (not needed for matching) + keySecretPath := strings.ReplaceAll(parts[1], "\\:", ":") // Unescape colons + keyCacheKey := parts[2] if matchesPath(keySecretPath, mutationPath) { c.evictEntryUnsafe(keyCacheKey) diff --git a/packages/proxy/resync.go b/packages/proxy/resync.go index 95bf696e..47289cdf 100644 --- a/packages/proxy/resync.go +++ b/packages/proxy/resync.go @@ -15,6 +15,8 @@ import ( "github.com/rs/zerolog/log" ) +var rateLimitSecondsRegex = regexp.MustCompile(`(\d+)\s+seconds?`) + // maskToken masks a token showing only first 5 and last 5 characters func maskToken(token string) string { if len(token) <= 10 { @@ -37,8 +39,7 @@ func parseRateLimitSeconds(body []byte) int { return seconds } - re := regexp.MustCompile(`(\d+)\s+seconds?`) - matches := re.FindStringSubmatch(errorResponse.Message) + matches := rateLimitSecondsRegex.FindStringSubmatch(errorResponse.Message) if len(matches) < 2 { return 10 } From 9d336b7a5cfbc032ea923eb5951ab6dfd5d62ff4 Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 04:13:58 +0100 Subject: [PATCH 09/14] Update proxy.go --- packages/cmd/proxy.go | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index d265a884..c53494d7 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -353,26 +353,27 @@ func startProxyServer(cmd *cobra.Command, args []string) { Str("path", r.URL.Path). Msg("Missing projectId or environment for cache purging - skipping cache purge") return - } - - log.Debug(). - Str("method", r.Method). - Str("path", r.URL.Path). - Str("projectId", projectId). - Str("environment", environment). - Str("secretPath", secretPath). - Msg("Attempting mutation purging across all tokens") - purgedCount := cache.PurgeByMutation(projectId, environment, secretPath) - - if purgedCount == 1 { - log.Info(). - Str("mutationPath", secretPath). - Msg("Entry purged") } else { - log.Info(). - Int("purgedCount", purgedCount). - Str("mutationPath", secretPath). - Msg("Entries purged") + + log.Debug(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("projectId", projectId). + Str("environment", environment). + Str("secretPath", secretPath). + Msg("Attempting mutation purging across all tokens") + purgedCount := cache.PurgeByMutation(projectId, environment, secretPath) + + if purgedCount == 1 { + log.Info(). + Str("mutationPath", secretPath). + Msg("Entry purged") + } else { + log.Info(). + Int("purgedCount", purgedCount). + Str("mutationPath", secretPath). + Msg("Entries purged") + } } } From 70af94e9409f61338bf96634a0a406cdd079aa18 Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 04:27:15 +0100 Subject: [PATCH 10/14] Update proxy.go --- packages/cmd/proxy.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index c53494d7..48ecdd16 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -23,6 +23,8 @@ import ( "github.com/spf13/cobra" ) +const OPTIMISTIC_CACHE_EVICTION_STRATEGY = "optimistic" + var proxyCmd = &cobra.Command{ Example: `infisical proxy start`, Short: "Used to run Infisical proxy server", @@ -93,7 +95,7 @@ func startProxyServer(cmd *cobra.Command, args []string) { util.HandleError(err, "Unable to parse eviction-strategy flag") } - if evictionStrategy != "optimistic" { + if evictionStrategy != OPTIMISTIC_CACHE_EVICTION_STRATEGY { util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid eviction-strategy '%s'. Currently only 'optimistic' is supported.", evictionStrategy)) } @@ -566,7 +568,7 @@ func isStreamingEndpoint(path string) bool { func init() { proxyStartCmd.Flags().String("domain", "", "Domain of your Infisical instance (e.g., https://app.infisical.com for cloud, https://my-self-hosted-instance.com for self-hosted)") proxyStartCmd.Flags().String("listen-address", "localhost:8081", "The address for the proxy server to listen on. Defaults to localhost:8081") - proxyStartCmd.Flags().String("eviction-strategy", "optimistic", "Cache eviction strategy. 'optimistic' keeps cached data when Infisical is unreachable for high availability. Defaults to optimistic.") + proxyStartCmd.Flags().String("eviction-strategy", OPTIMISTIC_CACHE_EVICTION_STRATEGY, "Cache eviction strategy. 'optimistic' keeps cached data when Infisical is unreachable for high availability. Currently only 'optimistic' is supported.") proxyStartCmd.Flags().String("access-token-check-interval", "5m", "How often to validate that access tokens are still valid (e.g., 5m, 1h). Defaults to 5m.") proxyStartCmd.Flags().String("static-secrets-refresh-interval", "1h", "How often to refresh cached secrets (e.g., 30m, 1h, 1d). Defaults to 1h.") proxyStartCmd.Flags().String("tls-cert-file", "", "The path to the TLS certificate file for the proxy server. Required when `tls-enabled` is set to true (default)") From 051622d52b766bac121f28e54e383d20ee76a9ec Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 04:51:01 +0100 Subject: [PATCH 11/14] Update cache.go --- packages/proxy/cache.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/proxy/cache.go b/packages/proxy/cache.go index 9364529e..5a048fd8 100644 --- a/packages/proxy/cache.go +++ b/packages/proxy/cache.go @@ -103,7 +103,8 @@ func buildTokenIndexPrefix(token string) string { // buildPathIndexKey builds the storage key for path index entry // Key format: path:{projectId}:{envSlug}:{tokenHash}:{escapedSecretPath}:{cacheKey} func buildPathIndexKey(token string, indexEntry IndexEntry) string { - // Escape colons in secretPath to avoid key parsing issues + // Escape colons in secretPath to avoid key parsing issues. + // Currently not relevant as we don't support colons in secret paths, but if we decide to broaden our allowed folder naming in the future, this would be needed escapedPath := strings.ReplaceAll(indexEntry.SecretPath, ":", "\\:") key := fmt.Sprintf("%s%s:%s:%s:%s:%s", prefixPath, From f0bb9459a911ed4d857fdd3b66401dd5a61267fa Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 04:51:03 +0100 Subject: [PATCH 12/14] Update resync.go --- packages/proxy/resync.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/proxy/resync.go b/packages/proxy/resync.go index 47289cdf..f8d989b0 100644 --- a/packages/proxy/resync.go +++ b/packages/proxy/resync.go @@ -177,7 +177,7 @@ func runAccessTokenValidation(cache *Cache, domainURL *url.URL, httpClient *http } // If 401, evict all entries for this token - if resp.StatusCode == http.StatusUnauthorized { + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { evictedCount := cache.EvictAllEntriesForToken(token) resp.Body.Close() tokensEvicted++ From ea0ba34c23b29817284468b70884a09d6d335fe9 Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 23:29:13 +0100 Subject: [PATCH 13/14] Update proxy.go --- packages/cmd/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index 48ecdd16..d73e3a59 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -314,7 +314,7 @@ func startProxyServer(cmd *cobra.Command, args []string) { // -- Secret Mutation Purging -- - if (r.Method == http.MethodPatch || r.Method == http.MethodDelete) && + if (r.Method == http.MethodPatch || r.Method == http.MethodDelete || r.Method == http.MethodPost) && proxy.IsSecretsEndpoint(r.URL.Path) && resp.StatusCode >= 200 && resp.StatusCode < 300 { var projectId, environment, secretPath string From e7cf3a9dec47b32a400ef1a26624210732a88ed2 Mon Sep 17 00:00:00 2001 From: Daniel Hougaard Date: Wed, 14 Jan 2026 23:47:16 +0100 Subject: [PATCH 14/14] feat: memory guard cache support --- go.mod | 2 ++ go.sum | 4 ++++ packages/cmd/agent.go | 29 ++++++++++++++-------------- packages/cmd/proxy.go | 24 ++++++++++++----------- packages/util/cache/cache-storage.go | 10 ++++++---- 5 files changed, 40 insertions(+), 29 deletions(-) diff --git a/go.mod b/go.mod index 8996fabb..534c4920 100644 --- a/go.mod +++ b/go.mod @@ -58,6 +58,8 @@ require ( github.com/Masterminds/semver/v3 v3.3.0 // indirect github.com/alessio/shellescape v1.4.1 // indirect github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef // indirect + github.com/awnumar/memcall v0.4.0 // indirect + github.com/awnumar/memguard v0.23.0 // indirect github.com/aws/aws-sdk-go-v2 v1.27.2 // indirect github.com/aws/aws-sdk-go-v2/config v1.27.18 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.18 // indirect diff --git a/go.sum b/go.sum index e9f236ba..e929ea5f 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,10 @@ github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmV github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g= +github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w= +github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A= +github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M= github.com/aws/aws-sdk-go-v2 v1.27.2 h1:pLsTXqX93rimAOZG2FIYraDQstZaaGVVN4tNw65v0h8= github.com/aws/aws-sdk-go-v2 v1.27.2/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2/config v1.27.18 h1:wFvAnwOKKe7QAyIxziwSKjmer9JBMH1vzIL6W+fYuKk= diff --git a/packages/cmd/agent.go b/packages/cmd/agent.go index 1f094c86..fb1b0647 100644 --- a/packages/cmd/agent.go +++ b/packages/cmd/agent.go @@ -28,6 +28,7 @@ import ( "text/template" "time" + "github.com/awnumar/memguard" "github.com/dgraph-io/badger/v3" "github.com/go-resty/resty/v2" infisicalSdk "github.com/infisical/go-sdk" @@ -90,13 +91,13 @@ type RetryConfig struct { } type Config struct { - Version string `yaml:"version,omitempty"` - Infisical InfisicalConfig `yaml:"infisical"` - Auth AuthConfig `yaml:"auth"` - Sinks []Sink `yaml:"sinks"` - Cache CacheConfig `yaml:"cache,omitempty"` - Templates []Template `yaml:"templates"` - Certificates []AgentCertificateConfig `yaml:"certificates,omitempty"` + Version string `yaml:"version,omitempty"` + Infisical InfisicalConfig `yaml:"infisical"` + Auth AuthConfig `yaml:"auth"` + Sinks []Sink `yaml:"sinks"` + Cache CacheConfig `yaml:"cache,omitempty"` + Templates []Template `yaml:"templates"` + Certificates []AgentCertificateConfig `yaml:"certificates,omitempty"` } type TemplateWithID struct { @@ -195,10 +196,10 @@ type Template struct { } type CertificateLifecycleConfig struct { - RenewBeforeExpiry string `yaml:"renew-before-expiry"` - StatusCheckInterval string `yaml:"status-check-interval"` + RenewBeforeExpiry string `yaml:"renew-before-expiry"` + StatusCheckInterval string `yaml:"status-check-interval"` FailureRetryInterval string `yaml:"failure-retry-interval,omitempty"` - MaxFailureRetries int `yaml:"max-failure-retries,omitempty"` + MaxFailureRetries int `yaml:"max-failure-retries,omitempty"` } type CertificateAttributes struct { @@ -343,7 +344,10 @@ func NewCacheManager(ctx context.Context, cacheConfig *CacheConfig) (*CacheManag return &CacheManager{}, fmt.Errorf("unable to read service account token: %v. Please ensure the file exists and is not empty", err) } - encryptionKey := sha256.Sum256(serviceAccountToken) + hash := sha256.Sum256(serviceAccountToken) + encryptionKey := memguard.NewBufferFromBytes(hash[:]) // the hash (source) is wiped after copied to the secure buffer + + defer encryptionKey.Destroy() cacheStorage, err := cache.NewEncryptedStorage(cache.EncryptedStorageOptions{ DBPath: cacheConfig.Persistent.Path, @@ -2000,7 +2004,6 @@ func validateCertificateLifecycleConfig(certificates *[]AgentCertificateConfig) return nil } - func resolveCertificateNameReferences(certificates *[]AgentCertificateConfig, httpClient *resty.Client) error { for i := range *certificates { cert := &(*certificates)[i] @@ -2086,7 +2089,6 @@ func buildCertificateAttributes(certificate *AgentCertificateConfig) *api.Certif removeRoots = false } - attributes.RemoveRootsFromChain = removeRoots hasAny = true @@ -3207,7 +3209,6 @@ var agentCmd = &cobra.Command{ log.Warn().Msg("credential revocation timed out after 5 minutes, forcing exit") exitCode = 1 } - } os.Exit(exitCode) diff --git a/packages/cmd/proxy.go b/packages/cmd/proxy.go index d73e3a59..156b127f 100644 --- a/packages/cmd/proxy.go +++ b/packages/cmd/proxy.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -15,15 +16,19 @@ import ( "syscall" "time" - "github.com/Infisical/infisical-merge/packages/crypto" "github.com/Infisical/infisical-merge/packages/proxy" "github.com/Infisical/infisical-merge/packages/util" "github.com/Infisical/infisical-merge/packages/util/cache" + "github.com/awnumar/memguard" "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) -const OPTIMISTIC_CACHE_EVICTION_STRATEGY = "optimistic" +type CacheEvictionStrategy string + +const ( + CacheEvictionStrategyOptimistic CacheEvictionStrategy = "optimistic" +) var proxyCmd = &cobra.Command{ Example: `infisical proxy start`, @@ -95,7 +100,7 @@ func startProxyServer(cmd *cobra.Command, args []string) { util.HandleError(err, "Unable to parse eviction-strategy flag") } - if evictionStrategy != OPTIMISTIC_CACHE_EVICTION_STRATEGY { + if evictionStrategy != string(CacheEvictionStrategyOptimistic) { util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid eviction-strategy '%s'. Currently only 'optimistic' is supported.", evictionStrategy)) } @@ -135,15 +140,12 @@ func startProxyServer(cmd *cobra.Command, args []string) { // Create in-memory cache (no persistence, no encryption needed for ephemeral data) // For persistent cache with encryption, use proxy.NewCacheWithOptions - - encryptionKey, err := crypto.GenerateRandomBytes(32) - if err != nil { - util.HandleError(err, "Failed to generate random encryption key") - } + encryptionKey := memguard.NewBufferRandom(32) + defer encryptionKey.Destroy() cache, err := proxy.NewCache(cache.EncryptedStorageOptions{ InMemory: true, - EncryptionKey: [32]byte(encryptionKey), + EncryptionKey: encryptionKey, }) if err != nil { @@ -287,7 +289,7 @@ func startProxyServer(cmd *cobra.Command, args []string) { flusher.Flush() } } - if err == io.EOF { + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { break } if err != nil { @@ -568,7 +570,7 @@ func isStreamingEndpoint(path string) bool { func init() { proxyStartCmd.Flags().String("domain", "", "Domain of your Infisical instance (e.g., https://app.infisical.com for cloud, https://my-self-hosted-instance.com for self-hosted)") proxyStartCmd.Flags().String("listen-address", "localhost:8081", "The address for the proxy server to listen on. Defaults to localhost:8081") - proxyStartCmd.Flags().String("eviction-strategy", OPTIMISTIC_CACHE_EVICTION_STRATEGY, "Cache eviction strategy. 'optimistic' keeps cached data when Infisical is unreachable for high availability. Currently only 'optimistic' is supported.") + proxyStartCmd.Flags().String("eviction-strategy", string(CacheEvictionStrategyOptimistic), "Cache eviction strategy. 'optimistic' keeps cached data when Infisical is unreachable for high availability. Currently only 'optimistic' is supported.") proxyStartCmd.Flags().String("access-token-check-interval", "5m", "How often to validate that access tokens are still valid (e.g., 5m, 1h). Defaults to 5m.") proxyStartCmd.Flags().String("static-secrets-refresh-interval", "1h", "How often to refresh cached secrets (e.g., 30m, 1h, 1d). Defaults to 1h.") proxyStartCmd.Flags().String("tls-cert-file", "", "The path to the TLS certificate file for the proxy server. Required when `tls-enabled` is set to true (default)") diff --git a/packages/util/cache/cache-storage.go b/packages/util/cache/cache-storage.go index 8a318242..44a7b263 100644 --- a/packages/util/cache/cache-storage.go +++ b/packages/util/cache/cache-storage.go @@ -11,13 +11,14 @@ import ( "reflect" "time" + "github.com/awnumar/memguard" "github.com/dgraph-io/badger/v3" "github.com/rs/zerolog/log" ) type EncryptedStorage struct { db *badger.DB - key [32]byte + key *memguard.LockedBuffer } type EncryptedStorageOptions struct { @@ -27,7 +28,7 @@ type EncryptedStorageOptions struct { InMemory bool // Only required if InMemory is false - EncryptionKey [32]byte + EncryptionKey *memguard.LockedBuffer } func NewEncryptedStorage(opts EncryptedStorageOptions) (*EncryptedStorage, error) { @@ -314,6 +315,7 @@ func (s *EncryptedStorage) Exists(key string) (bool, error) { } func (s *EncryptedStorage) Close() error { + s.key.Destroy() return s.db.Close() } @@ -347,7 +349,7 @@ func (s *EncryptedStorage) StartPeriodicGarbageCollection(context context.Contex } func (s *EncryptedStorage) encrypt(plaintext []byte) ([]byte, error) { - block, err := aes.NewCipher(s.key[:]) + block, err := aes.NewCipher(s.key.Bytes()) if err != nil { return nil, err } @@ -366,7 +368,7 @@ func (s *EncryptedStorage) encrypt(plaintext []byte) ([]byte, error) { } func (s *EncryptedStorage) decrypt(ciphertext []byte) ([]byte, error) { - block, err := aes.NewCipher(s.key[:]) + block, err := aes.NewCipher(s.key.Bytes()) if err != nil { return nil, err }