|
| 1 | +import asyncio |
1 | 2 | import logging |
2 | 3 | import os |
3 | 4 | from concurrent.futures import ThreadPoolExecutor |
|
20 | 21 | ) |
21 | 22 |
|
22 | 23 | # Multithreading for Diffbot API |
23 | | -MAX_WORKERS = min(os.cpu_count() * 5, 20) |
| 24 | +MAX_WORKERS = min((os.cpu_count() or 1) * 5, 20) |
24 | 25 |
|
25 | 26 | app = FastAPI() |
26 | 27 |
|
|
35 | 36 |
|
36 | 37 |
|
37 | 38 | @app.post("/import_articles/") |
38 | | -def import_articles_endpoint(article_data: ArticleData) -> int: |
| 39 | +async def import_articles_endpoint(article_data: ArticleData) -> int: |
39 | 40 | logging.info(f"Starting to process article import with params: {article_data}") |
40 | | - if not article_data.query and not article_data.tag: |
| 41 | + if not article_data.query and not article_data.category and not article_data.tag: |
41 | 42 | raise HTTPException( |
42 | | - status_code=500, detail="Either `query` or `tag` must be provided" |
| 43 | + status_code=500, |
| 44 | + detail="Either `query` or `category` or `tag` must be provided", |
43 | 45 | ) |
44 | | - data = get_articles(article_data.query, article_data.tag, article_data.size) |
45 | | - logging.info(f"Articles fetched: {len(data['data'])} articles.") |
| 46 | + articles = await get_articles( |
| 47 | + article_data.query, article_data.category, article_data.tag, article_data.size |
| 48 | + ) |
| 49 | + logging.info(f"Articles fetched: {len(articles)} articles.") |
46 | 50 | try: |
47 | | - params = process_params(data) |
| 51 | + params = process_params(articles) |
48 | 52 | except Exception as e: |
49 | 53 | # You could log the exception here if needed |
50 | | - raise HTTPException(status_code=500, detail=e) |
| 54 | + raise HTTPException(status_code=500, detail=e) from e |
51 | 55 | graph.query(import_cypher_query, params={"data": params}) |
52 | | - logging.info(f"Article import query executed successfully.") |
| 56 | + logging.info("Article import query executed successfully.") |
53 | 57 | return len(params) |
54 | 58 |
|
55 | 59 |
|
@@ -124,26 +128,41 @@ def fetch_unprocessed_count(count_data: CountData) -> int: |
124 | 128 |
|
125 | 129 |
|
126 | 130 | @app.post("/enhance_entities/") |
127 | | -def enhance_entities(entity_data: EntityData) -> str: |
| 131 | +async def enhance_entities(entity_data: EntityData) -> str: |
128 | 132 | entities = graph.query( |
129 | 133 | "MATCH (a:Person|Organization) WHERE a.processed IS NULL " |
130 | 134 | "WITH a LIMIT toInteger($limit) " |
131 | 135 | "RETURN [el in labels(a) WHERE el <> '__Entity__' | el][0] " |
132 | 136 | "AS label, collect(a.name) AS entities", |
133 | 137 | params={"limit": entity_data.size}, |
134 | 138 | ) |
135 | | - enhanced_data = [] |
136 | | - with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
137 | | - # Submitting all tasks and creating a list of future objects |
138 | | - for row in entities: |
139 | | - futures = [ |
140 | | - executor.submit(process_entities, el, row["label"]) |
141 | | - for el in row["entities"] |
142 | | - ] |
143 | | - |
144 | | - for future in futures: |
145 | | - response = future.result() |
146 | | - enhanced_data.append(response) |
| 139 | + enhanced_data = {} |
| 140 | + |
| 141 | + # Run the process_entities function in a TaskGroup |
| 142 | + |
| 143 | + queue = asyncio.Queue() |
| 144 | + for row in entities: |
| 145 | + for el in row["entities"]: |
| 146 | + await queue.put((el, row["label"])) |
| 147 | + |
| 148 | + async def worker(): |
| 149 | + while True: |
| 150 | + el, label = await queue.get() |
| 151 | + try: |
| 152 | + response = await process_entities(el, label) |
| 153 | + enhanced_data[response[0]] = response[1] |
| 154 | + finally: |
| 155 | + queue.task_done() |
| 156 | + |
| 157 | + tasks = [] |
| 158 | + for _ in range(4): # Number of workers |
| 159 | + tasks.append(asyncio.create_task(worker())) |
| 160 | + |
| 161 | + await queue.join() |
| 162 | + |
| 163 | + for task in tasks: |
| 164 | + task.cancel() |
| 165 | + |
147 | 166 | store_enhanced_data(enhanced_data) |
148 | 167 | return "Finished enhancing entities." |
149 | 168 |
|
|
0 commit comments