Skip to content

Commit ac3ac11

Browse files
authored
feat: add config options, simplification settings, geometry retrieval by id (#11)
* feat: add config options, simplification settings, geometry retrieval by id * fix: update sql param and model key selection
1 parent 123eac1 commit ac3ac11

13 files changed

Lines changed: 254 additions & 40 deletions

geodini/agents/geocoder_agent.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import logging
4+
import os
45
import time
56
import traceback
67
from concurrent.futures import ThreadPoolExecutor
@@ -15,6 +16,7 @@
1516
from shapely.ops import transform
1617

1718
from geodini import hookspecs, lib
19+
from geodini.models import MODEL_LIGHT
1820
from geodini.agents.utils.postgis_exec import (
1921
clear_geometries_table,
2022
create_geometries_table,
@@ -72,9 +74,12 @@ class RerankingResult:
7274
most_probable: str
7375

7476

77+
_GEOMETRY_PRECISION = int(os.getenv("GEOMETRY_DECIMAL_PRECISION", "6"))
78+
79+
7580
class RoundedFloat(float):
7681
def __repr__(self):
77-
return f"{self:.2f}"
82+
return f"{self:.{_GEOMETRY_PRECISION}f}"
7883

7984

8085
def recursively_convert(obj):
@@ -93,8 +98,11 @@ def clip_coordinates_with_rounding(geojson: dict[str, Any]) -> dict[str, Any]:
9398

9499

95100
def simplify_geometry(
96-
geometry: dict[str, Any], tolerance_m: float = 10000
101+
geometry: dict[str, Any],
102+
tolerance_m: float | None = None,
97103
) -> dict[str, Any]:
104+
if tolerance_m is None:
105+
tolerance_m = float(os.getenv("GEOMETRY_AGENT_SIMPLIFY_TOLERANCE", "10000"))
98106
to_meters = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True).transform
99107
to_degrees = Transformer.from_crs(
100108
"EPSG:3857", "EPSG:4326", always_xy=True
@@ -111,8 +119,28 @@ def simplify_geometry(
111119
return geojson
112120

113121

122+
def simplify_geometry_to_size(
123+
geometry: dict[str, Any],
124+
max_bytes: int,
125+
initial_tolerance_m: float = 1000,
126+
max_iterations: int = 10,
127+
) -> dict[str, Any]:
128+
"""Iteratively simplify geometry until JSON representation fits within max_bytes."""
129+
tolerance = initial_tolerance_m
130+
result = simplify_geometry(geometry, tolerance_m=tolerance)
131+
132+
for _ in range(max_iterations):
133+
serialized = json.dumps(result)
134+
if len(serialized.encode("utf-8")) <= max_bytes:
135+
return result
136+
tolerance *= 2
137+
result = simplify_geometry(geometry, tolerance_m=tolerance)
138+
139+
return result
140+
141+
114142
rephrase_agent = Agent(
115-
"openai:gpt-4.1-mini",
143+
MODEL_LIGHT,
116144
output_type=RephrasedQuery,
117145
system_prompt="""
118146
Given the search query, rephrase it to be more specific and accurate. We will be using this query to search for places in the overture database. So it helps to make the query be full formal name of the place.
@@ -131,7 +159,7 @@ def simplify_geometry(
131159

132160

133161
routing_agent = Agent(
134-
"openai:gpt-4.1-mini",
162+
MODEL_LIGHT,
135163
output_type=RoutingResult,
136164
system_prompt="""
137165
Given the search query, determine if it is a simple or complex query.
@@ -143,7 +171,7 @@ def simplify_geometry(
143171

144172

145173
complex_geocode_query_agent = Agent(
146-
"openai:gpt-4.1-mini",
174+
MODEL_LIGHT,
147175
output_type=ComplexGeocodeResult,
148176
system_prompt="""
149177
Given the search query, return ALL relevant places to search for in the query as queries.
@@ -170,8 +198,7 @@ def simplify_geometry(
170198

171199

172200
rerank_agent = Agent(
173-
# 4o-mini is smarter than 3.5-turbo. And does better in edge cases.
174-
"openai:gpt-4.1-mini",
201+
MODEL_LIGHT,
175202
output_type=RerankingResult,
176203
system_prompt="""
177204
Given the search query and results, rank them in order of
@@ -312,6 +339,7 @@ async def simple_geocode(query: str, simplify_geometry: bool = True) -> dict:
312339
"query": query,
313340
"results": [
314341
{
342+
"id": most_probable["id"] if most_probable else None,
315343
"geometry": most_probable["geometry"] if most_probable else None,
316344
"country": most_probable["country"] if most_probable else None,
317345
"name": most_probable["name"] if most_probable else query,

geodini/agents/utils/geocoder.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def build_postgis_query(simplify_geometry: bool = True) -> str:
107107
"""Build PostgreSQL query for searching overture unified data using trigram similarity"""
108108

109109
# Choose geometry function based on whether to simplify
110-
geometry_func = "ST_AsGeoJSON(ST_Simplify(geometry, 0.001))" if simplify_geometry else "ST_AsGeoJSON(geometry)"
110+
db_tolerance = float(os.getenv("GEOMETRY_DB_SIMPLIFY_TOLERANCE", "0.001"))
111+
geometry_func = f"ST_AsGeoJSON(ST_Simplify(geometry, {db_tolerance}))" if simplify_geometry else "ST_AsGeoJSON(geometry)"
111112

112113
sql_query = f"""
113114
SELECT
@@ -162,6 +163,48 @@ def build_postgis_query(simplify_geometry: bool = True) -> str:
162163
return sql_query
163164

164165

166+
def get_geometry_by_id(
167+
division_id: str, simplify: bool = True
168+
) -> dict[str, Any] | None:
169+
"""Retrieve geometry for a specific division by ID."""
170+
engine = get_postgis_engine()
171+
db_tolerance = float(os.getenv("GEOMETRY_DB_SIMPLIFY_TOLERANCE", "0.001"))
172+
173+
geometry_func = (
174+
f"ST_AsGeoJSON(ST_Simplify(geometry, {db_tolerance}))"
175+
if simplify
176+
else "ST_AsGeoJSON(geometry)"
177+
)
178+
179+
sql_query = f"""
180+
SELECT
181+
id,
182+
COALESCE(common_en_name, primary_name) as name,
183+
subtype, country,
184+
{geometry_func} as geometry
185+
FROM all_geometries
186+
WHERE id = :id
187+
LIMIT 1
188+
"""
189+
190+
try:
191+
with engine.begin() as conn:
192+
result = conn.execute(text(sql_query), {"id": division_id})
193+
row = result.fetchone()
194+
if row:
195+
return {
196+
"id": row.id,
197+
"name": row.name,
198+
"subtype": row.subtype,
199+
"country": row.country,
200+
"geometry": json.loads(row.geometry) if row.geometry else None,
201+
}
202+
except Exception as e:
203+
logger.error(f"Error retrieving geometry by ID: {e}")
204+
205+
return None
206+
207+
165208
if __name__ == "__main__":
166209
import time
167210

geodini/agents/utils/postgis_exec.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import psycopg2
66
from pydantic_ai import Agent
77

8+
from geodini.models import MODEL_HEAVY
9+
810

911
@dataclass
1012
class PostGISResult:
@@ -96,33 +98,34 @@ def search_subtype_within_aoi(subtype: str, aoi: dict) -> list[dict]:
9698
# aoi is the geojson geometry as dict as returned from run_postgis_query
9799
# return a list of dictionaries of the form:
98100
# { "geometry": result_geometry_as_json_dict, "country": country_name }
101+
db_tolerance = float(os.getenv("GEOMETRY_DB_SIMPLIFY_TOLERANCE", "0.001"))
99102
conn = get_postgis_connection()
100103
try:
101104
with conn.cursor() as cur:
102105
# Convert AOI dict to GeoJSON string
103106
aoi_geojson = json.dumps(aoi)
104-
107+
105108
# SQL query to find places of given subtype within the AOI
106109
sql_query = """
107-
SELECT
108-
ST_AsGeoJSON(ST_Simplify(geometry, 0.001)) as geometry,
110+
SELECT
111+
ST_AsGeoJSON(ST_Simplify(geometry, %s)) as geometry,
109112
country,
110113
COALESCE(common_en_name, primary_name) as name
111114
FROM all_geometries
112-
WHERE
115+
WHERE
113116
source_type = 'division'
114117
AND subtype = %s
115118
AND geometry IS NOT NULL
116119
AND ST_Within(
117120
geometry,
118121
ST_GeomFromGeoJSON(%s)
119122
)
120-
ORDER BY
123+
ORDER BY
121124
ST_Area(geometry) DESC
122125
LIMIT 100
123126
"""
124-
125-
cur.execute(sql_query, (subtype, aoi_geojson))
127+
128+
cur.execute(sql_query, (db_tolerance, subtype, aoi_geojson))
126129
results = cur.fetchall()
127130

128131
# Convert results to expected format
@@ -141,7 +144,7 @@ def search_subtype_within_aoi(subtype: str, aoi: dict) -> list[dict]:
141144

142145

143146
postgis_agent = Agent(
144-
"openai:gpt-4.1",
147+
MODEL_HEAVY,
145148
output_type=PostGISResult,
146149
system_prompt="""
147150
You are a helpful assistant that can help with PostGIS queries.
@@ -185,7 +188,7 @@ def search_subtype_within_aoi(subtype: str, aoi: dict) -> list[dict]:
185188

186189

187190
postgis_query_judgement_agent = Agent(
188-
"openai:gpt-4o",
191+
MODEL_HEAVY,
189192
output_type=PostGISResult,
190193
system_prompt="""
191194
You are a helpful assistant that can help with PostGIS queries.

geodini/api/api.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from fastapi import FastAPI, HTTPException, Query
88
from fastapi.middleware.cors import CORSMiddleware
99

10-
from geodini.agents.geocoder_agent import search
10+
from geodini.agents.geocoder_agent import search, simplify_geometry_to_size
11+
from geodini.agents.utils.geocoder import get_geometry_by_id
1112
from geodini.agents.utils.postgis_exec import get_postgis_connection
1213
from geodini.cache import init_cache
1314

@@ -50,10 +51,11 @@ async def root():
5051
@app.get("/search")
5152
async def search_endpoint(
5253
query: str = Query(..., description="The search query string"),
54+
max_bytes: int = Query(0, description="Max geometry size in bytes (0 = use server default)"),
5355
) -> dict[str, Any]:
5456
"""
5557
Unified search endpoint that handles both simple and complex queries.
56-
58+
5759
Simple queries: "New York City", "London in Canada", "India"
5860
Complex queries: "India and Sri Lanka", "Within 100km of Mumbai", "France north of Paris"
5961
@@ -65,6 +67,15 @@ async def search_endpoint(
6567
# Get result from unified search
6668
result = await search(query)
6769

70+
# Apply byte-threshold simplification if requested
71+
effective_max_bytes = max_bytes or int(os.getenv("GEOMETRY_MAX_BYTES", "0"))
72+
if effective_max_bytes > 0 and result.get("results"):
73+
for r in result["results"]:
74+
if r.get("geometry"):
75+
r["geometry"] = simplify_geometry_to_size(
76+
r["geometry"], max_bytes=effective_max_bytes
77+
)
78+
6879
return result
6980

7081
except Exception as e:
@@ -74,6 +85,25 @@ async def search_endpoint(
7485
)
7586

7687

88+
@app.get("/geometry/{geometry_id}")
89+
async def get_geometry_endpoint(
90+
geometry_id: str,
91+
simplify: bool = Query(True, description="Whether to simplify the geometry"),
92+
max_bytes: int = Query(0, description="Max geometry size in bytes (0 = no limit)"),
93+
) -> dict[str, Any]:
94+
"""Retrieve full geometry by division ID."""
95+
result = get_geometry_by_id(geometry_id, simplify=simplify)
96+
if result is None:
97+
raise HTTPException(status_code=404, detail="Geometry not found")
98+
99+
if max_bytes > 0 and result.get("geometry"):
100+
result["geometry"] = simplify_geometry_to_size(
101+
result["geometry"], max_bytes=max_bytes
102+
)
103+
104+
return result
105+
106+
77107
@app.get("/health")
78108
async def health_check():
79109
"""Health check endpoint to verify the API is running."""

geodini/api/mcp_server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
from mcp.server.fastmcp import FastMCP
24

35
from geodini.agents.geocoder_agent import search, simplify_geometry
@@ -10,7 +12,8 @@ async def geocode(query: str) -> str:
1012
"""Geocode a query and download the geojson geometry"""
1113
result = await search(query)
1214
if "result" in result and "geometry" in result["result"]:
13-
return simplify_geometry(result["result"]["geometry"], tolerance_m=1000)
15+
mcp_tolerance = float(os.getenv("GEOMETRY_MCP_SIMPLIFY_TOLERANCE", "1000"))
16+
return simplify_geometry(result["result"]["geometry"], tolerance_m=mcp_tolerance)
1417
return "No geometry found for query"
1518

1619

geodini/models.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Centralized LLM model configuration.
2+
3+
Two tiers control all agent models:
4+
- MODEL_LIGHT: fast/cheap tasks (rephrase, routing, complex parse, rerank)
5+
- MODEL_HEAVY: SQL generation and error correction
6+
7+
Each model's provider is extracted from its prefix (e.g. "anthropic"
8+
from "anthropic:claude-sonnet-4-6"). The provider-specific API key env
9+
var (e.g. ANTHROPIC_API_KEY) takes priority; LLM_API_KEY is used as a
10+
shared fallback for any provider that doesn't have its own key set.
11+
12+
Supported provider prefixes: openai, anthropic, groq, mistral.
13+
Bedrock uses AWS creds (AWS_ACCESS_KEY_ID, etc.).
14+
Ollama uses OLLAMA_BASE_URL (no API key needed).
15+
"""
16+
17+
import os
18+
19+
MODEL_LIGHT = os.getenv("MODEL_LIGHT", "anthropic:claude-haiku-4-5")
20+
MODEL_HEAVY = os.getenv("MODEL_HEAVY", "anthropic:claude-sonnet-4-6")
21+
22+
KEY_MAP = {
23+
"openai": "OPENAI_API_KEY",
24+
"anthropic": "ANTHROPIC_API_KEY",
25+
"groq": "GROQ_API_KEY",
26+
"mistral": "MISTRAL_API_KEY",
27+
}
28+
29+
llm_fallback_key = os.getenv("LLM_API_KEY", "")
30+
for model in (MODEL_LIGHT, MODEL_HEAVY):
31+
provider = model.split(":")[0]
32+
if provider in KEY_MAP and llm_fallback_key:
33+
os.environ.setdefault(KEY_MAP[provider], llm_fallback_key)

helm/geodini/templates/_helpers.tpl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ Create the name of the service account to use
6161
{{- end -}}
6262
{{- end -}}
6363

64+
{{/*
65+
Return the secret name to use for credentials.
66+
*/}}
67+
{{- define "geodini.secretName" -}}
68+
{{- if .Values.secrets.existingSecret -}}
69+
{{- .Values.secrets.existingSecret -}}
70+
{{- else -}}
71+
{{- include "geodini.fullname" . }}-geodini-secret
72+
{{- end -}}
73+
{{- end -}}
74+
6475
{{/*
6576
Return the appropriate apiVersion for deployment.
6677
*/}}

0 commit comments

Comments
 (0)