Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 54 additions & 31 deletions backend/src/agents/datastore_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from src.llm.llm import LLM
from src.utils.graph_db_utils import execute_query
Expand All @@ -8,15 +9,60 @@
from src.utils.log_publisher import LogPrefix, publish_log_info
from .agent import Agent, agent
from .tool import tool
import json

from src.utils.semantic_layer_builder import get_semantic_layer

logger = logging.getLogger(__name__)

engine = PromptEngine()

graph_schema = engine.load_prompt("graph-schema")
cache = {}

async def generate_cypher_query_core(
question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
) -> str:

async def get_semantic_layer_cache(graph_schema):
global cache
if not cache:
graph_schema = await get_semantic_layer(llm, model)
cache = graph_schema
return cache
else:
return cache

details_to_create_cypher_query = engine.load_prompt(
"details-to-create-cypher-query",
question_intent=question_intent,
operation=operation,
question_params=question_params,
aggregation=aggregation,
sort_order=sort_order,
timeframe=timeframe,
)
try:
graph_schema = await get_semantic_layer_cache(cache)
graph_schema = json.dumps(graph_schema, separators=(",", ":"))

generate_cypher_query_prompt = engine.load_prompt(
"generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
)

llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query,
return_json=True)
json_query = to_json(llm_query)
await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__)
if json_query["query"] == "None":
return "No database query"
db_response = execute_query(json_query["query"])
await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__)
except Exception as e:
logger.error(f"Error during data retrieval: {e}")
raise
response = {
"content": db_response,
"ignore_validation": "false"
}
return json.dumps(response, indent=4)

@tool(
name="generate cypher query",
Expand Down Expand Up @@ -51,39 +97,16 @@
),
},
)
async def generate_query(
question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
) -> str:
details_to_create_cypher_query = engine.load_prompt(
"details-to-create-cypher-query",
question_intent=question_intent,
operation=operation,
question_params=question_params,
aggregation=aggregation,
sort_order=sort_order,
timeframe=timeframe,
)
generate_cypher_query_prompt = engine.load_prompt(
"generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
)
llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True)
json_query = to_json(llm_query)
await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__)
if json_query["query"] == "None":
return "No database query"
db_response = execute_query(json_query["query"])
await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__)
response = {
"content": db_response,
"ignore_validation": "false"
}
return json.dumps(response, indent=4)

async def generate_cypher(question_intent, operation, question_params, aggregation, sort_order,
timeframe, llm: LLM, model) -> str:
return await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order,
timeframe, llm, model)

@agent(
name="DatastoreAgent",
description="This agent is responsible for handling database queries relating to the user's personal data.",
tools=[generate_query],
tools=[generate_cypher],
)
class DatastoreAgent(Agent):
pass
6 changes: 5 additions & 1 deletion backend/src/agents/web_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ async def web_general_search_core(search_query, llm, model) -> str:
if not summary:
continue
if await is_valid_answer(summary, search_query):
return summary
response = {
"content": summary,
"ignore_validation": "false"
}
return json.dumps(response, indent=4)
return "No relevant information found on the internet for the given query."
except Exception as e:
logger.error(f"Error in web_general_search_core: {e}")
Expand Down
2 changes: 1 addition & 1 deletion backend/src/prompts/templates/generate-cypher-query.j2
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ When returning a value, always remove the `-` sign before the number.
Here is the graph schema:
{{ graph_schema }}

The current date and time is {{ current_date }}
The current date and time is {{ current_date }} and the currency of the data is GBP.
156 changes: 0 additions & 156 deletions backend/src/prompts/templates/graph-schema.j2

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WITH
detail : ""
}) as props
RETURN COLLECT({
node_label: node,
label: node,
cypher_representation : "(:" + node + ")",
properties: props
}) AS nodeProperties
6 changes: 0 additions & 6 deletions backend/src/prompts/templates/nodes-query.j2

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ WITH
COLLECT({
name: propertyName,
data_type: propertyTypes,
detail: "A " + propertyName + " is a.. "
detail: ""
}) AS props
RETURN COLLECT({
relationship_type: "[" + REPLACE(rel, "`", "") + "]",
Expand Down
13 changes: 1 addition & 12 deletions backend/src/prompts/templates/relationships-query.j2
Original file line number Diff line number Diff line change
@@ -1,12 +1 @@
CALL apoc.meta.stats() YIELD relTypes
WITH relTypes, keys(relTypes) AS relTypeKeys
UNWIND relTypeKeys AS relTypeKey
WITH relTypeKey, relTypes[relTypeKey] AS count
WHERE relTypeKey CONTAINS ")->(:"
OR relTypeKey CONTAINS "(:"
WITH collect({
label: split(split(relTypeKey, "-")[1], ">")[0],
cypher_representation: relTypeKey,
detail: ""
}) AS paths
RETURN paths
call db.schema.visualization
Loading