Skip to content

Commit 46ff1bd

Browse files
committed
Merge remote-tracking branch 'dfroger/fix-pipeline-ttl' into fix/dfroger/fix-pipeline-ttl
2 parents 2bcc4c2 + e9b6ab7 commit 46ff1bd

File tree

4 files changed

+100
-3
lines changed

4 files changed

+100
-3
lines changed

redisvl/index/storage.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,20 @@ async def _aget(
175175
"""Asynchronously get data from Redis using the provided client or pipeline."""
176176
raise NotImplementedError
177177

178+
@staticmethod
179+
async def _aexpire(client: AsyncRedisClientOrPipeline, key: str, ttl: int):
180+
"""Asynchronously set TTL on a key using the provided client or pipeline
181+
182+
Args:
183+
client (AsyncRedisClientOrPipeline): The async Redis client or pipeline instance.
184+
key (str): The key for which to set the TTL.
185+
ttl (int): Time-to-live in seconds for each key.
186+
"""
187+
if isinstance(client, (AsyncPipeline, AsyncClusterPipeline)):
188+
client.expire(key, ttl)
189+
else:
190+
await client.expire(key, ttl)
191+
178192
def _validate(self, obj: Dict[str, Any]) -> Dict[str, Any]:
179193
"""
180194
Validate an object against the schema using Pydantic-based validation.
@@ -490,7 +504,7 @@ async def awrite(
490504

491505
# Set TTL if provided
492506
if ttl:
493-
await pipe.expire(key, ttl)
507+
await self._aexpire(pipe, key, ttl)
494508

495509
added_keys.append(key)
496510

@@ -615,7 +629,7 @@ async def _aset(client: AsyncRedisClientOrPipeline, key: str, obj: Dict[str, Any
615629
"""Asynchronously set a hash value in Redis for the given key.
616630
617631
Args:
618-
client (AsyncClientOrPipeline): The async Redis client or pipeline instance.
632+
client (AsyncRedisClientOrPipeline): The async Redis client or pipeline instance.
619633
key (str): The key under which to store the hash.
620634
obj (Dict[str, Any]): The hash to store in Redis.
621635
"""
@@ -644,7 +658,7 @@ async def _aget(
644658
"""Asynchronously retrieve a hash value from Redis for the given key.
645659
646660
Args:
647-
client (AsyncRedisClient): The async Redis client or pipeline instance.
661+
client (AsyncRedisClientOrPipeline): The async Redis client or pipeline instance.
648662
key (str): The key for which to retrieve the hash.
649663
650664
Returns:

tests/integration/test_async_search_index.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,26 @@ async def test_async_search_index_connect(index_schema, redis_url):
707707
await async_index.connect(redis_url=redis_url)
708708
assert async_index.client is not None
709709
await async_index.disconnect()
710+
711+
712+
@pytest.mark.asyncio
713+
@pytest.mark.parametrize("ttl", [None, 30])
714+
async def test_search_index_load_with_ttl(async_index, ttl):
715+
"""Test that TTL is correctly set on keys when using load() with ttl parameter."""
716+
await async_index.create(overwrite=True, drop=True)
717+
718+
# Load test data with TTL parameter
719+
data = [{"id": "1", "test": "foo"}]
720+
keys = await async_index.load(data, id_field="id", ttl=ttl)
721+
722+
# Check TTL on the loaded key
723+
client = await async_index._get_client()
724+
key_ttl = await client.ttl(keys[0])
725+
726+
if ttl is None:
727+
# No TTL set, should return -1
728+
assert key_ttl == -1
729+
else:
730+
# TTL should be set and close to the expected value
731+
assert key_ttl > 0
732+
assert abs(key_ttl - ttl) <= 5

tests/integration/test_cluster_pipelining.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,42 @@ def test_batch_search_with_real_cluster(redis_cluster_url):
152152

153153
finally:
154154
index.delete()
155+
156+
157+
@pytest.mark.requires_cluster
158+
@pytest.mark.parametrize("ttl", [None, 30])
159+
def test_cluster_load_with_ttl(redis_cluster_url, ttl):
160+
"""
161+
Test that TTL is correctly set on keys when using load() with ttl parameter on cluster.
162+
"""
163+
schema_dict = {
164+
"index": {"name": "test-ttl-cluster", "prefix": "ttl", "storage_type": "hash"},
165+
"fields": [
166+
{"name": "id", "type": "tag"},
167+
{"name": "text", "type": "text"},
168+
],
169+
}
170+
171+
schema = IndexSchema.from_dict(schema_dict)
172+
index = SearchIndex(schema, redis_url=redis_cluster_url)
173+
174+
index.create(overwrite=True)
175+
176+
try:
177+
# Load test data with TTL parameter
178+
data = [{"id": "1", "text": "foo"}]
179+
keys = index.load(data, id_field="id", ttl=ttl)
180+
181+
# Check TTL on the loaded key
182+
key_ttl = index.client.ttl(keys[0])
183+
184+
if ttl is None:
185+
# No TTL set, should return -1
186+
assert key_ttl == -1
187+
else:
188+
# TTL should be set and close to the expected value
189+
assert key_ttl > 0
190+
assert abs(key_ttl - ttl) <= 5
191+
192+
finally:
193+
index.delete()

tests/integration/test_search_index.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,24 @@ def test_search_index_validates_query_with_hnsw_algorithm(hnsw_index, sample_dat
697697
)
698698
# Should not raise
699699
hnsw_index.query(query)
700+
701+
702+
@pytest.mark.parametrize("ttl", [None, 30])
703+
def test_search_index_load_with_ttl(index, ttl):
704+
"""Test that TTL is correctly set on keys when using load() with ttl parameter."""
705+
index.create(overwrite=True, drop=True)
706+
707+
# Load test data with TTL parameter
708+
data = [{"id": "1", "test": "foo"}]
709+
keys = index.load(data, id_field="id", ttl=ttl)
710+
711+
# Check TTL on the loaded key
712+
key_ttl = index.client.ttl(keys[0])
713+
714+
if ttl is None:
715+
# No TTL set, should return -1
716+
assert key_ttl == -1
717+
else:
718+
# TTL should be set and close to the expected value
719+
assert key_ttl > 0
720+
assert abs(key_ttl - ttl) <= 5

0 commit comments

Comments
 (0)