Skip to content

Commit 00f5be4

Browse files
rbs333vladvildanov
andauthored
adds scorer to AggregateRequest (#3409)
* adds scorer to AggregateRequest * fix linting * update tests for BM25 * enum for aggregation scorer * update signature * revert back to string input --------- Co-authored-by: Vladyslav Vildanov <[email protected]>
1 parent 4c4d4af commit 00f5be4

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

redis/commands/search/aggregation.py

+16
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None:
112112
self._cursor = []
113113
self._dialect = None
114114
self._add_scores = False
115+
self._scorer = "TFIDF"
115116

116117
def load(self, *fields: List[str]) -> "AggregateRequest":
117118
"""
@@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest":
300301
self._add_scores = True
301302
return self
302303

304+
def scorer(self, scorer: str) -> "AggregateRequest":
305+
"""
306+
Use a different scoring function to evaluate document relevance.
307+
Default is `TFIDF`.
308+
309+
:param scorer: The scoring function to use
310+
(e.g. `TFIDF.DOCNORM` or `BM25`)
311+
"""
312+
self._scorer = scorer
313+
return self
314+
303315
def verbatim(self) -> "AggregateRequest":
304316
self._verbatim = True
305317
return self
@@ -323,6 +335,9 @@ def build_args(self) -> List[str]:
323335
if self._verbatim:
324336
ret.append("VERBATIM")
325337

338+
if self._scorer:
339+
ret.extend(["SCORER", self._scorer])
340+
326341
if self._add_scores:
327342
ret.append("ADDSCORES")
328343

@@ -332,6 +347,7 @@ def build_args(self) -> List[str]:
332347
if self._loadall:
333348
ret.append("LOAD")
334349
ret.append("*")
350+
335351
elif self._loadfields:
336352
ret.append("LOAD")
337353
ret.append(str(len(self._loadfields)))

tests/test_asyncio/test_search.py

+55
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,61 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis):
15561556
assert res.rows[1] == ["__score", "0.2"]
15571557

15581558

1559+
@pytest.mark.redismod
1560+
@skip_ifmodversion_lt("2.10.05", "search")
1561+
async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
1562+
assert await decoded_r.ft().create_index(
1563+
(
1564+
TextField("name", sortable=True, weight=5.0),
1565+
TextField("description", sortable=True, weight=5.0),
1566+
VectorField(
1567+
"vector",
1568+
"HNSW",
1569+
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
1570+
),
1571+
)
1572+
)
1573+
1574+
assert await decoded_r.hset(
1575+
"doc1",
1576+
mapping={
1577+
"name": "cat book",
1578+
"description": "an animal book about cats",
1579+
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
1580+
},
1581+
)
1582+
assert await decoded_r.hset(
1583+
"doc2",
1584+
mapping={
1585+
"name": "dog book",
1586+
"description": "an animal book about dogs",
1587+
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
1588+
},
1589+
)
1590+
1591+
query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
1592+
req = (
1593+
aggregations.AggregateRequest(query_string)
1594+
.scorer("BM25")
1595+
.add_scores()
1596+
.apply(hybrid_score="@__score + @dist")
1597+
.load("*")
1598+
.dialect(4)
1599+
)
1600+
1601+
res = await decoded_r.ft().aggregate(
1602+
req,
1603+
query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()},
1604+
)
1605+
1606+
if isinstance(res, dict):
1607+
assert len(res["results"]) == 2
1608+
else:
1609+
assert len(res.rows) == 2
1610+
for row in res.rows:
1611+
len(row) == 6
1612+
1613+
15591614
@pytest.mark.redismod
15601615
@skip_if_redis_enterprise()
15611616
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

tests/test_search.py

+55
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,61 @@ def test_aggregations_add_scores(client):
14661466
assert res.rows[1] == ["__score", "0.2"]
14671467

14681468

1469+
@pytest.mark.redismod
1470+
@skip_ifmodversion_lt("2.10.05", "search")
1471+
async def test_aggregations_hybrid_scoring(client):
1472+
client.ft().create_index(
1473+
(
1474+
TextField("name", sortable=True, weight=5.0),
1475+
TextField("description", sortable=True, weight=5.0),
1476+
VectorField(
1477+
"vector",
1478+
"HNSW",
1479+
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
1480+
),
1481+
)
1482+
)
1483+
1484+
client.hset(
1485+
"doc1",
1486+
mapping={
1487+
"name": "cat book",
1488+
"description": "an animal book about cats",
1489+
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
1490+
},
1491+
)
1492+
client.hset(
1493+
"doc2",
1494+
mapping={
1495+
"name": "dog book",
1496+
"description": "an animal book about dogs",
1497+
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
1498+
},
1499+
)
1500+
1501+
query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
1502+
req = (
1503+
aggregations.AggregateRequest(query_string)
1504+
.scorer("BM25")
1505+
.add_scores()
1506+
.apply(hybrid_score="@__score + @dist")
1507+
.load("*")
1508+
.dialect(4)
1509+
)
1510+
1511+
res = client.ft().aggregate(
1512+
req,
1513+
query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()},
1514+
)
1515+
1516+
if isinstance(res, dict):
1517+
assert len(res["results"]) == 2
1518+
else:
1519+
assert len(res.rows) == 2
1520+
for row in res.rows:
1521+
len(row) == 6
1522+
1523+
14691524
@pytest.mark.redismod
14701525
@skip_ifmodversion_lt("2.0.0", "search")
14711526
def test_index_definition(client):

0 commit comments

Comments
 (0)