Skip to content

Commit 4342bb1

Browse files
committed
tweak some prompts
1 parent 43df527 commit 4342bb1

2 files changed

Lines changed: 17 additions & 8 deletions

File tree

geodini/agents/simple_geocoder_agent.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,14 @@ class RerankingResult:
5353
5454
From the results list, return a JSON object with:
5555
1. "most_probable": The ID of the most relevant result
56-
2. "next_probable": A list of IDs of the next 3 most relevant results in order of relevance
56+
2. "next_probable": A list of IDs of the next 5 most relevant results in order of relevance
5757
5858
Make sure the returned IDs are in the results list.
59+
60+
While reranking, consider the following:
61+
- The query might be a shortened name and the result might be a full name. For example, "United States" or "United States of America" is a match for "USA" or "The US".
62+
- The query might be a informal name and the result might be a formal name. For example, "District of Columbia" is a candidate for "DC" or "Washington, D.C." or even just "Washington". So consider all possible variations of the query when matching.
63+
- Consider geographical context. For example, "London in Canada" should rank "London, Ontario" higher than "London, England" because it is more likely to be the correct answer.
5964
""",
6065
)
6166

@@ -184,13 +189,13 @@ async def search_places(query: str) -> list[Place]:
184189
for place_id in reranked_results.output.next_probable
185190
]
186191
else:
187-
most_probable = []
192+
most_probable = None
188193
next_probable = []
189194

190195
return {
191196
"most_probable": most_probable,
192197
"next_probable": next_probable,
193-
"results": list(results_dict.values()),
198+
"results": [most_probable, *next_probable],
194199
"query": query,
195200
"rephrased_query": rephrased_query.output.query,
196201
"country_code": rephrased_query.output.country_code,

geodini/agents/utils/geocoder.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ def geocode(query: str) -> list[dict[str, Any]]:
4747

4848
try:
4949
with engine.begin() as conn:
50-
if limit is not None:
51-
result = conn.execute(text(sql_query), {"query": query, "limit": limit})
52-
else:
5350
result = conn.execute(text(sql_query), {"query": query})
5451

5552
rows = result.fetchall()
@@ -125,8 +122,8 @@ def build_postgis_query() -> str:
125122
WHEN 'macroregion' THEN 2.0
126123
WHEN 'region' THEN 2.0
127124
WHEN 'macrocounty' THEN 2.0
128-
WHEN 'county' THEN 1.1
129-
WHEN 'localadmin' THEN 1.5
125+
WHEN 'county' THEN 1.0
126+
WHEN 'localadmin' THEN 1.1
130127
WHEN 'locality' THEN 0.9
131128
WHEN 'borough' THEN 0.8
132129
WHEN 'macrohood' THEN 0.8
@@ -164,3 +161,10 @@ def build_postgis_query() -> str:
164161
if postgis_results:
165162
print("\nSample PostgreSQL result:")
166163
pprint(postgis_results[0])
164+
165+
# print the name and similarity score of all results
166+
for result in postgis_results:
167+
print(
168+
f"Name: {result['name']}, Similarity: {result['similarity']}, Type: {result['subtype']}, Country: {result['country']}"
169+
)
170+
print(len(postgis_results))

0 commit comments

Comments
 (0)