-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathdatastore_agent.py
More file actions
103 lines (93 loc) · 3.63 KB
/
datastore_agent.py
File metadata and controls
103 lines (93 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import json
import logging
from src.llm.llm import LLM
from src.utils.graph_db_utils import execute_query
from src.prompts import PromptEngine
from datetime import datetime
from src.utils import to_json
from .agent_types import Parameter
from src.utils.log_publisher import LogPrefix, publish_log_info
from .agent import Agent, agent
from .tool import tool
from src.agents.semantic_layer_builder import get_semantic_layer
logger = logging.getLogger(__name__)
engine = PromptEngine()
cache = {}
@tool(
name="generate cypher query",
description="Generate Cypher query if the category is data driven, based on the operation to be performed",
parameters={
"question_intent": Parameter(
type="string",
description="The intent the question will be based on",
),
"operation": Parameter(
type="string",
description="The operation the cypher query will have to perform",
),
"question_params": Parameter(
type="string",
description="""
The specific parameters required for the question to be answered with the question_intent
or none if no params required
""",
),
"aggregation": Parameter(
type="string",
description="Any aggregation that is required to answer the question or none if no aggregation is needed",
),
"sort_order": Parameter(
type="string",
description="The order a list should be sorted in or none if no sort_order is needed",
),
"timeframe": Parameter(
type="string",
description="string of the timeframe to be considered or none if no timeframe is needed",
),
},
)
async def generate_query(
question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
) -> str:
async def get_semantic_layer_cache(graph_schema):
global cache
if not cache:
graph_schema = await get_semantic_layer(llm, model)
cache = graph_schema
return cache
else:
return cache
details_to_create_cypher_query = engine.load_prompt(
"details-to-create-cypher-query",
question_intent=question_intent,
operation=operation,
question_params=question_params,
aggregation=aggregation,
sort_order=sort_order,
timeframe=timeframe,
)
try:
graph_schema = await get_semantic_layer_cache(cache)
graph_schema = json.dumps(graph_schema, separators=(",", ":"))
generate_cypher_query_prompt = engine.load_prompt(
"generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
)
logger.info(f"generate cypher query prompt: {generate_cypher_query_prompt}")
llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True)
json_query = to_json(llm_query)
await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__)
if json_query["query"] == "None":
return "No database query"
db_response = execute_query(json_query["query"])
await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__)
except Exception as e:
logger.error(f"Error during data retrieval: {e}")
raise
return str(db_response)
@agent(
name="DatastoreAgent",
description="This agent is responsible for handling database queries relating to the user's personal data.",
tools=[generate_query],
)
class DatastoreAgent(Agent):
pass