diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 76eb58c2d7..f316ed9f14 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -13,6 +13,7 @@ class Field: SORTABLE = "SORTABLE" NOINDEX = "NOINDEX" AS = "AS" + GEOSHAPE = "GEOSHAPE" def __init__( self, @@ -91,6 +92,21 @@ def __init__(self, name: str, **kwargs): Field.__init__(self, name, args=[Field.NUMERIC], **kwargs) +class GeoShapeField(Field): + """ + GeoShapeField is used to enable within/contain indexing/searching + """ + + SPHERICAL = "SPHERICAL" + FLAT = "FLAT" + + def __init__(self, name: str, coord_system=None, **kwargs): + args = [Field.GEOSHAPE] + if coord_system: + args.append(coord_system) + Field.__init__(self, name, args=args, **kwargs) + + class GeoField(Field): """ GeoField is used to define a geo-indexing field in a schema definition diff --git a/tests/test_search.py b/tests/test_search.py index 9bbfc3c696..7469123453 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -13,6 +13,7 @@ from redis.commands.search import Search from redis.commands.search.field import ( GeoField, + GeoShapeField, NumericField, TagField, TextField, @@ -2266,3 +2267,20 @@ def test_query_timeout(r: redis.Redis): q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): r.ft().search(q2) + + +@pytest.mark.redismod +def test_geoshape(client: redis.Redis): + client.ft().create_index((GeoShapeField("geom", GeoShapeField.FLAT))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + client.hset("small", "geom", "POLYGON((1 1, 1 100, 100 100, 100 1, 1 1))") + client.hset("large", "geom", "POLYGON((1 1, 1 200, 200 200, 200 1, 1 1))") + q1 = Query("@geom:[WITHIN $poly]").dialect(3) + qp1 = {"poly": "POLYGON((0 0, 0 150, 150 150, 150 0, 0 0))"} + q2 = Query("@geom:[CONTAINS $poly]").dialect(3) + qp2 = {"poly": "POLYGON((2 2, 2 50, 50 50, 50 2, 2 2))"} + result = client.ft().search(q1, query_params=qp1) + assert len(result.docs) == 1 + assert result.docs[0]["id"] == "small" + result = client.ft().search(q2, query_params=qp2) + assert len(result.docs) == 2