diff --git a/CHANGELOG.md b/CHANGELOG.md index e93d58ce..87c91cc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ adhere to [Semantic Versioning](http://semver.org/spec/v2.0.0.html) starting v1. - Remove dependency: github.com/pkg/errors (#443) - Add public Cache.RemainingCost() method - Add support for uint keys +- Implement public Cache.IterValues() method (#475) **Fixed** diff --git a/cache.go b/cache.go index b70ed006..8db6f248 100644 --- a/cache.go +++ b/cache.go @@ -414,6 +414,16 @@ func (c *Cache[K, V]) GetTTL(key K) (time.Duration, bool) { return time.Until(expiration), true } +// IterValues iterates the values of the Map, passing them to the callback. +// It guarantees that any value in the Map will be visited only once. +// The set of values visited by IterValues is non-deterministic. +func (c *Cache[K, V]) IterValues(cb func(v V) (stop bool)) { + if c == nil || c.isClosed.Load() { + return + } + c.storedItems.IterValues(cb) +} + // Close stops all goroutines and closes all channels. func (c *Cache[K, V]) Close() { if c == nil || c.isClosed.Load() { diff --git a/cache_test.go b/cache_test.go index 1f5316c0..6959f8ec 100644 --- a/cache_test.go +++ b/cache_test.go @@ -388,6 +388,96 @@ func TestCacheGet(t *testing.T) { require.Zero(t, val) } +func TestCacheIterValues(t *testing.T) { + c, err := NewCache(&Config[string, int]{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, + IgnoreInternalCost: true, + Metrics: true, + }) + require.NoError(t, err) + + expectedValues := map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 4, + } + for k, v := range expectedValues { + key, conflict := z.KeyToHash(k) + i := Item[int]{ + Key: key, + Conflict: conflict, + Value: v, + } + c.storedItems.Set(&i) + } + + resultValues := make([]int, 0) + c.IterValues(func(v int) (stop bool) { + resultValues = append(resultValues, v) + return false + }) + + expectedSlice := make([]int, 0, len(expectedValues)) + for _, v := range expectedValues { + expectedSlice = append(expectedSlice, v) + } + require.ElementsMatch(t, expectedSlice, resultValues) +} + +func TestCacheIterValuesNil(t *testing.T) { + // Test that calling IterValues on a nil cache is safe and doesn't panic + var c *Cache[int, int] + + callbackCalled := false + c.IterValues(func(v int) (stop bool) { + callbackCalled = true + return false + }) + + // Callback should never be called on a nil cache + require.False(t, callbackCalled) +} + +func TestCacheIterValuesAfterClose(t *testing.T) { + c, err := NewCache(&Config[int, int]{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, + IgnoreInternalCost: true, + Metrics: true, + }) + require.NoError(t, err) + + expectedCacheLen := 5 + for k := 0; k < expectedCacheLen; k++ { + c.Set(k, k*10, 1) + } + c.Wait() + + // Verify values exist before closing + resultsBefore := make([]int, 0) + c.IterValues(func(v int) (stop bool) { + resultsBefore = append(resultsBefore, v) + return false + }) + require.Len(t, resultsBefore, expectedCacheLen) + + c.Close() + + // Try to iterate after close - should not panic and callback should not be called + callbackCalled := false + c.IterValues(func(v int) (stop bool) { + callbackCalled = true + return false + }) + + // Callback should never be called on a closed cache + require.False(t, callbackCalled) +} + // retrySet calls SetWithTTL until the item is accepted by the cache. func retrySet(t *testing.T, c *Cache[int, int], key, value int, cost int64, ttl time.Duration) { for { diff --git a/store.go b/store.go index 6ba3fc0b..39901e46 100644 --- a/store.go +++ b/store.go @@ -45,6 +45,10 @@ type store[V any] interface { // Clear clears all contents of the store. Clear(onEvict func(item *Item[V])) SetShouldUpdateFn(f updateFn[V]) + // IterValues iterates the values of the Map, passing them to the callback. + // It guarantees that any value in the Map will be visited only once. + // The set of values visited by IterValues is non-deterministic. + IterValues(cb func(v V) (stop bool)) } // newStore returns the default store implementation. @@ -76,6 +80,32 @@ func (m *shardedMap[V]) SetShouldUpdateFn(f updateFn[V]) { } } +// IterValues iterates the values of the Map, passing them to the callback. +// It guarantees that any value in the Map will be visited only once. +// The set of values visited by IterValues is non-deterministic. +func (sm *shardedMap[V]) IterValues(cb func(v V) (stop bool)) { + for _, shard := range sm.shards { + stopped := func() bool { + shard.RLock() + defer shard.RUnlock() + + for _, item := range shard.data { + if !item.expiration.IsZero() && time.Now().After(item.expiration) { + continue + } + if stop := cb(item.value); stop { + return true + } + } + return false + }() + + if stopped { + break + } + } +} + func (sm *shardedMap[V]) Get(key, conflict uint64) (V, bool) { return sm.shards[key%numShards].get(key, conflict) } diff --git a/store_test.go b/store_test.go index 2e60b98d..827bf7fc 100644 --- a/store_test.go +++ b/store_test.go @@ -61,6 +61,127 @@ func TestStoreDel(t *testing.T) { s.Del(2, 0) } +func TestStoreIterValues(t *testing.T) { + s := newStore[int]() + expectedValues := map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 4, + } + for k, v := range expectedValues { + key, conflict := z.KeyToHash(k) + i := Item[int]{ + Key: key, + Conflict: conflict, + Value: v, + } + s.Set(&i) + } + + resultValues := make([]int, 0) + s.IterValues(func(v int) (stop bool) { + resultValues = append(resultValues, v) + return false + }) + + expectedSlice := make([]int, 0, len(expectedValues)) + for _, v := range expectedValues { + expectedSlice = append(expectedSlice, v) + } + require.ElementsMatch(t, expectedSlice, resultValues) +} + +func TestStoreIterValuesWithStop(t *testing.T) { + s := newStore[int]() + expectedValues := map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 4, + } + for k, v := range expectedValues { + key, conflict := z.KeyToHash(k) + i := Item[int]{ + Key: key, + Conflict: conflict, + Value: v, + } + s.Set(&i) + } + + resultValues := make([]int, 0) + index := 1 + expectedLength := 3 + s.IterValues(func(v int) (stop bool) { + resultValues = append(resultValues, v) + + // Only three elements should be present + if index == expectedLength { + return true + } + + index++ + return false + }) + + require.Len(t, resultValues, expectedLength) + // Verify all returned values are valid (exist in expectedValues) + expectedSlice := make([]int, 0, len(expectedValues)) + for _, v := range expectedValues { + expectedSlice = append(expectedSlice, v) + } + require.Subset(t, expectedSlice, resultValues) +} + +func TestStoreIterValuesSkipsExpiredItems(t *testing.T) { + s := newStore[int]() + now := time.Now() + + // Add items with various expiration states + items := []struct { + key string + value int + expiration time.Time + shouldSee bool + }{ + {"valid1", 1, now.Add(time.Hour), true}, // Expires in 1 hour + {"expired1", 2, now.Add(-time.Hour), false}, // Expired 1 hour ago + {"expired2", 3, now.Add(-time.Second), false}, // Expired 1 second ago + {"valid2", 4, now.Add(2 * time.Minute), true}, // Expires in 2 minutes + {"noexpiry", 5, time.Time{}, true}, // No expiration set + } + + for _, item := range items { + key, conflict := z.KeyToHash(item.key) + i := Item[int]{ + Key: key, + Conflict: conflict, + Value: item.value, + Expiration: item.expiration, + } + s.Set(&i) + } + + // Collect values from iteration + resultValues := make([]int, 0) + s.IterValues(func(v int) (stop bool) { + resultValues = append(resultValues, v) + return false + }) + + // Only non-expired items should be returned + expectedValues := []int{} + for _, item := range items { + if item.shouldSee { + expectedValues = append(expectedValues, item.value) + } + } + + require.ElementsMatch(t, expectedValues, resultValues) + require.Len(t, resultValues, 3) +} + func TestStoreClear(t *testing.T) { s := newStore[uint64]() for i := uint64(0); i < 1000; i++ {