diff --git a/internal/server/spanner/client.go b/internal/server/spanner/client.go index 7793ad6bc..3f21ece42 100644 --- a/internal/server/spanner/client.go +++ b/internal/server/spanner/client.go @@ -18,19 +18,40 @@ package spanner import ( "context" "fmt" + "sync" + "time" "cloud.google.com/go/spanner" "gopkg.in/yaml.v3" ) +const ( + // CACHE_DURATION defines how long the CompletionTimestamp is kept in memory before being refetched. + CACHE_DURATION = 5 * time.Second +) + // SpannerClient encapsulates the Spanner client. type SpannerClient struct { client *spanner.Client + + // Cache for storing CompletionTimestamp for stale reads. + cacheMutex sync.RWMutex + cachedTimestamp *time.Time + cacheExpiry time.Time + + // For mocking in tests. + timestampFetcher func(context.Context) (*time.Time, error) + clock func() time.Time } // newSpannerClient creates a new SpannerClient. func newSpannerClient(client *spanner.Client) *SpannerClient { - return &SpannerClient{client: client} + sc := &SpannerClient{ + client: client, + clock: time.Now, // Default to real time + } + sc.timestampFetcher = sc.fetchCompletionTimestampFromSpanner + return sc } // NewSpannerClient creates a new SpannerClient from the config yaml string. @@ -43,7 +64,15 @@ func NewSpannerClient(ctx context.Context, spannerConfigYaml string) (*SpannerCl if err != nil { return nil, fmt.Errorf("failed to create SpannerClient: %w", err) } - return newSpannerClient(client), nil + sc := newSpannerClient(client) + + // Cache initial CompletionTimestamp + _, err = sc.GetStalenessTimestampBound(ctx) + if err != nil { + return nil, fmt.Errorf("failed to warm up stable timestamp cache: %w", err) + } + + return sc, nil } // createSpannerClient creates the database name string and initializes the Spanner client. diff --git a/internal/server/spanner/client_test.go b/internal/server/spanner/client_test.go new file mode 100644 index 000000000..d1203654f --- /dev/null +++ b/internal/server/spanner/client_test.go @@ -0,0 +1,118 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "fmt" + "strings" + "testing" + "time" +) + +func TestCacheHit(t *testing.T) { + var fetchCount int + mockTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC) + stableTime := mockTime.Add(-1 * time.Minute) + + sc := &SpannerClient{clock: func() time.Time { return mockTime }} + sc.timestampFetcher = func(ctx context.Context) (*time.Time, error) { + fetchCount++ + return &stableTime, nil + } + // Initialization will populate cache. + _, err := sc.getCompletionTimestamp(context.Background()) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if fetchCount != 1 { + t.Fatalf("Setup failed, expected 1 fetch, got %d", fetchCount) + } + + // This call is immediately after initialization, within the 5-second duration. + _, err = sc.getCompletionTimestamp(context.Background()) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if fetchCount != 1 { + t.Errorf("Expected timestamp fetch count to remain 1, got %d", fetchCount) + } +} + +func TestCacheExpiration(t *testing.T) { + var fetchCount int + mockTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC) + stableTime := mockTime.Add(-1 * time.Minute) + + sc := &SpannerClient{ + cacheExpiry: mockTime.Add(CACHE_DURATION), + clock: func() time.Time { return mockTime }, + } + sc.timestampFetcher = func(ctx context.Context) (*time.Time, error) { + fetchCount++ + return &stableTime, nil + } + // Initialization will populate cache. + _, err := sc.getCompletionTimestamp(context.Background()) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if fetchCount != 1 { + t.Fatalf("Setup failed, expected 1 fetch, got %d", fetchCount) + } + + // Advance time past expiration. + mockTime = mockTime.Add(6 * time.Second) + _, err = sc.getCompletionTimestamp(context.Background()) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if fetchCount != 2 { + t.Errorf("Expected timestamp fetch count to increase to 2, got %d", fetchCount) + } + expectedExpiry := mockTime.Add(CACHE_DURATION) + if sc.cacheExpiry.Sub(expectedExpiry) > time.Millisecond { + t.Errorf("Cache expiry was not correctly updated after refetch.") + } +} + +func TestGetStalenessTimestampBound(t *testing.T) { + mockTime := time.Date(2025, time.January, 1, 10, 0, 0, 0, time.UTC) + stableTime := mockTime.Add(-5 * time.Minute) // Stable time is 5 minutes ago + + sc := &SpannerClient{ + cacheExpiry: mockTime.Add(CACHE_DURATION), + clock: func() time.Time { return mockTime }, + } + sc.timestampFetcher = func(ctx context.Context) (*time.Time, error) { + return &stableTime, nil + } + + timestamp, err := sc.GetStalenessTimestampBound(context.Background()) + + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if timestamp == nil { + t.Fatal("Expected a non-nil TimestampBound") + } else { + // Approximate check relying on String() representation of ReadTimestamp. + expectedString := fmt.Sprintf("ReadTimestamp(%s)", stableTime.Format(time.RFC3339Nano)) + actualString := (*timestamp).String() + if !strings.Contains(actualString, stableTime.Format("2006-01-02")) { + t.Errorf("Expected ReadTimestamp containing %v, got %s", expectedString, actualString) + } + } +} diff --git a/internal/server/spanner/query.go b/internal/server/spanner/query.go index 69d4df861..92edca03b 100644 --- a/internal/server/spanner/query.go +++ b/internal/server/spanner/query.go @@ -18,10 +18,13 @@ package spanner import ( "context" "fmt" + "log/slog" + "time" "cloud.google.com/go/spanner" v2 "github.com/datacommonsorg/mixer/internal/server/v2" "google.golang.org/api/iterator" + "google.golang.org/grpc/codes" ) const ( @@ -211,16 +214,47 @@ func (sc *SpannerClient) queryAndCollect( newStruct func() interface{}, withStruct func(interface{}), ) error { - iter := sc.client.Single().Query(ctx, stmt) + timestampBound, err := sc.GetStalenessTimestampBound(ctx) + if err != nil { + return err + } + + // Attempt stale read + iter := sc.client.Single().WithTimestampBound(*timestampBound).Query(ctx, stmt) defer iter.Stop() + err = sc.processRows(iter, newStruct, withStruct) + + // Check if the error is due to an expired timestamp (FAILED_PRECONDITION). + // Currently the timestamp is set manually so can naturally get stale. + // So for now, just log an error and fallback to a strong read. + // TODO: Once the Spanner instance is set to periodically update the timestamp, increase severity of check, as this indicates that ingestion failed. + if spanner.ErrCode(err) == codes.FailedPrecondition { + slog.Error("Stale read timestamp expired (before earliest_version_time). Falling back to StrongRead.", + "expiredTimestamp", timestampBound.String()) + + // Fallback to strong read + strongBound := spanner.StrongRead() + iter = sc.client.Single().WithTimestampBound(strongBound).Query(ctx, stmt) + defer iter.Stop() + + err = sc.processRows(iter, newStruct, withStruct) + } + if err != nil { + return fmt.Errorf("failed to execute Spanner query after fallback attempt: %w", err) + } + + return nil +} + +func (sc *SpannerClient) processRows(iter *spanner.RowIterator, newStruct func() interface{}, withStruct func(interface{})) error { for { row, err := iter.Next() if err == iterator.Done { break } if err != nil { - return fmt.Errorf("failed to fetch row: %w", err) + return err } rowStruct := newStruct() @@ -229,6 +263,69 @@ func (sc *SpannerClient) queryAndCollect( } withStruct(rowStruct) } - return nil } + +// fetchCompletionTimestampFromSpanner returns the latest reported CompletionTimestamp in IngestionHistory. +func (sc *SpannerClient) fetchCompletionTimestampFromSpanner(ctx context.Context) (*time.Time, error) { + iter := sc.client.Single().Query(ctx, *GetCompletionTimestampQuery()) + defer iter.Stop() + + row, err := iter.Next() + if err == iterator.Done { + return nil, fmt.Errorf("no rows found in IngestionHistory") + } + if err != nil { + return nil, fmt.Errorf("failed to fetch row: %w", err) + } + + var timestamp time.Time + if err := row.Column(0, ×tamp); err != nil { + return nil, fmt.Errorf("failed to read CompletionTimestamp column: %w", err) + } + + return ×tamp, nil +} + +// getCompletionTimestamp returns the latest reported CompletionTimestamp. +// It prioritizes returning a value from an in-memory cache to reduce Spanner traffic. +func (sc *SpannerClient) getCompletionTimestamp(ctx context.Context) (*time.Time, error) { + // Check cache + sc.cacheMutex.RLock() + if sc.cachedTimestamp != nil && sc.clock().Before(sc.cacheExpiry) { + sc.cacheMutex.RUnlock() + return sc.cachedTimestamp, nil + } + sc.cacheMutex.RUnlock() + + // Fetch from Spanner + sc.cacheMutex.Lock() + defer sc.cacheMutex.Unlock() + + // Re-check the cache under the write lock (to prevent a race condition + // where another goroutine updated it between the RUnlock and this Lock) + if sc.cachedTimestamp != nil && sc.clock().Before(sc.cacheExpiry) { + return sc.cachedTimestamp, nil + } + timestamp, err := sc.timestampFetcher(ctx) + if err != nil { + return nil, err + } + + // Update cache + sc.cachedTimestamp = timestamp + sc.cacheExpiry = sc.clock().Add(CACHE_DURATION) + + return timestamp, nil +} + +// GetStalenessTimestampBound returns the TimestampBound that should be used for stale reads in Spanner. +func (sc *SpannerClient) GetStalenessTimestampBound(ctx context.Context) (*spanner.TimestampBound, error) { + completionTimestamp, err := sc.getCompletionTimestamp(ctx) + if err != nil { + return nil, err + } + + timestampBound := spanner.ReadTimestamp(*completionTimestamp) + return ×tampBound, nil +} diff --git a/internal/server/spanner/query_builder.go b/internal/server/spanner/query_builder.go index 966a531ff..4eaa8cd2e 100644 --- a/internal/server/spanner/query_builder.go +++ b/internal/server/spanner/query_builder.go @@ -26,6 +26,12 @@ import ( v2 "github.com/datacommonsorg/mixer/internal/server/v2" ) +func GetCompletionTimestampQuery() *spanner.Statement { + return &spanner.Statement{ + SQL: statements.getCompletionTimestamp, + } +} + func GetNodePropsQuery(ids []string, out bool) *spanner.Statement { switch out { case true: diff --git a/internal/server/spanner/statements.go b/internal/server/spanner/statements.go index 6c406bb67..a8da1a0d4 100644 --- a/internal/server/spanner/statements.go +++ b/internal/server/spanner/statements.go @@ -21,6 +21,8 @@ import ( // SQL / GQL statements executed by the SpannerClient var statements = struct { + // Fetch latest CompletionTimestamp from IngestionHistory table. + getCompletionTimestamp string // Fetch Properties for out arcs. getPropsBySubjectID string // Fetch Properties for in arcs. @@ -74,6 +76,13 @@ var statements = struct { // Resolve one property to another. resolvePropToProp string }{ + getCompletionTimestamp: ` SELECT + CompletionTimestamp + FROM + IngestionHistory + ORDER BY + CompletionTimestamp DESC + LIMIT 1`, getPropsBySubjectID: ` GRAPH DCGraph MATCH -[e:Edge WHERE e.subject_id IN UNNEST(@ids)]-> diff --git a/test/setup.go b/test/setup.go index cc43747da..003cbd83a 100644 --- a/test/setup.go +++ b/test/setup.go @@ -19,6 +19,7 @@ import ( "context" "encoding/json" "log" + "log/slog" "net" "os" "path" @@ -392,7 +393,16 @@ func NewSpannerClient() *spanner.SpannerClient { } _, filename, _, _ := runtime.Caller(0) spannerGraphInfoYamlPath := path.Join(path.Dir(filename), "../deploy/storage/spanner_graph_info.yaml") - return newSpannerClient(context.Background(), spannerGraphInfoYamlPath) + sc := newSpannerClient(context.Background(), spannerGraphInfoYamlPath) + + // Cache initial CompletionTimestamp + _, err := sc.GetStalenessTimestampBound(context.Background()) + if err != nil { + slog.Error("failed to warm up stable timestamp cache", "error", err) + return nil + } + + return sc } func newSpannerClient(ctx context.Context, spannerGraphInfoYamlPath string) *spanner.SpannerClient {