Skip to content

Commit af4edfa

Browse files
committed
Utilize diffbot-kg client
1 parent 7bc7c78 commit af4edfa

File tree

3 files changed

+72
-47
lines changed

3 files changed

+72
-47
lines changed

api/app/enhance.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,32 @@
1-
import logging
21
import os
32
from datetime import datetime
4-
from typing import Any, Dict, List, Optional, Union
5-
from urllib.parse import urlencode
3+
from typing import Dict, List, Literal, Optional, Tuple, Union
64

7-
import requests
5+
from diffbot_kg import DiffbotEnhanceClient
86
from utils import graph
97

108
CATEGORY_THRESHOLD = 0.50
119
params = []
1210

1311
DIFF_TOKEN = os.environ["DIFFBOT_API_KEY"]
1412

13+
client = DiffbotEnhanceClient(DIFF_TOKEN)
1514

1615
def get_datetime(value: Optional[Union[str, int, float]]) -> datetime:
1716
if not value:
1817
return value
1918
return datetime.fromtimestamp(float(value) / 1000.0)
2019

2120

22-
def process_entities(entity: str, type: str) -> Dict[str, Any]:
21+
22+
async def process_entities(entity: str, type: str) -> Tuple[str, List[Dict]]:
2323
"""
2424
Fetch relevant articles from Diffbot KG endpoint
2525
"""
26-
search_host = "https://kg.diffbot.com/kg/v3/enhance?"
27-
params = {"type": type, "name": entity, "token": DIFF_TOKEN}
28-
encoded_query = urlencode(params)
29-
url = f"{search_host}{encoded_query}"
30-
return entity, requests.get(url).json()
26+
params = {"type": type, "name": entity}
27+
response = await client.enhance(params)
28+
29+
return entity, response.entities
3130

3231

3332
def get_people_params(row: Dict) -> Optional[Dict]:

api/app/importing.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,40 @@
11
import logging
22
import os
3-
from typing import Any, Dict, List, Optional
3+
from typing import Dict, List, Optional
44

5-
import requests
5+
from diffbot_kg import DiffbotSearchClient
66
from utils import embeddings, text_splitter
77

88
CATEGORY_THRESHOLD = 0.50
99
params = []
1010

1111
DIFF_TOKEN = os.environ["DIFFBOT_API_KEY"]
1212

13+
client = DiffbotSearchClient(token=DIFF_TOKEN)
1314

14-
def get_articles(
15-
query: Optional[str], tag: Optional[str], size: int = 5, offset: int = 0
16-
) -> Dict[str, Any]:
15+
16+
async def get_articles(
17+
query: Optional[str],
18+
tag: Optional[str],
19+
size: int = 5,
20+
offset: int = 0,
21+
) -> List[Dict]:
1722
"""
1823
Fetch relevant articles from Diffbot KG endpoint
1924
"""
25+
search_query = "type:Article language:en sortBy:date"
26+
if query:
27+
search_query += f' strict:text:"{query}"'
28+
if tag:
29+
search_query += f' tags.label:"{tag}"'
30+
31+
params = {"query": search_query, "size": size, "offset": offset}
32+
33+
logging.info(f"Fetching articles with params: {params}")
34+
2035
try:
21-
search_host = "https://kg.diffbot.com/kg/v3/dql?"
22-
search_query = f'query=type%3AArticle+strict%3Alanguage%3A"en"+sortBy%3Adate'
23-
if query:
24-
search_query += f'+text%3A"{query}"'
25-
if tag:
26-
search_query += f'+tags.label%3A"{tag}"'
27-
url = (
28-
f"{search_host}{search_query}&token={DIFF_TOKEN}&from={offset}&size={size}"
29-
)
30-
return requests.get(url).json()
36+
response = await client.search(params)
37+
return response.entities
3138
except Exception as ex:
3239
raise ex
3340

api/app/main.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
import os
34
from concurrent.futures import ThreadPoolExecutor
@@ -20,7 +21,7 @@
2021
)
2122

2223
# Multithreading for Diffbot API
23-
MAX_WORKERS = min(os.cpu_count() * 5, 20)
24+
MAX_WORKERS = min((os.cpu_count() or 1) * 5, 20)
2425

2526
app = FastAPI()
2627

@@ -35,21 +36,24 @@
3536

3637

3738
@app.post("/import_articles/")
38-
def import_articles_endpoint(article_data: ArticleData) -> int:
39+
async def import_articles_endpoint(article_data: ArticleData) -> int:
3940
logging.info(f"Starting to process article import with params: {article_data}")
40-
if not article_data.query and not article_data.tag:
41+
if not article_data.query and not article_data.category and not article_data.tag:
4142
raise HTTPException(
42-
status_code=500, detail="Either `query` or `tag` must be provided"
43+
status_code=500,
44+
detail="Either `query` or `category` or `tag` must be provided",
4345
)
44-
data = get_articles(article_data.query, article_data.tag, article_data.size)
45-
logging.info(f"Articles fetched: {len(data['data'])} articles.")
46+
articles = await get_articles(
47+
article_data.query, article_data.category, article_data.tag, article_data.size
48+
)
49+
logging.info(f"Articles fetched: {len(articles)} articles.")
4650
try:
47-
params = process_params(data)
51+
params = process_params(articles)
4852
except Exception as e:
4953
# You could log the exception here if needed
50-
raise HTTPException(status_code=500, detail=e)
54+
raise HTTPException(status_code=500, detail=e) from e
5155
graph.query(import_cypher_query, params={"data": params})
52-
logging.info(f"Article import query executed successfully.")
56+
logging.info("Article import query executed successfully.")
5357
return len(params)
5458

5559

@@ -124,26 +128,41 @@ def fetch_unprocessed_count(count_data: CountData) -> int:
124128

125129

126130
@app.post("/enhance_entities/")
127-
def enhance_entities(entity_data: EntityData) -> str:
131+
async def enhance_entities(entity_data: EntityData) -> str:
128132
entities = graph.query(
129133
"MATCH (a:Person|Organization) WHERE a.processed IS NULL "
130134
"WITH a LIMIT toInteger($limit) "
131135
"RETURN [el in labels(a) WHERE el <> '__Entity__' | el][0] "
132136
"AS label, collect(a.name) AS entities",
133137
params={"limit": entity_data.size},
134138
)
135-
enhanced_data = []
136-
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
137-
# Submitting all tasks and creating a list of future objects
138-
for row in entities:
139-
futures = [
140-
executor.submit(process_entities, el, row["label"])
141-
for el in row["entities"]
142-
]
143-
144-
for future in futures:
145-
response = future.result()
146-
enhanced_data.append(response)
139+
enhanced_data = {}
140+
141+
# Run the process_entities function in a TaskGroup
142+
143+
queue = asyncio.Queue()
144+
for row in entities:
145+
for el in row["entities"]:
146+
await queue.put((el, row["label"]))
147+
148+
async def worker():
149+
while True:
150+
el, label = await queue.get()
151+
try:
152+
response = await process_entities(el, label)
153+
enhanced_data[response[0]] = response[1]
154+
finally:
155+
queue.task_done()
156+
157+
tasks = []
158+
for _ in range(4): # Number of workers
159+
tasks.append(asyncio.create_task(worker()))
160+
161+
await queue.join()
162+
163+
for task in tasks:
164+
task.cancel()
165+
147166
store_enhanced_data(enhanced_data)
148167
return "Finished enhancing entities."
149168

0 commit comments

Comments
 (0)