From 66a60f9e461c994ceaf59eaf647269b8b22ed7f5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Fri, 30 Aug 2024 14:54:12 +0100
Subject: [PATCH 01/48] await semantic layer builder

---
 backend/src/agents/datastore_agent.py         |  31 +++--
 backend/src/agents/semantic_layer_builder.py  | 115 ++++++++++++++++++
 .../templates/node-property-cypher-query.j2   |   2 +-
 backend/src/utils/semantic_layer_builder.py   |  80 ------------
 4 files changed, 134 insertions(+), 94 deletions(-)
 create mode 100644 backend/src/agents/semantic_layer_builder.py
 delete mode 100644 backend/src/utils/semantic_layer_builder.py

diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py
index 98c87d77d..2a11195c4 100644
--- a/backend/src/agents/datastore_agent.py
+++ b/backend/src/agents/datastore_agent.py
@@ -8,14 +8,12 @@
 from src.utils.log_publisher import LogPrefix, publish_log_info
 from .agent import Agent, agent
 from .tool import tool
-
+from src.agents.semantic_layer_builder import get_semantic_layer
 
 logger = logging.getLogger(__name__)
 
 engine = PromptEngine()
 
-graph_schema = engine.load_prompt("graph-schema")
-
 
 @tool(
     name="generate cypher query",
@@ -62,16 +60,23 @@ async def generate_query(
         sort_order=sort_order,
         timeframe=timeframe,
     )
-    generate_cypher_query_prompt = engine.load_prompt(
-        "generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
-    )
-    llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True)
-    json_query = to_json(llm_query)
-    await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__)
-    if json_query["query"] == "None":
-        return "No database query"
-    db_response = execute_query(json_query["query"])
-    await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__)
+    try:
+        result = await get_semantic_layer(llm, model)
+
+        generate_cypher_query_prompt = engine.load_prompt(
+            "generate-cypher-query", graph_schema=result, current_date=datetime.now()
+        )
+
+        llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True)
+        json_query = to_json(llm_query)
+        await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__)
+        if json_query["query"] == "None":
+            return "No database query"
+        db_response = execute_query(json_query["query"])
+        await publish_log_info(LogPrefix.USER, f"Database response: {db_response}", __name__)
+    except Exception as e:
+        logger.error(f"Error during data retrieval: {e}")
+        raise
     return str(db_response)
 
 
diff --git a/backend/src/agents/semantic_layer_builder.py b/backend/src/agents/semantic_layer_builder.py
new file mode 100644
index 000000000..06b7dce3f
--- /dev/null
+++ b/backend/src/agents/semantic_layer_builder.py
@@ -0,0 +1,115 @@
+# NOT NEEDED CURRENTLY AS THE RETURNED VALUE OF GRAPH SCHEMA IS STORED STATICALLY AS A JINJA TEMPLATE
+# THIS FILE IS CURRENTLY BROKEN BUT AS UNUSED THE FOLLOWING LINE IS SUPRESSING ERRORS
+# REMOVE NEXT LINE BEFORE WORKING ON FILE
+# pyright: reportAttributeAccessIssue=none
+from src.utils.graph_db_utils import execute_query
+import logging
+from src.prompts import PromptEngine
+import json
+import re
+
+logger = logging.getLogger(__name__)
+
+engine = PromptEngine()
+
+async def get_semantic_layer(llm, model):
+    finalised_graph_structure = {"nodes": {}, "properties": {}}
+
+    neo4j_graph_why_prompt = engine.load_prompt("neo4j-graph-why")
+
+    relationship_query = engine.load_prompt("relationships-query")
+
+    node_query = engine.load_prompt("nodes-query")
+
+    relationship_property_query = engine.load_prompt("relationship-property-cypher-query")
+
+    node_property_query = engine.load_prompt("node-property-cypher-query")
+
+    neo4j_relationships_understanding_prompt = engine.load_prompt(
+        "neo4j-relationship-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+    )
+
+    neo4j_nodes_understanding_prompt = engine.load_prompt(
+        "neo4j-nodes-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+    )
+
+    neo4j_relationship_property_prompt = engine.load_prompt(
+        "neo4j-property-intent-prompt", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+    )
+
+    neo4j_node_property_prompt = engine.load_prompt(
+        "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+    )
+
+
+    # Fetch and enrich relationships
+    relationship_result = execute_query(relationship_query)
+    relationships_neo4j = relationship_result[0]
+    enriched_relationships = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationships_neo4j))
+    enriched_relationships = json.dumps(enriched_relationships)
+    enriched_relationships = json.loads(enriched_relationships)
+    finalised_graph_structure["relationships"] = enriched_relationships if enriched_relationships else {}
+    logger.info(f"Finalised graph structure with enriched relationships: {finalised_graph_structure}")
+
+    # Fetch and enrich nodes
+    nodes_neo4j_result = execute_query(node_query)
+    nodes_neo4j = nodes_neo4j_result[0]
+    enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(nodes_neo4j))
+    if enriched_nodes.startswith("```json") and enriched_nodes.endswith("```"):
+      enriched_nodes = enriched_nodes[7:-3].strip()
+    enriched_nodes = json.loads(enriched_nodes)
+    json.dumps(enriched_nodes)
+    finalised_graph_structure['nodes']['labels'] = enriched_nodes['nodes']
+    logger.debug(f"Finalised graph structure with enriched nodes: {finalised_graph_structure}")
+
+    # Fetch and enrich relationship properties
+    # properties_result = execute_query(relationship_property_query)
+    # rel_properties_neo4j = properties_result[0]
+    # enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j))
+    # if enriched_rel_properties.startswith("```json") and enriched_rel_properties.endswith("```"):
+    #     enriched_rel_properties = enriched_rel_properties[7:-3].strip()
+    # enriched_rel_properties = json.loads(enriched_rel_properties)
+
+    # for new_rel in enriched_rel_properties["relProperties"]:
+    #     relationship_type = new_rel["relationship_type"]
+    #     properties_to_add = new_rel["properties"]
+    #     for rel in finalised_graph_structure["relationships"]:
+    #         if rel["cypher_representation"] == relationship_type:
+    #             if "properties" not in rel:
+    #                 rel["properties"] = []
+    #             rel["properties"] = properties_to_add
+    # logger.info(f"Enriched relationship properties response: {enriched_rel_properties}")
+    # # enriched_rel_properties = ast.literal_eval(enriched_rel_properties)
+    # # finalised_graph_structure["properties"]["relationship_properties"] = enriched_rel_properties["relProperties"]
+    # logger.debug(f"Finalised graph structure with enriched relationship properties: {finalised_graph_structure}")
+
+    # Fetch and enrich node properties
+    node_properties_neo4j_result = execute_query(node_property_query)
+    node_properties_neo4j = node_properties_neo4j_result[0]
+    filtered_payload = {
+        'nodeProperties': [
+            node for node in node_properties_neo4j['nodeProperties']
+            if all(prop['data_type'] is not None and prop['name'] is not None for prop in node['properties'])
+        ]
+    }
+    enriched_node_properties = await llm.chat(model, neo4j_node_property_prompt, str(filtered_payload))
+    if enriched_node_properties.startswith("```json") and enriched_node_properties.endswith("```"):
+        enriched_node_properties = enriched_node_properties[7:-3].strip()
+    enriched_node_properties = json.loads(enriched_node_properties)
+
+    # for new_node in enriched_node_properties["nodeProperties"]:
+    #     label = new_node["label"]
+    #     properties_to_add = new_node["properties"]
+
+    #     for node in finalised_graph_structure["nodes"]:
+    #         logger.info(f"finalised graph structure: {finalised_graph_structure}")
+    #         if node["label"] == label:
+    #             logger.info(f"node in finalised graph structure: {node["label"]}")
+    #             if "properties" not in node:
+    #                 node["properties"] = []
+    #             node["properties"] = properties_to_add
+    finalised_graph_structure["properties"]["node_properties"] = enriched_node_properties["nodeProperties"]
+    # logger.debug(f"Finalised graph structure with enriched node properties: {finalised_graph_structure}")
+
+    graph_schema = json.dumps(finalised_graph_structure, separators=(",", ":"))
+    return graph_schema
diff --git a/backend/src/prompts/templates/node-property-cypher-query.j2 b/backend/src/prompts/templates/node-property-cypher-query.j2
index 0f3a5bf2b..0d22bc0fb 100644
--- a/backend/src/prompts/templates/node-property-cypher-query.j2
+++ b/backend/src/prompts/templates/node-property-cypher-query.j2
@@ -7,7 +7,7 @@ WITH
         detail : ""
     }) as props
 RETURN  COLLECT({
-    node_label: node,
+    label: node,
     cypher_representation : "(:" + node + ")",
     properties: props
 }) AS nodeProperties
diff --git a/backend/src/utils/semantic_layer_builder.py b/backend/src/utils/semantic_layer_builder.py
deleted file mode 100644
index e2418c3f7..000000000
--- a/backend/src/utils/semantic_layer_builder.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# NOT NEEDED CURRENTLY AS THE RETURNED VALUE OF GRAPH SCHEMA IS STORED STATICALLY AS A JINJA TEMPLATE
-# THIS FILE IS CURRENTLY BROKEN BUT AS UNUSED THE FOLLOWING LINE IS SUPRESSING ERRORS
-# REMOVE NEXT LINE BEFORE WORKING ON FILE
-# pyright: reportAttributeAccessIssue=none
-from src.llm import call_model
-from src.utils.graph_db_utils import execute_query
-import logging
-from src.prompts import PromptEngine
-import json
-import ast
-
-logger = logging.getLogger(__name__)
-
-engine = PromptEngine()
-
-
-def get_semantic_layer():
-    finalised_graph_structure = {"nodes": {}, "properties": {}}
-
-    neo4j_graph_why_prompt = engine.load_prompt("neo4j-graph-why")
-
-    relationship_query = engine.load_prompt("relationships-query")
-
-    node_query = engine.load_prompt("nodes-query")
-
-    relationship_property_query = engine.load_prompt("relationship-property-cypher-query")
-
-    node_property_query = engine.load_prompt("node-property-cypher-query")
-
-    neo4j_relationships_understanding_prompt = engine.load_prompt(
-        "neo4j-relationship-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    neo4j_nodes_understanding_prompt = engine.load_prompt(
-        "neo4j-nodes-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    neo4j_relationship_property_prompt = engine.load_prompt(
-        "neo4j-property-intent-prompt", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    neo4j_node_property_prompt = engine.load_prompt(
-        "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    # Fetch and enrich relationships
-    relationship_result = execute_query(relationship_query)
-    relationships_neo4j = relationship_result[0]
-    enriched_relationships = call_model(neo4j_relationships_understanding_prompt, str(relationships_neo4j))
-    enriched_relationships = json.dumps(enriched_relationships)
-    enriched_relationships = json.loads(enriched_relationships)
-    finalised_graph_structure["relationships"] = enriched_relationships if enriched_relationships else {}
-    logger.debug(f"Finalised graph structure with enriched relationships: {finalised_graph_structure}")
-
-    # Fetch and enrich nodes
-    nodes_neo4j_result = execute_query(node_query)
-    nodes_neo4j = nodes_neo4j_result[0]
-    enriched_nodes = call_model(neo4j_nodes_understanding_prompt, str(nodes_neo4j))
-    enriched_nodes = ast.literal_eval(enriched_nodes)
-    finalised_graph_structure["nodes"]["labels"] = enriched_nodes["nodes"]
-    logger.debug(f"Finalised graph structure with enriched nodes: {finalised_graph_structure}")
-
-    # Fetch and enrich relationship properties
-    properties_result = execute_query(relationship_property_query)
-    rel_properties_neo4j = properties_result[0]
-    enriched_rel_properties = call_model(neo4j_relationship_property_prompt, str(rel_properties_neo4j))
-    enriched_rel_properties = ast.literal_eval(enriched_rel_properties)
-    finalised_graph_structure["properties"]["relationship_properties"] = enriched_rel_properties["relProperties"]
-    logger.debug(f"Finalised graph structure with enriched relationship properties: {finalised_graph_structure}")
-
-    # Fetch and enrich node properties
-    node_properties_neo4j_result = execute_query(node_property_query)
-    node_properties_neo4j = node_properties_neo4j_result[0]
-    enriched_node_properties = call_model(neo4j_node_property_prompt, str(node_properties_neo4j))
-    enriched_node_properties = ast.literal_eval(enriched_node_properties)
-    finalised_graph_structure["properties"]["node_properties"] = enriched_node_properties["nodeProperties"]
-    logger.debug(f"Finalised graph structure with enriched node properties: {finalised_graph_structure}")
-
-    graph_schema = json.dumps(finalised_graph_structure, separators=(",", ":"))
-    return graph_schema

From 171a259874fe1ff647cd5b2020ad63d16af9c7cf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Mon, 2 Sep 2024 15:47:21 +0100
Subject: [PATCH 02/48] Create cache to store graph schema

---
 backend/src/agents/datastore_agent.py        | 16 ++++++-
 backend/src/agents/semantic_layer_builder.py | 44 +++-----------------
 2 files changed, 19 insertions(+), 41 deletions(-)

diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py
index 2a11195c4..6c8eecbd2 100644
--- a/backend/src/agents/datastore_agent.py
+++ b/backend/src/agents/datastore_agent.py
@@ -1,3 +1,4 @@
+import json
 import logging
 from src.llm.llm import LLM
 from src.utils.graph_db_utils import execute_query
@@ -14,6 +15,7 @@
 
 engine = PromptEngine()
 
+cache = {}
 
 @tool(
     name="generate cypher query",
@@ -51,6 +53,14 @@
 async def generate_query(
     question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
 ) -> str:
+    async def get_semantic_layer_cache(graph_schema):
+        if not cache:
+            graph_schema = await get_semantic_layer(llm, model)
+            cache['nodes'] = graph_schema['nodes']
+            cache['properties'] = graph_schema['properties']
+            cache['relationships'] = graph_schema['relationships']
+            return cache
+
     details_to_create_cypher_query = engine.load_prompt(
         "details-to-create-cypher-query",
         question_intent=question_intent,
@@ -61,11 +71,13 @@ async def generate_query(
         timeframe=timeframe,
     )
     try:
-        result = await get_semantic_layer(llm, model)
+        graph_schema = await get_semantic_layer_cache(cache)
+        graph_schema = json.dumps(graph_schema, separators=(",", ":"))
 
         generate_cypher_query_prompt = engine.load_prompt(
-            "generate-cypher-query", graph_schema=result, current_date=datetime.now()
+            "generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
         )
+        logger.info(f"generate cypher query prompt: {generate_cypher_query_prompt}")
 
         llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True)
         json_query = to_json(llm_query)
diff --git a/backend/src/agents/semantic_layer_builder.py b/backend/src/agents/semantic_layer_builder.py
index 06b7dce3f..b5c6a0b78 100644
--- a/backend/src/agents/semantic_layer_builder.py
+++ b/backend/src/agents/semantic_layer_builder.py
@@ -6,12 +6,12 @@
 import logging
 from src.prompts import PromptEngine
 import json
-import re
 
 logger = logging.getLogger(__name__)
 
 engine = PromptEngine()
 
+
 async def get_semantic_layer(llm, model):
     finalised_graph_structure = {"nodes": {}, "properties": {}}
 
@@ -41,15 +41,15 @@ async def get_semantic_layer(llm, model):
         "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt
     )
 
-
     # Fetch and enrich relationships
     relationship_result = execute_query(relationship_query)
     relationships_neo4j = relationship_result[0]
     enriched_relationships = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationships_neo4j))
+    if enriched_relationships.startswith("```json") and enriched_relationships.endswith("```"):
+      enriched_relationships = enriched_relationships[7:-3].strip()
     enriched_relationships = json.dumps(enriched_relationships)
     enriched_relationships = json.loads(enriched_relationships)
     finalised_graph_structure["relationships"] = enriched_relationships if enriched_relationships else {}
-    logger.info(f"Finalised graph structure with enriched relationships: {finalised_graph_structure}")
 
     # Fetch and enrich nodes
     nodes_neo4j_result = execute_query(node_query)
@@ -62,27 +62,6 @@ async def get_semantic_layer(llm, model):
     finalised_graph_structure['nodes']['labels'] = enriched_nodes['nodes']
     logger.debug(f"Finalised graph structure with enriched nodes: {finalised_graph_structure}")
 
-    # Fetch and enrich relationship properties
-    # properties_result = execute_query(relationship_property_query)
-    # rel_properties_neo4j = properties_result[0]
-    # enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j))
-    # if enriched_rel_properties.startswith("```json") and enriched_rel_properties.endswith("```"):
-    #     enriched_rel_properties = enriched_rel_properties[7:-3].strip()
-    # enriched_rel_properties = json.loads(enriched_rel_properties)
-
-    # for new_rel in enriched_rel_properties["relProperties"]:
-    #     relationship_type = new_rel["relationship_type"]
-    #     properties_to_add = new_rel["properties"]
-    #     for rel in finalised_graph_structure["relationships"]:
-    #         if rel["cypher_representation"] == relationship_type:
-    #             if "properties" not in rel:
-    #                 rel["properties"] = []
-    #             rel["properties"] = properties_to_add
-    # logger.info(f"Enriched relationship properties response: {enriched_rel_properties}")
-    # # enriched_rel_properties = ast.literal_eval(enriched_rel_properties)
-    # # finalised_graph_structure["properties"]["relationship_properties"] = enriched_rel_properties["relProperties"]
-    # logger.debug(f"Finalised graph structure with enriched relationship properties: {finalised_graph_structure}")
-
     # Fetch and enrich node properties
     node_properties_neo4j_result = execute_query(node_property_query)
     node_properties_neo4j = node_properties_neo4j_result[0]
@@ -97,19 +76,6 @@ async def get_semantic_layer(llm, model):
         enriched_node_properties = enriched_node_properties[7:-3].strip()
     enriched_node_properties = json.loads(enriched_node_properties)
 
-    # for new_node in enriched_node_properties["nodeProperties"]:
-    #     label = new_node["label"]
-    #     properties_to_add = new_node["properties"]
-
-    #     for node in finalised_graph_structure["nodes"]:
-    #         logger.info(f"finalised graph structure: {finalised_graph_structure}")
-    #         if node["label"] == label:
-    #             logger.info(f"node in finalised graph structure: {node["label"]}")
-    #             if "properties" not in node:
-    #                 node["properties"] = []
-    #             node["properties"] = properties_to_add
     finalised_graph_structure["properties"]["node_properties"] = enriched_node_properties["nodeProperties"]
-    # logger.debug(f"Finalised graph structure with enriched node properties: {finalised_graph_structure}")
-
-    graph_schema = json.dumps(finalised_graph_structure, separators=(",", ":"))
-    return graph_schema
+    logger.debug(f"Finalised graph structure with enriched node properties: {finalised_graph_structure}")
+    return finalised_graph_structure

From 33c02e93ff575c511746809ae5c24fe3eeac7db4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Mon, 2 Sep 2024 16:11:04 +0100
Subject: [PATCH 03/48] Make cache global to prevent UnboundLocalError

---
 backend/src/agents/datastore_agent.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py
index 6c8eecbd2..bcb584b58 100644
--- a/backend/src/agents/datastore_agent.py
+++ b/backend/src/agents/datastore_agent.py
@@ -53,12 +53,14 @@
 async def generate_query(
     question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
 ) -> str:
+
     async def get_semantic_layer_cache(graph_schema):
+        global cache
         if not cache:
             graph_schema = await get_semantic_layer(llm, model)
-            cache['nodes'] = graph_schema['nodes']
-            cache['properties'] = graph_schema['properties']
-            cache['relationships'] = graph_schema['relationships']
+            cache = graph_schema
+            return cache
+        else:
             return cache
 
     details_to_create_cypher_query = engine.load_prompt(

From ed3065db4b8e8828092c7c71fbcb5dcc8695c729 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Tue, 3 Sep 2024 14:59:51 +0100
Subject: [PATCH 04/48] Bring in latest code from Michael Down

---
 backend/src/agents/datastore_agent.py         |   3 +-
 backend/src/agents/semantic_layer_builder.py  |  81 ---------
 .../templates/generate-cypher-query.j2        |   2 +-
 backend/src/prompts/templates/nodes-query.j2  |   6 -
 .../relationship-property-cypher-query.j2     |   2 +-
 .../prompts/templates/relationships-query.j2  |  13 +-
 backend/src/utils/semantic_layer_builder.py   | 162 ++++++++++++++++++
 7 files changed, 166 insertions(+), 103 deletions(-)
 delete mode 100644 backend/src/agents/semantic_layer_builder.py
 delete mode 100644 backend/src/prompts/templates/nodes-query.j2
 create mode 100644 backend/src/utils/semantic_layer_builder.py

diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py
index bcb584b58..89551c115 100644
--- a/backend/src/agents/datastore_agent.py
+++ b/backend/src/agents/datastore_agent.py
@@ -9,7 +9,7 @@
 from src.utils.log_publisher import LogPrefix, publish_log_info
 from .agent import Agent, agent
 from .tool import tool
-from src.agents.semantic_layer_builder import get_semantic_layer
+from src.utils.semantic_layer_builder import get_semantic_layer
 
 logger = logging.getLogger(__name__)
 
@@ -79,7 +79,6 @@ async def get_semantic_layer_cache(graph_schema):
         generate_cypher_query_prompt = engine.load_prompt(
             "generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
         )
-        logger.info(f"generate cypher query prompt: {generate_cypher_query_prompt}")
 
         llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True)
         json_query = to_json(llm_query)
diff --git a/backend/src/agents/semantic_layer_builder.py b/backend/src/agents/semantic_layer_builder.py
deleted file mode 100644
index b5c6a0b78..000000000
--- a/backend/src/agents/semantic_layer_builder.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# NOT NEEDED CURRENTLY AS THE RETURNED VALUE OF GRAPH SCHEMA IS STORED STATICALLY AS A JINJA TEMPLATE
-# THIS FILE IS CURRENTLY BROKEN BUT AS UNUSED THE FOLLOWING LINE IS SUPRESSING ERRORS
-# REMOVE NEXT LINE BEFORE WORKING ON FILE
-# pyright: reportAttributeAccessIssue=none
-from src.utils.graph_db_utils import execute_query
-import logging
-from src.prompts import PromptEngine
-import json
-
-logger = logging.getLogger(__name__)
-
-engine = PromptEngine()
-
-
-async def get_semantic_layer(llm, model):
-    finalised_graph_structure = {"nodes": {}, "properties": {}}
-
-    neo4j_graph_why_prompt = engine.load_prompt("neo4j-graph-why")
-
-    relationship_query = engine.load_prompt("relationships-query")
-
-    node_query = engine.load_prompt("nodes-query")
-
-    relationship_property_query = engine.load_prompt("relationship-property-cypher-query")
-
-    node_property_query = engine.load_prompt("node-property-cypher-query")
-
-    neo4j_relationships_understanding_prompt = engine.load_prompt(
-        "neo4j-relationship-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    neo4j_nodes_understanding_prompt = engine.load_prompt(
-        "neo4j-nodes-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    neo4j_relationship_property_prompt = engine.load_prompt(
-        "neo4j-property-intent-prompt", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    neo4j_node_property_prompt = engine.load_prompt(
-        "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    # Fetch and enrich relationships
-    relationship_result = execute_query(relationship_query)
-    relationships_neo4j = relationship_result[0]
-    enriched_relationships = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationships_neo4j))
-    if enriched_relationships.startswith("```json") and enriched_relationships.endswith("```"):
-      enriched_relationships = enriched_relationships[7:-3].strip()
-    enriched_relationships = json.dumps(enriched_relationships)
-    enriched_relationships = json.loads(enriched_relationships)
-    finalised_graph_structure["relationships"] = enriched_relationships if enriched_relationships else {}
-
-    # Fetch and enrich nodes
-    nodes_neo4j_result = execute_query(node_query)
-    nodes_neo4j = nodes_neo4j_result[0]
-    enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(nodes_neo4j))
-    if enriched_nodes.startswith("```json") and enriched_nodes.endswith("```"):
-      enriched_nodes = enriched_nodes[7:-3].strip()
-    enriched_nodes = json.loads(enriched_nodes)
-    json.dumps(enriched_nodes)
-    finalised_graph_structure['nodes']['labels'] = enriched_nodes['nodes']
-    logger.debug(f"Finalised graph structure with enriched nodes: {finalised_graph_structure}")
-
-    # Fetch and enrich node properties
-    node_properties_neo4j_result = execute_query(node_property_query)
-    node_properties_neo4j = node_properties_neo4j_result[0]
-    filtered_payload = {
-        'nodeProperties': [
-            node for node in node_properties_neo4j['nodeProperties']
-            if all(prop['data_type'] is not None and prop['name'] is not None for prop in node['properties'])
-        ]
-    }
-    enriched_node_properties = await llm.chat(model, neo4j_node_property_prompt, str(filtered_payload))
-    if enriched_node_properties.startswith("```json") and enriched_node_properties.endswith("```"):
-        enriched_node_properties = enriched_node_properties[7:-3].strip()
-    enriched_node_properties = json.loads(enriched_node_properties)
-
-    finalised_graph_structure["properties"]["node_properties"] = enriched_node_properties["nodeProperties"]
-    logger.debug(f"Finalised graph structure with enriched node properties: {finalised_graph_structure}")
-    return finalised_graph_structure
diff --git a/backend/src/prompts/templates/generate-cypher-query.j2 b/backend/src/prompts/templates/generate-cypher-query.j2
index 36ae24e6e..597df245c 100644
--- a/backend/src/prompts/templates/generate-cypher-query.j2
+++ b/backend/src/prompts/templates/generate-cypher-query.j2
@@ -30,4 +30,4 @@ When returning a value, always remove the `-` sign before the number.
 Here is the graph schema:
 {{ graph_schema }}
 
-The current date and time is {{ current_date }}
+The current date and time is {{ current_date }} and the currency of the data is GBP.
diff --git a/backend/src/prompts/templates/nodes-query.j2 b/backend/src/prompts/templates/nodes-query.j2
deleted file mode 100644
index cd1e2207e..000000000
--- a/backend/src/prompts/templates/nodes-query.j2
+++ /dev/null
@@ -1,6 +0,0 @@
-call db.labels() yield label
-return collect({
-    label: label,
-    cypher_representation : "(:" + label + ")",
-    detail: "A " + label  + " is a..."
-}) AS nodes
diff --git a/backend/src/prompts/templates/relationship-property-cypher-query.j2 b/backend/src/prompts/templates/relationship-property-cypher-query.j2
index f03d64b9a..952c11ce0 100644
--- a/backend/src/prompts/templates/relationship-property-cypher-query.j2
+++ b/backend/src/prompts/templates/relationship-property-cypher-query.j2
@@ -4,7 +4,7 @@ WITH
     COLLECT({
         name: propertyName,
         data_type: propertyTypes,
-        detail: "A " + propertyName + " is a.. "
+        detail: ""
     }) AS props
 RETURN COLLECT({
     relationship_type: "[" + REPLACE(rel, "`", "") + "]",
diff --git a/backend/src/prompts/templates/relationships-query.j2 b/backend/src/prompts/templates/relationships-query.j2
index f59ca8343..54e1a73e1 100644
--- a/backend/src/prompts/templates/relationships-query.j2
+++ b/backend/src/prompts/templates/relationships-query.j2
@@ -1,12 +1 @@
-CALL apoc.meta.stats() YIELD relTypes
-WITH relTypes, keys(relTypes) AS relTypeKeys
-UNWIND relTypeKeys AS relTypeKey
-WITH relTypeKey, relTypes[relTypeKey] AS count
-WHERE relTypeKey CONTAINS ")->(:"
-   OR relTypeKey CONTAINS "(:"
-WITH collect({
-    label: split(split(relTypeKey, "-")[1], ">")[0],
-    cypher_representation: relTypeKey,
-    detail: ""
-}) AS paths
-RETURN paths
\ No newline at end of file
+call db.schema.visualization
diff --git a/backend/src/utils/semantic_layer_builder.py b/backend/src/utils/semantic_layer_builder.py
new file mode 100644
index 000000000..3056a49ba
--- /dev/null
+++ b/backend/src/utils/semantic_layer_builder.py
@@ -0,0 +1,162 @@
+from src.utils.graph_db_utils import execute_query
+import logging
+from src.prompts import PromptEngine
+import json
+
+logger = logging.getLogger(__name__)
+
+engine = PromptEngine()
+
+
+async def get_semantic_layer(llm, model):
+    finalised_graph_structure = {"nodes": {}, "properties": {}}
+
+    neo4j_graph_why_prompt = engine.load_prompt("neo4j-graph-why")
+
+    relationship_query = engine.load_prompt("relationships-query")
+
+    relationship_property_query = engine.load_prompt("relationship-property-cypher-query")
+
+    node_property_query = engine.load_prompt("node-property-cypher-query")
+
+    neo4j_relationships_understanding_prompt = engine.load_prompt(
+        "neo4j-relationship-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+    )
+
+    neo4j_nodes_understanding_prompt = engine.load_prompt(
+        "neo4j-nodes-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+    )
+
+    neo4j_relationship_property_prompt = engine.load_prompt(
+        "neo4j-property-intent-prompt", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+    )
+
+    neo4j_node_property_prompt = engine.load_prompt(
+        "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+    )
+
+    relationship_result = execute_query(relationship_query)
+    payload = relationship_result[0]
+
+    nodes = []
+    relationships_dict = {}
+
+    # Convert nodes
+    for node in payload['nodes']:
+        nodes.append({
+            "cypher_representation": f"(:{node['name']})",
+            "label": node['name'],
+            "indexes": node.get('indexes', []),
+            "constraints": node.get('constraints', [])
+        })
+
+    # Convert relationships
+    for relationship in payload['relationships']:
+        start_node = relationship[0]['name']
+        relationship_type = relationship[1]
+        end_node = relationship[2]['name']
+        path = f"(:{start_node})-[:{relationship_type}]->(:{end_node})"
+
+        if relationship_type not in relationships_dict:
+            relationships_dict[relationship_type] = {
+                "cypher_representation": f"[:{relationship_type}]",
+                "type": relationship_type,
+                "paths": []
+            }
+
+        relationships_dict[relationship_type]["paths"].append({
+            "path": path,
+            "detail": ""
+        })
+    # Convert relationships_dict to a list
+    relationships = list(relationships_dict.values())
+
+    finalised_graph_structure = {
+        "nodes": nodes,
+        "relationships": relationships
+    }
+    json.dumps(finalised_graph_structure)
+    relationships = finalised_graph_structure['relationships']
+    enriched_relationships_list = []
+
+    for relationship in relationships:
+        enriched_relationship = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationship))
+
+        if enriched_relationship.startswith("```json") and enriched_relationship.endswith("```"):
+            enriched_relationship = enriched_relationship[7:-3].strip()
+        enriched_relationships_list.append(json.loads(enriched_relationship))
+
+        finalised_graph_structure['relationships'] = enriched_relationships_list
+    logger.debug(f"finalised graph structure with enriched relationships: {finalised_graph_structure}")
+
+    # Fetch and enrich nodes
+    neo4j_data = finalised_graph_structure['nodes']
+    enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(neo4j_data))
+
+    if enriched_nodes.startswith("```json") and enriched_nodes.endswith("```"):
+        enriched_nodes = enriched_nodes[7:-3].strip()
+    enriched_nodes = json.loads(enriched_nodes)
+    json.dumps(enriched_nodes)
+    finalised_graph_structure['nodes'] = enriched_nodes
+    logger.debug(f"finalised graph structure: {finalised_graph_structure}")
+
+    # Fetch and enrich relationship properties
+    properties_result = execute_query(relationship_property_query)
+    rel_properties_neo4j = properties_result[0]
+    cleaned_rel_properties = []
+
+    for rel_property in rel_properties_neo4j['relProperties']:
+        cleaned_properties = [prop for prop in rel_property['properties'] if prop['name'] is not None]
+        if cleaned_properties:
+            rel_property['properties'] = cleaned_properties
+            cleaned_rel_properties.append(rel_property)
+
+    rel_properties_neo4j = {'relProperties': cleaned_rel_properties}
+    json.dumps(rel_properties_neo4j)
+
+    enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j))
+
+    if enriched_rel_properties.startswith("```json") and enriched_rel_properties.endswith("```"):
+        enriched_rel_properties = enriched_rel_properties[7:-3].strip()
+
+    enriched_rel_properties = json.loads(enriched_rel_properties)
+
+    # Merge properties
+    for new_rel in enriched_rel_properties["relProperties"]:
+        relationship_type = new_rel["relationship_type"]
+        properties_to_add = new_rel["properties"]
+
+        for rel in finalised_graph_structure["relationships"]:
+            if rel["cypher_representation"] == relationship_type:
+                if "properties" not in rel:
+                    rel["properties"] = []
+                rel["properties"] = properties_to_add
+
+    logger.debug(f"finalised graph structure with enriched properties: {finalised_graph_structure}")
+
+    # Fetch and enrich node properties
+    node_properties_neo4j_result = execute_query(node_property_query)
+    node_properties_neo4j = node_properties_neo4j_result[0]
+    filtered_payload = {
+        'nodeProperties': [
+            node for node in node_properties_neo4j['nodeProperties']
+            if all(prop['data_type'] is not None and prop['name'] is not None for prop in node['properties'])
+        ]
+    }
+    enriched_node_properties = await llm.chat(model, neo4j_node_property_prompt, str(filtered_payload))
+    if enriched_node_properties.startswith("```json") and enriched_node_properties.endswith("```"):
+        enriched_node_properties = enriched_node_properties[7:-3].strip()
+    enriched_node_properties = json.loads(enriched_node_properties)
+
+    for new_node in enriched_node_properties["nodeProperties"]:
+        label = new_node["label"]
+        properties_to_add = new_node["properties"]
+
+        for node in finalised_graph_structure["nodes"]:
+            if node["label"] == label:
+                if "properties" not in node:
+                    node["properties"] = []
+                node["properties"] = properties_to_add
+    logger.debug(f"finalised graph structure with enriched nodes: {finalised_graph_structure}")
+
+    return finalised_graph_structure

From 80a1fd2167c9504dd8ab76ab5a836b30f525b25e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Tue, 3 Sep 2024 15:15:37 +0100
Subject: [PATCH 05/48] Refact with sanitise function

---
 backend/src/utils/semantic_layer_builder.py | 22 +++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/backend/src/utils/semantic_layer_builder.py b/backend/src/utils/semantic_layer_builder.py
index 3056a49ba..8860b277e 100644
--- a/backend/src/utils/semantic_layer_builder.py
+++ b/backend/src/utils/semantic_layer_builder.py
@@ -82,8 +82,8 @@ async def get_semantic_layer(llm, model):
     for relationship in relationships:
         enriched_relationship = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationship))
 
-        if enriched_relationship.startswith("```json") and enriched_relationship.endswith("```"):
-            enriched_relationship = enriched_relationship[7:-3].strip()
+        enriched_relationship = sanitise_script(enriched_relationship)
+
         enriched_relationships_list.append(json.loads(enriched_relationship))
 
         finalised_graph_structure['relationships'] = enriched_relationships_list
@@ -93,8 +93,8 @@ async def get_semantic_layer(llm, model):
     neo4j_data = finalised_graph_structure['nodes']
     enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(neo4j_data))
 
-    if enriched_nodes.startswith("```json") and enriched_nodes.endswith("```"):
-        enriched_nodes = enriched_nodes[7:-3].strip()
+    enriched_nodes = sanitise_script(enriched_nodes)
+
     enriched_nodes = json.loads(enriched_nodes)
     json.dumps(enriched_nodes)
     finalised_graph_structure['nodes'] = enriched_nodes
@@ -116,8 +116,7 @@ async def get_semantic_layer(llm, model):
 
     enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j))
 
-    if enriched_rel_properties.startswith("```json") and enriched_rel_properties.endswith("```"):
-        enriched_rel_properties = enriched_rel_properties[7:-3].strip()
+    enriched_rel_properties = sanitise_script(enriched_rel_properties)
 
     enriched_rel_properties = json.loads(enriched_rel_properties)
 
@@ -144,8 +143,7 @@ async def get_semantic_layer(llm, model):
         ]
     }
     enriched_node_properties = await llm.chat(model, neo4j_node_property_prompt, str(filtered_payload))
-    if enriched_node_properties.startswith("```json") and enriched_node_properties.endswith("```"):
-        enriched_node_properties = enriched_node_properties[7:-3].strip()
+    enriched_node_properties = sanitise_script(enriched_node_properties)
     enriched_node_properties = json.loads(enriched_node_properties)
 
     for new_node in enriched_node_properties["nodeProperties"]:
@@ -160,3 +158,11 @@ async def get_semantic_layer(llm, model):
     logger.debug(f"finalised graph structure with enriched nodes: {finalised_graph_structure}")
 
     return finalised_graph_structure
+
+def sanitise_script(script: str) -> str:
+    script = script.strip()
+    if script.startswith("```json"):
+        script = script[7:]
+    if script.endswith("```"):
+        script = script[:-3]
+    return script.strip()

From 5cacc3ea2c13a0b72d2b853e286549e18202bebe Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Wed, 4 Sep 2024 15:34:27 +0100
Subject: [PATCH 06/48] Write tests and refactor

---
 backend/src/agents/datastore_agent.py        | 75 ++++++++--------
 backend/src/utils/semantic_layer_builder.py  |  6 --
 backend/tests/agents/datastore_agent_test.py | 90 ++++++++++++++++++++
 3 files changed, 130 insertions(+), 41 deletions(-)
 create mode 100644 backend/tests/agents/datastore_agent_test.py

diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py
index 89551c115..9b872649a 100644
--- a/backend/src/agents/datastore_agent.py
+++ b/backend/src/agents/datastore_agent.py
@@ -17,40 +17,7 @@
 
 cache = {}
 
-@tool(
-    name="generate cypher query",
-    description="Generate Cypher query if the category is data driven, based on the operation to be performed",
-    parameters={
-        "question_intent": Parameter(
-            type="string",
-            description="The intent the question will be based on",
-        ),
-        "operation": Parameter(
-            type="string",
-            description="The operation the cypher query will have to perform",
-        ),
-        "question_params": Parameter(
-            type="string",
-            description="""
-                The specific parameters required for the question to be answered with the question_intent
-                or none if no params required
-            """,
-        ),
-        "aggregation": Parameter(
-            type="string",
-            description="Any aggregation that is required to answer the question or none if no aggregation is needed",
-        ),
-        "sort_order": Parameter(
-            type="string",
-            description="The order a list should be sorted in or none if no sort_order is needed",
-        ),
-        "timeframe": Parameter(
-            type="string",
-            description="string of the timeframe to be considered or none if no timeframe is needed",
-        ),
-    },
-)
-async def generate_query(
+async def generate_cypher_query_core(
     question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
 ) -> str:
 
@@ -92,11 +59,49 @@ async def get_semantic_layer_cache(graph_schema):
         raise
     return str(db_response)
 
+@tool(
+    name="generate cypher query",
+    description="Generate Cypher query if the category is data driven, based on the operation to be performed",
+    parameters={
+        "question_intent": Parameter(
+            type="string",
+            description="The intent the question will be based on",
+        ),
+        "operation": Parameter(
+            type="string",
+            description="The operation the cypher query will have to perform",
+        ),
+        "question_params": Parameter(
+            type="string",
+            description="""
+                The specific parameters required for the question to be answered with the question_intent
+                or none if no params required
+            """,
+        ),
+        "aggregation": Parameter(
+            type="string",
+            description="Any aggregation that is required to answer the question or none if no aggregation is needed",
+        ),
+        "sort_order": Parameter(
+            type="string",
+            description="The order a list should be sorted in or none if no sort_order is needed",
+        ),
+        "timeframe": Parameter(
+            type="string",
+            description="string of the timeframe to be considered or none if no timeframe is needed",
+        ),
+    },
+)
+
+async def generate_cypher(question_intent, operation, question_params, aggregation, sort_order,
+                          timeframe, llm: LLM, model) -> str:
+    return await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order,
+                                            timeframe, llm, model)
 
 @agent(
     name="DatastoreAgent",
     description="This agent is responsible for handling database queries relating to the user's personal data.",
-    tools=[generate_query],
+    tools=[generate_cypher],
 )
 class DatastoreAgent(Agent):
     pass
diff --git a/backend/src/utils/semantic_layer_builder.py b/backend/src/utils/semantic_layer_builder.py
index 8860b277e..419cd8aba 100644
--- a/backend/src/utils/semantic_layer_builder.py
+++ b/backend/src/utils/semantic_layer_builder.py
@@ -81,9 +81,7 @@ async def get_semantic_layer(llm, model):
 
     for relationship in relationships:
         enriched_relationship = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationship))
-
         enriched_relationship = sanitise_script(enriched_relationship)
-
         enriched_relationships_list.append(json.loads(enriched_relationship))
 
         finalised_graph_structure['relationships'] = enriched_relationships_list
@@ -92,9 +90,7 @@ async def get_semantic_layer(llm, model):
     # Fetch and enrich nodes
     neo4j_data = finalised_graph_structure['nodes']
     enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(neo4j_data))
-
     enriched_nodes = sanitise_script(enriched_nodes)
-
     enriched_nodes = json.loads(enriched_nodes)
     json.dumps(enriched_nodes)
     finalised_graph_structure['nodes'] = enriched_nodes
@@ -115,9 +111,7 @@ async def get_semantic_layer(llm, model):
     json.dumps(rel_properties_neo4j)
 
     enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j))
-
     enriched_rel_properties = sanitise_script(enriched_rel_properties)
-
     enriched_rel_properties = json.loads(enriched_rel_properties)
 
     # Merge properties
diff --git a/backend/tests/agents/datastore_agent_test.py b/backend/tests/agents/datastore_agent_test.py
new file mode 100644
index 000000000..2091591ec
--- /dev/null
+++ b/backend/tests/agents/datastore_agent_test.py
@@ -0,0 +1,90 @@
+import pytest
+from unittest.mock import AsyncMock, patch, MagicMock
+from src.agents.datastore_agent import generate_cypher_query_core
+
+@pytest.mark.asyncio
+@patch("src.agents.datastore_agent.get_semantic_layer", new_callable=AsyncMock)
+@patch("src.agents.datastore_agent.execute_query", new_callable=MagicMock)
+@patch("src.agents.datastore_agent.publish_log_info", new_callable=AsyncMock)
+@patch("src.agents.datastore_agent.engine.load_prompt", autospec=True)
+async def test_generate_query_success(mock_load_prompt, mock_publish_log_info,
+                                      mock_execute_query, mock_get_semantic_layer):
+    llm = AsyncMock()
+    model = "mock_model"
+
+    mock_load_prompt.side_effect = [
+        "details to create cypher query prompt",
+        "generate cypher query prompt"
+    ]
+
+    llm.chat.return_value = '{"query": "MATCH (n) RETURN n"}'
+
+    mock_get_semantic_layer.return_value = {"nodes": [], "edges": []}
+
+    mock_execute_query.return_value = "Mocked response from the database"
+
+    question_intent = "Find all nodes"
+    operation = "MATCH"
+    question_params = "n"
+    aggregation = "none"
+    sort_order = "none"
+    timeframe = "2024"
+    model = "gpt-4"
+
+    result = await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order,
+                                              timeframe, llm, model)
+
+    assert result == "Mocked response from the database"
+    mock_load_prompt.assert_called()
+    llm.chat.assert_called_once_with(
+        model,
+        "generate cypher query prompt",
+        "details to create cypher query prompt",
+        return_json=True
+    )
+    mock_execute_query.assert_called_once_with("MATCH (n) RETURN n")
+    mock_publish_log_info.assert_called()
+
+@pytest.mark.asyncio
+@patch("src.agents.datastore_agent.get_semantic_layer", new_callable=AsyncMock)
+@patch("src.agents.datastore_agent.execute_query", new_callable=MagicMock)
+@patch("src.agents.datastore_agent.publish_log_info", new_callable=AsyncMock)
+@patch("src.agents.datastore_agent.engine.load_prompt", autospec=True)
+async def test_generate_query_failure(mock_load_prompt, mock_publish_log_info,
+                                      mock_execute_query, mock_get_semantic_layer):
+    llm = AsyncMock()
+    model = "mock_model"
+
+    mock_load_prompt.side_effect = [
+        "details to create cypher query prompt",
+        "generate cypher query prompt"
+    ]
+
+    llm.chat.side_effect = Exception("LLM chat failed")
+
+    mock_get_semantic_layer.return_value = {"nodes": [], "edges": []}
+
+    question_intent = "Find all nodes"
+    operation = "MATCH"
+    question_params = "n"
+    aggregation = "none"
+    sort_order = "none"
+    timeframe = "2024"
+    model = "gpt-4"
+
+    with pytest.raises(Exception, match="LLM chat failed"):
+        await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order,
+                                         timeframe, llm, model)
+
+    mock_load_prompt.assert_called()
+    llm.chat.assert_called_once_with(
+        model,
+        "generate cypher query prompt",
+        "details to create cypher query prompt",
+        return_json=True
+    )
+    mock_publish_log_info.assert_not_called()
+    mock_execute_query.assert_not_called()
+
+if __name__ == "__main__":
+    pytest.main(["-v"])

From 8bc9819fb10620dd43118920381d3fa450233bad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Wed, 4 Sep 2024 15:35:43 +0100
Subject: [PATCH 07/48] Delete unused file

---
 backend/src/prompts/templates/graph-schema.j2 | 156 ------------------
 1 file changed, 156 deletions(-)
 delete mode 100644 backend/src/prompts/templates/graph-schema.j2

diff --git a/backend/src/prompts/templates/graph-schema.j2 b/backend/src/prompts/templates/graph-schema.j2
deleted file mode 100644
index 38892655c..000000000
--- a/backend/src/prompts/templates/graph-schema.j2
+++ /dev/null
@@ -1,156 +0,0 @@
-{
-  "nodes": {
-    "labels": [
-      {
-        "label": "Account",
-        "detail": "An Account is a unique user profile with specific identifiers and associated transactions.",
-        "cypher_representation": "(:Account)"
-      },
-      {
-        "label": "Transaction",
-        "detail": "A Transaction is a record of a financial exchange between an Account and a Merchant, containing details like date, amount, etc.",
-        "cypher_representation": "(:Transaction)"
-      },
-      {
-        "label": "Merchant",
-        "detail": "A Merchant is an entity that provides goods or services in exchange for payment, linked to Transactions.",
-        "cypher_representation": "(:Merchant)"
-      },
-      {
-        "label": "Classification",
-        "detail": "A Classification is a category assigned to a Transaction, based on the type of purchase or service.",
-        "cypher_representation": "(:Classification)"
-      }
-    ]
-  },
-  "properties": {
-    "node_properties": [
-      {
-        "node_label": "Transaction",
-        "properties": [
-          {
-            "name": "id",
-            "data_type": "String",
-            "detail": "Unique identifier for the transaction"
-          },
-          {
-            "name": "amount",
-            "data_type": "Long",
-            "detail": "The amount of money involved in the transaction"
-          },
-          {
-            "name": "description",
-            "data_type": "String",
-            "detail": "A short explanation or reason for the transaction"
-          },
-          {
-            "name": "date",
-            "data_type": "DateTime",
-            "detail": "The date and time when the transaction occurred"
-          },
-          {
-            "name": "type",
-            "data_type": "String",
-            "detail": "The category or type of the transaction. One of: DEBIT, CREDIT, TRANSFER"
-          }
-        ]
-      },
-      {
-        "node_label": "Merchant",
-        "properties": [
-          {
-            "name": "name",
-            "data_type": "String",
-            "detail": "The name of the merchant / company involved in the transaction"
-          }
-        ]
-      },
-      {
-        "node_label": "Classification",
-        "properties": [
-          {
-            "name": "name",
-            "data_type": "String",
-            "detail": "The category or classification of the transaction"
-          }
-        ]
-      },
-      {
-        "node_label": "Account",
-        "properties": [
-          {
-            "name": "name",
-            "data_type": "String",
-            "detail": "The name or identifier of the account involved in the transaction"
-          }
-        ]
-      }
-    ],
-    "relationship_properties": [
-      {
-        "relationship_type": "[:PAID_BY]",
-        "properties": [
-          {
-            "name": "transaction_id",
-            "data_type": "String",
-            "detail": "Represents the payment of a Transaction by an Account"
-          }
-        ]
-      },
-      {
-        "relationship_type": "[:PAID_TO]",
-        "properties": [
-          {
-            "name": "transaction_id",
-            "data_type": "String",
-            "detail": "Represents the payment for a Transaction received by a Merchant"
-          }
-        ]
-      },
-      {
-        "relationship_type": "[:CLASSIFIED_AS]",
-        "properties": [
-          {
-            "name": "category",
-            "data_type": "String",
-            "detail": "Represents the categorization or classification of a transaction"
-          }
-        ]
-      }
-    ]
-  },
-  "relationships": {
-    "paths": [
-      {
-        "label": "[:PAID_TO]",
-        "detail": "Represents a payment made to a merchant",
-        "cypher_representation": "(:Transaction)-[:PAID_TO]->(:Merchant)"
-      },
-      {
-        "label": "[:PAID_BY]",
-        "detail": "Represents a transaction that is paid by a specific account",
-        "cypher_representation": "(:Transaction)-[:PAID_BY]->(:Account)"
-      },
-      {
-        "label": "[:CLASSIFIED_AS]",
-        "detail": "Represents the classification of a node",
-        "cypher_representation": "(:Transaction)-[:CLASSIFIED_AS]->(:Classification)"
-      },
-      {
-        "label": "[:PAID_TO]",
-        "detail": "Represents a payment made to a node",
-        "cypher_representation": "(:Transaction)-[:PAID_TO]->(:Merchant)"
-      },
-      {
-        "label": "[:PAID_BY]",
-        "detail": "Represents the account that made a transaction *NOTE DIRECTION OF ARROW*",
-        "cypher_representation": "(:Transaction)-[:PAID_BY]->(:Account)"
-      },
-      {
-        "label": "[:CLASSIFIED_AS]",
-        "detail": "Represents the classification of a transaction",
-        "cypher_representation": "(:Transaction)-[:CLASSIFIED_AS]->(:Classification)"
-      }
-    ]
-  }
-}

From 75c95f282690261dc8a3b8134ff3590a908c5204 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Wed, 4 Sep 2024 15:52:27 +0100
Subject: [PATCH 08/48] Fix linting issue

---
 backend/src/agents/datastore_agent.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py
index 9b872649a..a1c3b4cda 100644
--- a/backend/src/agents/datastore_agent.py
+++ b/backend/src/agents/datastore_agent.py
@@ -47,7 +47,8 @@ async def get_semantic_layer_cache(graph_schema):
             "generate-cypher-query", graph_schema=graph_schema, current_date=datetime.now()
         )
 
-        llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query, return_json=True)
+        llm_query = await llm.chat(model, generate_cypher_query_prompt, details_to_create_cypher_query,
+                                   return_json=True)
         json_query = to_json(llm_query)
         await publish_log_info(LogPrefix.USER, f"Cypher generated by the LLM: {llm_query}", __name__)
         if json_query["query"] == "None":

From c73a76896ecad45412bb5c0c751242d6f6fa09df Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 5 Sep 2024 10:10:47 +0100
Subject: [PATCH 09/48] Pushing Intent Changes

---
 backend/src/prompts/templates/intent.j2 | 23 +++++++++++++++++++++--
 1 file changed, 21 insertions(+), 2 deletions(-)

diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index 84361b217..07617ecd9 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -6,10 +6,21 @@ The question is:
 {{ question }} 
 
  
-The task is to comprehend the intention of the question. The question can be composed of different intents and when it is the case, examine all intents one by one to determine which one to tackle first as you may need the data gathered from a secondary intent to perform the first intent.  
-You are NOT ALLOWED to make up sample data or example values. Only use concrete data for which you can name the source. 
+Your task is to accurately comprehend the intentions behind the question. The question can be composed of different intents and when it is the case, examine all intents one by one to determine which one to tackle first as you may need the data gathered from a secondary intent to perform the first intent.  
+If the question contains multiple intents, break them down into individual tasks, and specify the order in which these tasks should be tackled. The order should ensure that each intent is addressed in a logical sequence, particularly if one intent depends on the data obtained from another.
+You are NOT ALLOWED to make up sample data or example values. Only use concrete data for which you can name the source.
 Based on this understanding, the following query must be formulated to extract the necessary data, which can then be used to address the question. 
 
+Guidelines:
+1. Determine each distinct intent in the question.
+2. Sequence the intents: Identify which intent should be tackled first and which should follow.
+3. For each intent, specify:
+    - The exact operation required (e.g., "literal search", "filter + aggregation", "data visualization").
+    - The category of the question (e.g., "data-driven", "data presentation", "general knowledge").
+    - Any specific parameters or conditions that apply.
+    - The correct aggregation and sorting methods if applicable.
+4. Avoid conflating intents: If a user’s query asks for data retrieval and its visualization, treat these as separate operations.
+5. Do not make assumptions or create hypothetical data. Use only concrete data where applicable.
  
 Specify an operation type under the operation key; here are a few examples: 
 
@@ -53,3 +64,11 @@ Response:
 Q: Find the schedule of the local train station. 
 Response: 
 {"query":"Find the schedule of the local train station.","user_intent":"find train schedule","questions":[{"query":"Find the schedule of the local train station.","question_intent":"retrieve train schedule from web","operation":"online search","question_category":"search online","parameters":[{"type":"train station","value":"local"}],"sort_order":"none"}]} 
+
+Q: What are the different subscriptions with Netflix? Show me the results in a chart.
+Response:
+{"query": "What are the different subscriptions with Netflix? Show me the results in a chart.", "user_intent": "find and display subscription information", "questions": [{"query": "What are the different subscriptions with Netflix?", "question_intent": "retrieve subscription information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Netflix"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Show me the results in a chart", "question_intent": "display information in a chart", "operation": "data visualization", "question_category": "data presentation", "parameters": [], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}
+
+Q: Show me a chart of different subscription prices with Netflix?
+Response:
+{"query": "Show me a chart of different subscription prices with Netflix?", "user_intent": "retrieve and visualize subscription data", "questions": [{"query": "What are the different subscription prices with Netflix?", "question_intent": "retrieve subscription pricing information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Netflix"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Show me the results in a chart", "question_intent": "display subscription pricing information in a chart", "operation": "data visualization", "question_category": "data presentation", "parameters": [], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}

From d2d15fb7e2f958c5960472d3a008a2e9f88025ac Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Thu, 5 Sep 2024 10:45:20 +0100
Subject: [PATCH 10/48] Change return value to json

---
 backend/src/agents/datastore_agent.py | 6 +++++-
 backend/src/agents/web_agent.py       | 6 +++++-
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py
index a1c3b4cda..3cc9f2f73 100644
--- a/backend/src/agents/datastore_agent.py
+++ b/backend/src/agents/datastore_agent.py
@@ -58,7 +58,11 @@ async def get_semantic_layer_cache(graph_schema):
     except Exception as e:
         logger.error(f"Error during data retrieval: {e}")
         raise
-    return str(db_response)
+    response = {
+        "content": db_response,
+        "ignore_validation": "false"
+    }
+    return json.dumps(response, indent=4)
 
 @tool(
     name="generate cypher query",
diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py
index 50c9be5a1..209ff1959 100644
--- a/backend/src/agents/web_agent.py
+++ b/backend/src/agents/web_agent.py
@@ -33,7 +33,11 @@ async def web_general_search_core(search_query, llm, model) -> str:
             if not summary:
                 continue
             if await is_valid_answer(summary, search_query):
-                return summary
+                response = {
+                    "content": summary,
+                    "ignore_validation": "true"
+                }
+                return json.dumps(response, indent=4)
         return "No relevant information found on the internet for the given query."
     except Exception as e:
         logger.error(f"Error in web_general_search_core: {e}")

From 96935245da55bde5550dbbea1e1ac1cade194626 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Thu, 5 Sep 2024 11:20:49 +0100
Subject: [PATCH 11/48] Edit tests to adjust to changes + web agent validation
 to false

---
 backend/src/agents/web_agent.py              | 2 +-
 backend/tests/agents/datastore_agent_test.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py
index 209ff1959..b330baec6 100644
--- a/backend/src/agents/web_agent.py
+++ b/backend/src/agents/web_agent.py
@@ -35,7 +35,7 @@ async def web_general_search_core(search_query, llm, model) -> str:
             if await is_valid_answer(summary, search_query):
                 response = {
                     "content": summary,
-                    "ignore_validation": "true"
+                    "ignore_validation": "false"
                 }
                 return json.dumps(response, indent=4)
         return "No relevant information found on the internet for the given query."
diff --git a/backend/tests/agents/datastore_agent_test.py b/backend/tests/agents/datastore_agent_test.py
index 2091591ec..c6de422b7 100644
--- a/backend/tests/agents/datastore_agent_test.py
+++ b/backend/tests/agents/datastore_agent_test.py
@@ -34,7 +34,7 @@ async def test_generate_query_success(mock_load_prompt, mock_publish_log_info,
     result = await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order,
                                               timeframe, llm, model)
 
-    assert result == "Mocked response from the database"
+    assert result == '{\n    "content": "Mocked response from the database",\n    "ignore_validation": "false"\n}'
     mock_load_prompt.assert_called()
     llm.chat.assert_called_once_with(
         model,

From 2ff6783bc33712ad1bd174f4211e08697ecf61ad Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 5 Sep 2024 14:00:17 +0100
Subject: [PATCH 12/48] Push the webagent tests

---
 backend/src/agents/web_agent.py        |   9 +-
 backend/src/utils/web_utils.py         |   2 +-
 backend/tests/agents/web_agent_test.py | 127 +++++++++++--------------
 3 files changed, 60 insertions(+), 78 deletions(-)

diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py
index b330baec6..e895a68ad 100644
--- a/backend/src/agents/web_agent.py
+++ b/backend/src/agents/web_agent.py
@@ -20,7 +20,7 @@
 
 async def web_general_search_core(search_query, llm, model) -> str:
     try:
-        search_result = perform_search(search_query, num_results=15)
+        search_result = await perform_search(search_query, num_results=15)
         if search_result["status"] == "error":
             return "No relevant information found on the internet for the given query."
         urls = search_result["urls"]
@@ -32,7 +32,8 @@ async def web_general_search_core(search_query, llm, model) -> str:
             summary = await perform_summarization(search_query, content, llm, model)
             if not summary:
                 continue
-            if await is_valid_answer(summary, search_query):
+            is_valid = await is_valid_answer(summary, search_query)
+            if is_valid:
                 response = {
                     "content": summary,
                     "ignore_validation": "false"
@@ -109,9 +110,9 @@ async def is_valid_answer(answer, task) -> bool:
     return is_valid
 
 
-def perform_search(search_query: str, num_results: int) -> Dict[str, Any]:
+async def perform_search(search_query: str, num_results: int) -> Dict[str, Any]:
     try:
-        search_result_json = search_urls(search_query, num_results=num_results)
+        search_result_json = await search_urls(search_query, num_results=num_results)
         return json.loads(search_result_json)
     except Exception as e:
         logger.error(f"Error during web search: {e}")
diff --git a/backend/src/utils/web_utils.py b/backend/src/utils/web_utils.py
index 2edaed606..90082be76 100644
--- a/backend/src/utils/web_utils.py
+++ b/backend/src/utils/web_utils.py
@@ -13,7 +13,7 @@
 engine = PromptEngine()
 
 
-def search_urls(search_query, num_results=10) -> str:
+async def search_urls(search_query, num_results=10) -> str:
     logger.info(f"Searching the web for: {search_query}")
     urls = []
     try:
diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py
index b50dc36ed..91892fb6b 100644
--- a/backend/tests/agents/web_agent_test.py
+++ b/backend/tests/agents/web_agent_test.py
@@ -1,76 +1,57 @@
-import unittest
-from unittest.mock import AsyncMock, patch
-
 import pytest
+from unittest.mock import patch, AsyncMock
+import json
 from src.agents.web_agent import web_general_search_core
 
-
-class TestWebAgentCore(unittest.TestCase):
-    def setUp(self):
-        self.llm = AsyncMock()
-        self.model = "mock_model"
-
-    @patch("src.agents.web_agent.perform_search")
-    @patch("src.agents.web_agent.perform_scrape")
-    @patch("src.agents.web_agent.perform_summarization")
-    @patch("src.agents.web_agent.is_valid_answer")
-    @pytest.mark.asyncio
-    async def test_web_general_search_core(
-        self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search
-    ):
-        mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
-        mock_perform_scrape.return_value = "Example scraped content."
-        mock_perform_summarization.return_value = "Example summary."
-        mock_is_valid_answer.return_value = True
-
-        result = await web_general_search_core("example query", self.llm, self.model)
-        self.assertEqual(result, "Example summary.")
-        mock_perform_search.assert_called_once_with("example query", num_results=15)
-        mock_perform_scrape.assert_called_once_with("http://example.com")
-        mock_perform_summarization.assert_called_once_with(
-            "example query", "Example scraped content.", self.llm, self.model
-        )
-        mock_is_valid_answer.assert_called_once_with("Example summary.", "example query")
-
-    @patch("src.agents.web_agent.perform_search")
-    @patch("src.agents.web_agent.perform_scrape")
-    @patch("src.agents.web_agent.perform_summarization")
-    @patch("src.agents.web_agent.is_valid_answer")
-    @pytest.mark.asyncio
-    async def test_web_general_search_core_no_results(
-        self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search
-    ):
-        mock_perform_search.return_value = {"status": "error", "urls": []}
-
-        result = await web_general_search_core("example query", self.llm, self.model)
-        self.assertEqual(result, "No relevant information found on the internet for the given query.")
-        mock_perform_search.assert_called_once_with("example query", num_results=15)
-        mock_perform_scrape.assert_not_called()
-        mock_perform_summarization.assert_not_called()
-        mock_is_valid_answer.assert_not_called()
-
-    @patch("src.agents.web_agent.perform_search")
-    @patch("src.agents.web_agent.perform_scrape")
-    @patch("src.agents.web_agent.perform_summarization")
-    @patch("src.agents.web_agent.is_valid_answer")
-    @pytest.mark.asyncio
-    async def test_web_general_search_core_invalid_summary(
-        self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search
-    ):
-        mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
-        mock_perform_scrape.return_value = "Example scraped content."
-        mock_perform_summarization.return_value = "Example invalid summary."
-        mock_is_valid_answer.return_value = False
-
-        result = await web_general_search_core("example query", self.llm, self.model)
-        self.assertEqual(result, "No relevant information found on the internet for the given query.")
-        mock_perform_search.assert_called_once_with("example query", num_results=15)
-        mock_perform_scrape.assert_called_once_with("http://example.com")
-        mock_perform_summarization.assert_called_once_with(
-            "example query", "Example scraped content.", self.llm, self.model
-        )
-        mock_is_valid_answer.assert_called_once_with("Example invalid summary.", "example query")
-
-
-if __name__ == "__main__":
-    unittest.main()
+@pytest.mark.asyncio
+@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
+@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock)
+async def test_web_general_search_core(
+    mock_is_valid_answer,
+    mock_perform_summarization,
+    mock_perform_scrape,
+    mock_perform_search,
+):
+    llm = AsyncMock()
+    model = "mock_model"
+
+    mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
+    mock_perform_scrape.return_value = "Example scraped content."
+    mock_perform_summarization.return_value = "Example summary."
+    mock_is_valid_answer.return_value = True
+    result = await web_general_search_core("example query", llm, model)
+    expected_response = {
+        "content": "Example summary.",
+        "ignore_validation": "false"
+    }
+    assert json.loads(result) == expected_response
+
+@pytest.mark.asyncio
+@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
+@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock)
+async def test_web_general_search_core_no_results(mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search):
+    llm = AsyncMock()
+    model = "mock_model"
+    mock_perform_search.return_value = {"status": "error", "urls": []}
+    result = await web_general_search_core("example query", llm, model)
+    assert result == "No relevant information found on the internet for the given query."
+
+
+@pytest.mark.asyncio
+@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
+@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock)
+async def test_web_general_search_core_invalid_summary(mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search):
+    llm = AsyncMock()
+    model = "mock_model"
+    mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
+    mock_perform_scrape.return_value = "Example scraped content."
+    mock_perform_summarization.return_value = "Example invalid summary."
+    mock_is_valid_answer.return_value = False
+    result = await web_general_search_core("example query", llm, model)
+    assert result == "No relevant information found on the internet for the given query."

From 87581bdfa3459c0b2197938d3da52fb5f85acce1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Fri, 6 Sep 2024 11:17:05 +0100
Subject: [PATCH 13/48] Remove sanitise function + move nested function out of
 main function

---
 backend/src/agents/datastore_agent.py       | 21 +++++++-------
 backend/src/utils/semantic_layer_builder.py | 31 ++++++++-------------
 2 files changed, 22 insertions(+), 30 deletions(-)

diff --git a/backend/src/agents/datastore_agent.py b/backend/src/agents/datastore_agent.py
index 3cc9f2f73..f876f9a0b 100644
--- a/backend/src/agents/datastore_agent.py
+++ b/backend/src/agents/datastore_agent.py
@@ -21,15 +21,6 @@ async def generate_cypher_query_core(
     question_intent, operation, question_params, aggregation, sort_order, timeframe, llm: LLM, model
 ) -> str:
 
-    async def get_semantic_layer_cache(graph_schema):
-        global cache
-        if not cache:
-            graph_schema = await get_semantic_layer(llm, model)
-            cache = graph_schema
-            return cache
-        else:
-            return cache
-
     details_to_create_cypher_query = engine.load_prompt(
         "details-to-create-cypher-query",
         question_intent=question_intent,
@@ -40,7 +31,7 @@ async def get_semantic_layer_cache(graph_schema):
         timeframe=timeframe,
     )
     try:
-        graph_schema = await get_semantic_layer_cache(cache)
+        graph_schema = await get_semantic_layer_cache(llm, model, cache)
         graph_schema = json.dumps(graph_schema, separators=(",", ":"))
 
         generate_cypher_query_prompt = engine.load_prompt(
@@ -103,6 +94,16 @@ async def generate_cypher(question_intent, operation, question_params, aggregati
     return await generate_cypher_query_core(question_intent, operation, question_params, aggregation, sort_order,
                                             timeframe, llm, model)
 
+
+async def get_semantic_layer_cache(llm, model, graph_schema):
+    global cache
+    if not cache:
+        graph_schema = await get_semantic_layer(llm, model)
+        cache = graph_schema
+        return cache
+    else:
+        return cache
+
 @agent(
     name="DatastoreAgent",
     description="This agent is responsible for handling database queries relating to the user's personal data.",
diff --git a/backend/src/utils/semantic_layer_builder.py b/backend/src/utils/semantic_layer_builder.py
index 419cd8aba..d6f3ca924 100644
--- a/backend/src/utils/semantic_layer_builder.py
+++ b/backend/src/utils/semantic_layer_builder.py
@@ -80,8 +80,8 @@ async def get_semantic_layer(llm, model):
     enriched_relationships_list = []
 
     for relationship in relationships:
-        enriched_relationship = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationship))
-        enriched_relationship = sanitise_script(enriched_relationship)
+        enriched_relationship = await llm.chat(model, neo4j_relationships_understanding_prompt, str(relationship),
+                                               return_json=True)
         enriched_relationships_list.append(json.loads(enriched_relationship))
 
         finalised_graph_structure['relationships'] = enriched_relationships_list
@@ -89,8 +89,7 @@ async def get_semantic_layer(llm, model):
 
     # Fetch and enrich nodes
     neo4j_data = finalised_graph_structure['nodes']
-    enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(neo4j_data))
-    enriched_nodes = sanitise_script(enriched_nodes)
+    enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(neo4j_data), return_json=True)
     enriched_nodes = json.loads(enriched_nodes)
     json.dumps(enriched_nodes)
     finalised_graph_structure['nodes'] = enriched_nodes
@@ -110,20 +109,20 @@ async def get_semantic_layer(llm, model):
     rel_properties_neo4j = {'relProperties': cleaned_rel_properties}
     json.dumps(rel_properties_neo4j)
 
-    enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j))
-    enriched_rel_properties = sanitise_script(enriched_rel_properties)
+    enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j),
+                                             return_json=True)
     enriched_rel_properties = json.loads(enriched_rel_properties)
 
     # Merge properties
     for new_rel in enriched_rel_properties["relProperties"]:
-        relationship_type = new_rel["relationship_type"]
-        properties_to_add = new_rel["properties"]
+        relationship_type = new_rel["relType"]
+        properties_to_add = new_rel["property"]
 
         for rel in finalised_graph_structure["relationships"]:
             if rel["cypher_representation"] == relationship_type:
                 if "properties" not in rel:
-                    rel["properties"] = []
-                rel["properties"] = properties_to_add
+                    rel["property"] = []
+                rel["property"] = properties_to_add
 
     logger.debug(f"finalised graph structure with enriched properties: {finalised_graph_structure}")
 
@@ -136,8 +135,8 @@ async def get_semantic_layer(llm, model):
             if all(prop['data_type'] is not None and prop['name'] is not None for prop in node['properties'])
         ]
     }
-    enriched_node_properties = await llm.chat(model, neo4j_node_property_prompt, str(filtered_payload))
-    enriched_node_properties = sanitise_script(enriched_node_properties)
+    enriched_node_properties = await llm.chat(model, neo4j_node_property_prompt, str(filtered_payload),
+                                              return_json=True)
     enriched_node_properties = json.loads(enriched_node_properties)
 
     for new_node in enriched_node_properties["nodeProperties"]:
@@ -152,11 +151,3 @@ async def get_semantic_layer(llm, model):
     logger.debug(f"finalised graph structure with enriched nodes: {finalised_graph_structure}")
 
     return finalised_graph_structure
-
-def sanitise_script(script: str) -> str:
-    script = script.strip()
-    if script.startswith("```json"):
-        script = script[7:]
-    if script.endswith("```"):
-        script = script[:-3]
-    return script.strip()

From 3dea87678389767c60ce5569430b1384ff6e195f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Fri, 6 Sep 2024 12:19:58 +0100
Subject: [PATCH 14/48] Separate logic by creating methods

---
 backend/src/utils/semantic_layer_builder.py | 76 +++++++++++----------
 1 file changed, 41 insertions(+), 35 deletions(-)

diff --git a/backend/src/utils/semantic_layer_builder.py b/backend/src/utils/semantic_layer_builder.py
index d6f3ca924..c030e6237 100644
--- a/backend/src/utils/semantic_layer_builder.py
+++ b/backend/src/utils/semantic_layer_builder.py
@@ -6,34 +6,31 @@
 logger = logging.getLogger(__name__)
 
 engine = PromptEngine()
+relationship_property_query = engine.load_prompt("relationship-property-cypher-query")
 
+node_property_query = engine.load_prompt("node-property-cypher-query")
 
-async def get_semantic_layer(llm, model):
-    finalised_graph_structure = {"nodes": {}, "properties": {}}
-
-    neo4j_graph_why_prompt = engine.load_prompt("neo4j-graph-why")
-
-    relationship_query = engine.load_prompt("relationships-query")
+neo4j_graph_why_prompt = engine.load_prompt("neo4j-graph-why")
 
-    relationship_property_query = engine.load_prompt("relationship-property-cypher-query")
+neo4j_nodes_understanding_prompt = engine.load_prompt(
+    "neo4j-nodes-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+)
 
-    node_property_query = engine.load_prompt("node-property-cypher-query")
+neo4j_relationship_property_prompt = engine.load_prompt(
+    "neo4j-property-intent-prompt", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+)
 
-    neo4j_relationships_understanding_prompt = engine.load_prompt(
-        "neo4j-relationship-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
+neo4j_node_property_prompt = engine.load_prompt(
+    "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+)
+relationship_query = engine.load_prompt("relationships-query")
 
-    neo4j_nodes_understanding_prompt = engine.load_prompt(
-        "neo4j-nodes-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
+neo4j_relationships_understanding_prompt = engine.load_prompt(
+    "neo4j-relationship-understanding", neo4j_graph_why_prompt=neo4j_graph_why_prompt
+)
 
-    neo4j_relationship_property_prompt = engine.load_prompt(
-        "neo4j-property-intent-prompt", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
-
-    neo4j_node_property_prompt = engine.load_prompt(
-        "neo4j-node-property", neo4j_graph_why_prompt=neo4j_graph_why_prompt
-    )
+async def get_semantic_layer(llm, model):
+    finalised_graph_structure = {"nodes": {}, "properties": {}}
 
     relationship_result = execute_query(relationship_query)
     payload = relationship_result[0]
@@ -76,6 +73,15 @@ async def get_semantic_layer(llm, model):
         "relationships": relationships
     }
     json.dumps(finalised_graph_structure)
+
+    await enrich_relationships(llm, model, finalised_graph_structure)
+    await enrich_nodes(llm, model, finalised_graph_structure)
+    await enriched_rel_properties(llm, model, finalised_graph_structure)
+    await enrich_nodes_properties(llm, model, finalised_graph_structure)
+
+    return finalised_graph_structure
+
+async def enrich_relationships(llm, model, finalised_graph_structure):
     relationships = finalised_graph_structure['relationships']
     enriched_relationships_list = []
 
@@ -87,15 +93,17 @@ async def get_semantic_layer(llm, model):
         finalised_graph_structure['relationships'] = enriched_relationships_list
     logger.debug(f"finalised graph structure with enriched relationships: {finalised_graph_structure}")
 
-    # Fetch and enrich nodes
-    neo4j_data = finalised_graph_structure['nodes']
-    enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(neo4j_data), return_json=True)
-    enriched_nodes = json.loads(enriched_nodes)
-    json.dumps(enriched_nodes)
-    finalised_graph_structure['nodes'] = enriched_nodes
-    logger.debug(f"finalised graph structure: {finalised_graph_structure}")
-
-    # Fetch and enrich relationship properties
+async def enrich_nodes(llm, model, finalised_graph_structure):
+        neo4j_data = finalised_graph_structure['nodes']
+        print(f"neo4j data: {neo4j_data}")
+        enriched_nodes = await llm.chat(model, neo4j_nodes_understanding_prompt, str(neo4j_data), return_json=True)
+        enriched_nodes = json.loads(enriched_nodes)
+        json.dumps(enriched_nodes)
+        finalised_graph_structure['nodes'] = enriched_nodes
+        logger.debug(f"finalised graph structure: {finalised_graph_structure}")
+        print(f"finalised_graph_structure with nodes: {finalised_graph_structure}")
+
+async def enriched_rel_properties(llm, model, finalised_graph_structure):
     properties_result = execute_query(relationship_property_query)
     rel_properties_neo4j = properties_result[0]
     cleaned_rel_properties = []
@@ -110,7 +118,7 @@ async def get_semantic_layer(llm, model):
     json.dumps(rel_properties_neo4j)
 
     enriched_rel_properties = await llm.chat(model, neo4j_relationship_property_prompt, str(rel_properties_neo4j),
-                                             return_json=True)
+                                            return_json=True)
     enriched_rel_properties = json.loads(enriched_rel_properties)
 
     # Merge properties
@@ -126,7 +134,7 @@ async def get_semantic_layer(llm, model):
 
     logger.debug(f"finalised graph structure with enriched properties: {finalised_graph_structure}")
 
-    # Fetch and enrich node properties
+async def enrich_nodes_properties(llm, model, finalised_graph_structure):
     node_properties_neo4j_result = execute_query(node_property_query)
     node_properties_neo4j = node_properties_neo4j_result[0]
     filtered_payload = {
@@ -136,7 +144,7 @@ async def get_semantic_layer(llm, model):
         ]
     }
     enriched_node_properties = await llm.chat(model, neo4j_node_property_prompt, str(filtered_payload),
-                                              return_json=True)
+                                            return_json=True)
     enriched_node_properties = json.loads(enriched_node_properties)
 
     for new_node in enriched_node_properties["nodeProperties"]:
@@ -149,5 +157,3 @@ async def get_semantic_layer(llm, model):
                     node["properties"] = []
                 node["properties"] = properties_to_add
     logger.debug(f"finalised graph structure with enriched nodes: {finalised_graph_structure}")
-
-    return finalised_graph_structure

From 7795a454f279c2e7e2a25a4a5bab1866fe878054 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Mon, 9 Sep 2024 10:10:33 +0100
Subject: [PATCH 15/48] Changing apostrophe

---
 backend/src/prompts/templates/intent.j2 | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index 07617ecd9..3f755d040 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -19,7 +19,7 @@ Guidelines:
     - The category of the question (e.g., "data-driven", "data presentation", "general knowledge").
     - Any specific parameters or conditions that apply.
     - The correct aggregation and sorting methods if applicable.
-4. Avoid conflating intents: If a user’s query asks for data retrieval and its visualization, treat these as separate operations.
+4. Avoid conflating intents: If a user's query asks for data retrieval and its visualization, treat these as separate operations.
 5. Do not make assumptions or create hypothetical data. Use only concrete data where applicable.
  
 Specify an operation type under the operation key; here are a few examples: 

From 5ea5f1c28eaba98c07b308f1e2f85c5b62cb1998 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Mon, 9 Sep 2024 10:47:25 +0100
Subject: [PATCH 16/48] Lint changes

---
 backend/tests/agents/web_agent_test.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py
index 91892fb6b..c18d2070d 100644
--- a/backend/tests/agents/web_agent_test.py
+++ b/backend/tests/agents/web_agent_test.py
@@ -33,7 +33,12 @@ async def test_web_general_search_core(
 @patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
 @patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
 @patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock)
-async def test_web_general_search_core_no_results(mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search):
+async def test_web_general_search_core_no_results(
+    mock_is_valid_answer,
+    mock_perform_summarization,
+    mock_perform_scrape,
+    mock_perform_search,
+):
     llm = AsyncMock()
     model = "mock_model"
     mock_perform_search.return_value = {"status": "error", "urls": []}
@@ -46,7 +51,12 @@ async def test_web_general_search_core_no_results(mock_is_valid_answer, mock_per
 @patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
 @patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
 @patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock)
-async def test_web_general_search_core_invalid_summary(mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search):
+async def test_web_general_search_core_invalid_summary(
+    mock_is_valid_answer,
+    mock_perform_summarization,
+    mock_perform_scrape,
+    mock_perform_search
+):
     llm = AsyncMock()
     model = "mock_model"
     mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}

From 84a6ca2f05d5212e16629f5a3be4fd8e36ad3551 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Mon, 9 Sep 2024 13:47:15 +0100
Subject: [PATCH 17/48] changes following merge conflict

---
 backend/tests/agents/web_agent_test.py | 125 ++++++++++++-------------
 1 file changed, 59 insertions(+), 66 deletions(-)

diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py
index b50dc36ed..7f5bb989f 100644
--- a/backend/tests/agents/web_agent_test.py
+++ b/backend/tests/agents/web_agent_test.py
@@ -1,76 +1,69 @@
 import unittest
 from unittest.mock import AsyncMock, patch
+import json
 
 import pytest
 from src.agents.web_agent import web_general_search_core
 
+@pytest.mark.asyncio
+@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
+@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock)
+async def test_web_general_search_core(
+    mock_is_valid_answer,
+    mock_perform_summarization,
+    mock_perform_scrape,
+    mock_perform_search,
+):
+    llm = AsyncMock()
+    model = "mock_model"
 
-class TestWebAgentCore(unittest.TestCase):
-    def setUp(self):
-        self.llm = AsyncMock()
-        self.model = "mock_model"
+    mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
+    mock_perform_scrape.return_value = "Example scraped content."
+    mock_perform_summarization.return_value = "Example summary."
+    mock_is_valid_answer.return_value = True
+    result = await web_general_search_core("example query", llm, model)
+    expected_response = {
+        "content": "Example summary.",
+        "ignore_validation": "false"
+    }
+    assert json.loads(result) == expected_response
 
-    @patch("src.agents.web_agent.perform_search")
-    @patch("src.agents.web_agent.perform_scrape")
-    @patch("src.agents.web_agent.perform_summarization")
-    @patch("src.agents.web_agent.is_valid_answer")
-    @pytest.mark.asyncio
-    async def test_web_general_search_core(
-        self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search
-    ):
-        mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
-        mock_perform_scrape.return_value = "Example scraped content."
-        mock_perform_summarization.return_value = "Example summary."
-        mock_is_valid_answer.return_value = True
+@pytest.mark.asyncio
+@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
+@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock)
+async def test_web_general_search_core_no_results(
+    mock_is_valid_answer,
+    mock_perform_summarization,
+    mock_perform_scrape,
+    mock_perform_search,
+):
+    llm = AsyncMock()
+    model = "mock_model"
+    mock_perform_search.return_value = {"status": "error", "urls": []}
+    result = await web_general_search_core("example query", llm, model)
+    assert result == "No relevant information found on the internet for the given query."
 
-        result = await web_general_search_core("example query", self.llm, self.model)
-        self.assertEqual(result, "Example summary.")
-        mock_perform_search.assert_called_once_with("example query", num_results=15)
-        mock_perform_scrape.assert_called_once_with("http://example.com")
-        mock_perform_summarization.assert_called_once_with(
-            "example query", "Example scraped content.", self.llm, self.model
-        )
-        mock_is_valid_answer.assert_called_once_with("Example summary.", "example query")
 
-    @patch("src.agents.web_agent.perform_search")
-    @patch("src.agents.web_agent.perform_scrape")
-    @patch("src.agents.web_agent.perform_summarization")
-    @patch("src.agents.web_agent.is_valid_answer")
-    @pytest.mark.asyncio
-    async def test_web_general_search_core_no_results(
-        self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search
-    ):
-        mock_perform_search.return_value = {"status": "error", "urls": []}
-
-        result = await web_general_search_core("example query", self.llm, self.model)
-        self.assertEqual(result, "No relevant information found on the internet for the given query.")
-        mock_perform_search.assert_called_once_with("example query", num_results=15)
-        mock_perform_scrape.assert_not_called()
-        mock_perform_summarization.assert_not_called()
-        mock_is_valid_answer.assert_not_called()
-
-    @patch("src.agents.web_agent.perform_search")
-    @patch("src.agents.web_agent.perform_scrape")
-    @patch("src.agents.web_agent.perform_summarization")
-    @patch("src.agents.web_agent.is_valid_answer")
-    @pytest.mark.asyncio
-    async def test_web_general_search_core_invalid_summary(
-        self, mock_is_valid_answer, mock_perform_summarization, mock_perform_scrape, mock_perform_search
-    ):
-        mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
-        mock_perform_scrape.return_value = "Example scraped content."
-        mock_perform_summarization.return_value = "Example invalid summary."
-        mock_is_valid_answer.return_value = False
-
-        result = await web_general_search_core("example query", self.llm, self.model)
-        self.assertEqual(result, "No relevant information found on the internet for the given query.")
-        mock_perform_search.assert_called_once_with("example query", num_results=15)
-        mock_perform_scrape.assert_called_once_with("http://example.com")
-        mock_perform_summarization.assert_called_once_with(
-            "example query", "Example scraped content.", self.llm, self.model
-        )
-        mock_is_valid_answer.assert_called_once_with("Example invalid summary.", "example query")
-
-
-if __name__ == "__main__":
-    unittest.main()
+@pytest.mark.asyncio
+@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
+@patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
+@patch("src.agents.web_agent.is_valid_answer", new_callable=AsyncMock)
+async def test_web_general_search_core_invalid_summary(
+    mock_is_valid_answer,
+    mock_perform_summarization,
+    mock_perform_scrape,
+    mock_perform_search
+):
+    llm = AsyncMock()
+    model = "mock_model"
+    mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
+    mock_perform_scrape.return_value = "Example scraped content."
+    mock_perform_summarization.return_value = "Example invalid summary."
+    mock_is_valid_answer.return_value = False
+    result = await web_general_search_core("example query", llm, model)
+    assert result == "No relevant information found on the internet for the given query."

From 78fa412ae48366bcb5070e71fc63d8039cb6b840 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Mon, 9 Sep 2024 15:45:23 +0100
Subject: [PATCH 18/48] Edit chart generator tests to move away from unittest

---
 .../agents/chart_generator_agent_test.py      | 108 ++++++++++++------
 1 file changed, 70 insertions(+), 38 deletions(-)

diff --git a/backend/tests/agents/chart_generator_agent_test.py b/backend/tests/agents/chart_generator_agent_test.py
index 8bfb7629e..59c410897 100644
--- a/backend/tests/agents/chart_generator_agent_test.py
+++ b/backend/tests/agents/chart_generator_agent_test.py
@@ -1,61 +1,93 @@
 from io import BytesIO
-import unittest
 from unittest.mock import patch, AsyncMock, MagicMock
 import pytest
 from src.agents.chart_generator_agent import generate_chart
+import base64
+import matplotlib.pyplot as plt
+from PIL import Image
+import json
 
+@pytest.mark.asyncio
+@patch("src.agents.chart_generator_agent.engine.load_prompt")
+@patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock)
+async def test_generate_code_success(mock_sanitise_script, mock_load_prompt):
+    llm = AsyncMock()
+    model = "mock_model"
 
-class TestGenerateChartAgent(unittest.TestCase):
-    def setUp(self):
-        self.llm = AsyncMock()
-        self.model = "mock_model"
-        self.details_to_generate_chart_code = "details to generate chart code"
-        self.generate_chart_code_prompt = "generate chart code prompt"
+    mock_load_prompt.side_effect = [
+        "details to create chart code prompt",
+        "generate chart code prompt"
+    ]
 
-    @pytest.mark.asyncio
-    @patch("src.agents.chart_generator_agent.engine.load_prompt")
-    @patch("src.agents.chart_generator_agent.sanitise_script")
-    async def test_generate_code_success(self, mock_sanitise_script, mock_load_prompt):
-        mock_load_prompt.side_effect = [self.details_to_generate_chart_code, self.generate_chart_code_prompt]
-        self.llm.chat.return_value = "generated code"
-        mock_sanitise_script.return_value = """
+    llm.chat.return_value = "generated code"
 
+    mock_sanitise_script.return_value = """
 import matplotlib.pyplot as plt
 fig = plt.figure()
 plt.plot([1, 2, 3], [4, 5, 6])
-
 """
+    plt.switch_backend('Agg')
+
+    def mock_exec_side_effect(script, globals=None, locals=None):
+        if isinstance(script, str):
+            fig = plt.figure()
+            plt.plot([1, 2, 3], [4, 5, 6])
+            if locals is None:
+                locals = {}
+            locals['fig'] = fig
+
+    with patch("builtins.exec", side_effect=mock_exec_side_effect):
+        result = await generate_chart("question_intent", "data_provided", "question_params", llm, model)
 
-        with patch("matplotlib.pyplot.figure") as mock_fig:
-            mock_fig_instance = MagicMock()
-            mock_fig.return_value = mock_fig_instance
-            result = await generate_chart("question_intent", "data_provided", "question_params", self.llm, self.model)
-            buf = BytesIO()
-            mock_fig_instance.savefig.assert_called_once_with(buf, format="png")
+        response = json.loads(result)
 
-        self.llm.chat.assert_called_once_with(
-            self.model, self.generate_chart_code_prompt, self.details_to_generate_chart_code
+        image_data = response["content"]
+        decoded_image = base64.b64decode(image_data)
+
+        image = Image.open(BytesIO(decoded_image))
+        image.verify()
+
+        llm.chat.assert_called_once_with(
+            model,
+            "generate chart code prompt",
+            "details to create chart code prompt"
         )
         mock_sanitise_script.assert_called_once_with("generated code")
-        self.assertIsInstance(result, str)
 
-    @pytest.mark.asyncio
-    @patch("src.agents.chart_generator_agent.engine.load_prompt")
-    @patch("src.agents.chart_generator_agent.sanitise_script")
-    async def test_generate_code_no_figure(self, mock_sanitise_script, mock_load_prompt):
-        mock_load_prompt.side_effect = [self.details_to_generate_chart_code, self.generate_chart_code_prompt]
-        self.llm.chat.return_value = "generated code"
-        mock_sanitise_script.return_value = """
+@pytest.mark.asyncio
+@patch("src.agents.chart_generator_agent.engine.load_prompt")
+@patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock)
+async def test_generate_code_no_figure(mock_sanitise_script, mock_load_prompt):
+    llm = AsyncMock()
+    model = "mock_model"
 
-import matplotlib.pyplot as plt
-# No figure is created
+    mock_load_prompt.side_effect = [
+        "details to create chart code prompt",
+        "generate chart code prompt"
+    ]
 
+    llm.chat.return_value = "generated code"
+
+    mock_sanitise_script.return_value = """
+import matplotlib.pyplot as plt
+# No fig creation
 """
 
-        with self.assertRaises(ValueError) as context:
-            await generate_chart("question_intent", "data_provided", "question_params", self.llm, self.model)
-        self.assertEqual(str(context.exception), "The generated code did not produce a figure named 'fig'.")
+    plt.switch_backend('Agg')
+
+    def mock_exec_side_effect(script, globals=None, locals=None):
+        if isinstance(script, str):
+            if locals is None:
+                locals = {}
 
+    with patch("builtins.exec", side_effect=mock_exec_side_effect):
+        with pytest.raises(ValueError, match="The generated code did not produce a figure named 'fig'."):
+            await generate_chart("question_intent", "data_provided", "question_params", llm, model)
 
-if __name__ == "__main__":
-    unittest.main()
+        llm.chat.assert_called_once_with(
+            model,
+            "generate chart code prompt",
+            "details to create chart code prompt"
+        )
+
+        mock_sanitise_script.assert_called_once_with("generated code")

From 2d57698758971c4d45eb1c7af7314dae19d786fd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Mon, 9 Sep 2024 15:49:49 +0100
Subject: [PATCH 19/48] Fix merging issues

---
 backend/tests/agents/web_agent_test.py | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py
index 7f5bb989f..8d516d1c3 100644
--- a/backend/tests/agents/web_agent_test.py
+++ b/backend/tests/agents/web_agent_test.py
@@ -1,8 +1,6 @@
-import unittest
-from unittest.mock import AsyncMock, patch
-import json
-
 import pytest
+from unittest.mock import patch, AsyncMock
+import json
 from src.agents.web_agent import web_general_search_core
 
 @pytest.mark.asyncio
@@ -25,7 +23,7 @@ async def test_web_general_search_core(
     mock_is_valid_answer.return_value = True
     result = await web_general_search_core("example query", llm, model)
     expected_response = {
-        "content": "Example summary.",
+        "content": "Example salut.",
         "ignore_validation": "false"
     }
     assert json.loads(result) == expected_response

From 369bf74aa0b877c73dd25e78947c7e6c190ceef1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Mon, 9 Sep 2024 15:50:54 +0100
Subject: [PATCH 20/48] Fix merging issues with web agent tests

---
 backend/tests/agents/web_agent_test.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py
index 8d516d1c3..c18d2070d 100644
--- a/backend/tests/agents/web_agent_test.py
+++ b/backend/tests/agents/web_agent_test.py
@@ -23,7 +23,7 @@ async def test_web_general_search_core(
     mock_is_valid_answer.return_value = True
     result = await web_general_search_core("example query", llm, model)
     expected_response = {
-        "content": "Example salut.",
+        "content": "Example summary.",
         "ignore_validation": "false"
     }
     assert json.loads(result) == expected_response

From 2dbe80fc40fb9d7a7837e896457f0c5f13ac9a6d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?H=C3=A9l=C3=A8ne=20Sauv=C3=A9?= <hsauve@scottlogic.com>
Date: Tue, 10 Sep 2024 10:21:44 +0100
Subject: [PATCH 21/48] Add tests based on PR comments

---
 .../agents/chart_generator_agent_test.py      | 58 +++++++++++++++++--
 1 file changed, 54 insertions(+), 4 deletions(-)

diff --git a/backend/tests/agents/chart_generator_agent_test.py b/backend/tests/agents/chart_generator_agent_test.py
index 59c410897..b3d404336 100644
--- a/backend/tests/agents/chart_generator_agent_test.py
+++ b/backend/tests/agents/chart_generator_agent_test.py
@@ -6,6 +6,7 @@
 import matplotlib.pyplot as plt
 from PIL import Image
 import json
+from src.agents.chart_generator_agent import sanitise_script
 
 @pytest.mark.asyncio
 @patch("src.agents.chart_generator_agent.engine.load_prompt")
@@ -21,7 +22,7 @@ async def test_generate_code_success(mock_sanitise_script, mock_load_prompt):
 
     llm.chat.return_value = "generated code"
 
-    mock_sanitise_script.return_value = """
+    return_string = mock_sanitise_script.return_value = """
 import matplotlib.pyplot as plt
 fig = plt.figure()
 plt.plot([1, 2, 3], [4, 5, 6])
@@ -29,7 +30,7 @@ async def test_generate_code_success(mock_sanitise_script, mock_load_prompt):
     plt.switch_backend('Agg')
 
     def mock_exec_side_effect(script, globals=None, locals=None):
-        if isinstance(script, str):
+        if script == return_string:
             fig = plt.figure()
             plt.plot([1, 2, 3], [4, 5, 6])
             if locals is None:
@@ -68,7 +69,7 @@ async def test_generate_code_no_figure(mock_sanitise_script, mock_load_prompt):
 
     llm.chat.return_value = "generated code"
 
-    mock_sanitise_script.return_value = """
+    return_string = mock_sanitise_script.return_value = """
 import matplotlib.pyplot as plt
 # No fig creation
 """
@@ -76,7 +77,7 @@ async def test_generate_code_no_figure(mock_sanitise_script, mock_load_prompt):
     plt.switch_backend('Agg')
 
     def mock_exec_side_effect(script, globals=None, locals=None):
-        if isinstance(script, str):
+        if script == return_string:
             if locals is None:
                 locals = {}
 
@@ -91,3 +92,52 @@ def mock_exec_side_effect(script, globals=None, locals=None):
         )
 
         mock_sanitise_script.assert_called_once_with("generated code")
+
+@pytest.mark.parametrize(
+    "input_script, expected_output",
+    [
+
+        (
+            """```python
+import matplotlib.pyplot as plt
+fig = plt.figure()
+plt.plot([1, 2, 3], [4, 5, 6])
+```""",
+            """import matplotlib.pyplot as plt
+fig = plt.figure()
+plt.plot([1, 2, 3], [4, 5, 6])"""
+        ),
+        (
+            """```python
+import matplotlib.pyplot as plt
+fig = plt.figure()
+plt.plot([1, 2, 3], [4, 5, 6])""",
+            """import matplotlib.pyplot as plt
+fig = plt.figure()
+plt.plot([1, 2, 3], [4, 5, 6])"""
+        ),
+        (
+            """import matplotlib.pyplot as plt
+fig = plt.figure()
+plt.plot([1, 2, 3], [4, 5, 6])
+```""",
+            """import matplotlib.pyplot as plt
+fig = plt.figure()
+plt.plot([1, 2, 3], [4, 5, 6])"""
+        ),
+        (
+            """import matplotlib.pyplot as plt
+fig = plt.figure()
+plt.plot([1, 2, 3], [4, 5, 6])""",
+            """import matplotlib.pyplot as plt
+fig = plt.figure()
+plt.plot([1, 2, 3], [4, 5, 6])"""
+        ),
+        (
+            "",
+            ""
+        )
+    ]
+)
+def test_sanitise_script(input_script, expected_output):
+    assert sanitise_script(input_script) == expected_output

From 711bd48e80f3ddcb89bdee0813b47e69de6dbf27 Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Wed, 11 Sep 2024 10:05:01 +0100
Subject: [PATCH 22/48] Move websocket related tests to correct directory

---
 .../connection_manager_test.py                |  0
 .../message_handlers_test.py                  | 52 +++++++++----------
 2 files changed, 26 insertions(+), 26 deletions(-)
 rename backend/tests/{api => websockets}/connection_manager_test.py (100%)
 rename backend/tests/{api => websockets}/message_handlers_test.py (96%)

diff --git a/backend/tests/api/connection_manager_test.py b/backend/tests/websockets/connection_manager_test.py
similarity index 100%
rename from backend/tests/api/connection_manager_test.py
rename to backend/tests/websockets/connection_manager_test.py
diff --git a/backend/tests/api/message_handlers_test.py b/backend/tests/websockets/message_handlers_test.py
similarity index 96%
rename from backend/tests/api/message_handlers_test.py
rename to backend/tests/websockets/message_handlers_test.py
index 42040c591..592b4f2d8 100644
--- a/backend/tests/api/message_handlers_test.py
+++ b/backend/tests/websockets/message_handlers_test.py
@@ -1,26 +1,26 @@
-from unittest.mock import call
-import pytest
-from src.websockets.message_handlers import create_on_ping, pong
-
-
-def test_on_ping_send_pong(mocker):
-    on_ping = create_on_ping()
-    mock_ws = mocker.Mock()
-    mock_disconnect = mocker.AsyncMock()
-    mocked_create_task = mocker.patch("asyncio.create_task")
-
-    on_ping(mock_ws, mock_disconnect, None)
-
-    first_call = mocked_create_task.call_args_list[0]
-    assert first_call == call(mock_ws.send_json(pong))
-
-
-@pytest.mark.asyncio
-async def test_on_ping_no_disconnect(mocker):
-    on_ping = create_on_ping()
-    mock_ws = mocker.AsyncMock()
-    mock_disconnect = mocker.AsyncMock()
-
-    on_ping(mock_ws, mock_disconnect, None)
-
-    mock_disconnect.assert_not_awaited()
+from unittest.mock import call
+import pytest
+from src.websockets.message_handlers import create_on_ping, pong
+
+
+def test_on_ping_send_pong(mocker):
+    on_ping = create_on_ping()
+    mock_ws = mocker.Mock()
+    mock_disconnect = mocker.AsyncMock()
+    mocked_create_task = mocker.patch("asyncio.create_task")
+
+    on_ping(mock_ws, mock_disconnect, None)
+
+    first_call = mocked_create_task.call_args_list[0]
+    assert first_call == call(mock_ws.send_json(pong))
+
+
+@pytest.mark.asyncio
+async def test_on_ping_no_disconnect(mocker):
+    on_ping = create_on_ping()
+    mock_ws = mocker.AsyncMock()
+    mock_disconnect = mocker.AsyncMock()
+
+    on_ping(mock_ws, mock_disconnect, None)
+
+    mock_disconnect.assert_not_awaited()

From 223f02aaea63ddfff2bc9985a8fd9b9077b6951f Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Wed, 11 Sep 2024 12:28:19 +0100
Subject: [PATCH 23/48] Add user_confirmer and confirmations_manager to send
 confirmations through the web socket

---
 .../src/websockets/confirmations_manager.py   |  43 +++++++
 backend/src/websockets/message_handlers.py    |  30 ++++-
 backend/src/websockets/types.py               |   3 +-
 backend/src/websockets/user_confirmer.py      |  45 +++++++
 .../websockets/confirmations_manager_test.py  |  86 ++++++++++++++
 .../tests/websockets/message_handlers_test.py | 112 +++++++++++++++++-
 .../tests/websockets/user_confirmer_test.py   |  73 ++++++++++++
 7 files changed, 388 insertions(+), 4 deletions(-)
 create mode 100644 backend/src/websockets/confirmations_manager.py
 create mode 100644 backend/src/websockets/user_confirmer.py
 create mode 100644 backend/tests/websockets/confirmations_manager_test.py
 create mode 100644 backend/tests/websockets/user_confirmer_test.py

diff --git a/backend/src/websockets/confirmations_manager.py b/backend/src/websockets/confirmations_manager.py
new file mode 100644
index 000000000..0cf9fbcb7
--- /dev/null
+++ b/backend/src/websockets/confirmations_manager.py
@@ -0,0 +1,43 @@
+import logging
+from typing import Dict
+import uuid
+from string import Template
+
+logger = logging.getLogger(__name__)
+
+
+class ConfirmationsManager:
+    _open_confirmations: Dict[uuid.UUID, bool | None] = {}
+    _ERROR_MESSAGE = Template(" Confirmation with id '$confirmation_id' not found")
+
+    def add_confirmation(self, confirmation_id: uuid.UUID):
+        self._open_confirmations[confirmation_id] = None
+        logger.info(f"Confirmation Added: {self._open_confirmations}")
+
+    def get_confirmation_state(self, confirmation_id: uuid.UUID) -> bool | None:
+        if confirmation_id in self._open_confirmations:
+            return self._open_confirmations[confirmation_id]
+        else:
+            raise Exception(
+                "Cannot get confirmation." + self._ERROR_MESSAGE.substitute(confirmation_id=confirmation_id)
+            )
+
+    def update_confirmation(self, confirmation_id: uuid.UUID, value: bool):
+        if confirmation_id in self._open_confirmations:
+            self._open_confirmations[confirmation_id] = value
+        else:
+            raise Exception(
+                "Cannot update confirmation." + self._ERROR_MESSAGE.substitute(confirmation_id=confirmation_id)
+            )
+
+    def delete_confirmation(self, confirmation_id: uuid.UUID):
+        if confirmation_id in self._open_confirmations:
+            del self._open_confirmations[confirmation_id]
+            logger.info(f"Confirmation Deleted: {self._open_confirmations}")
+        else:
+            raise Exception(
+                "Cannot delete confirmation." + self._ERROR_MESSAGE.substitute(confirmation_id=confirmation_id)
+            )
+
+
+confirmations_manager = ConfirmationsManager()
diff --git a/backend/src/websockets/message_handlers.py b/backend/src/websockets/message_handlers.py
index 968c16502..5d2d3faab 100644
--- a/backend/src/websockets/message_handlers.py
+++ b/backend/src/websockets/message_handlers.py
@@ -1,9 +1,11 @@
 import asyncio
 import json
 import logging
+from uuid import UUID
 from fastapi import WebSocket
 from typing import Callable
 from .types import Handlers, MessageTypes
+from src.websockets.confirmations_manager import confirmations_manager
 
 logger = logging.getLogger(__name__)
 
@@ -38,4 +40,30 @@ def on_chat(websocket: WebSocket, disconnect: Callable, data: str | None):
     logger.info(f"Chat message: {data}")
 
 
-handlers: Handlers = {MessageTypes.PING: create_on_ping(), MessageTypes.CHAT: on_chat}
+def on_confirmation(websocket: WebSocket, disconnect: Callable, data: str | None):
+    if data is None:
+        logger.warning("Confirmation response did not include data")
+        return
+    if ":" not in data:
+        logger.warning("Seperator (':') not present in confirmation")
+        return
+    sections = data.split(":")
+    try:
+        id = UUID(sections[0])
+    except ValueError:
+        logger.warning("Received invalid id")
+        return
+    if sections[1] != "y" and sections[1] != "n":
+        logger.warning("Received invalid value")
+        return
+    try:
+        confirmations_manager.update_confirmation(id, sections[1] == "y")
+    except Exception as e:
+        logger.warning(f"Could not update confirmation: '{e}'")
+
+
+handlers: Handlers = {
+    MessageTypes.PING: create_on_ping(),
+    MessageTypes.CHAT: on_chat,
+    MessageTypes.CONFIRMATION: on_confirmation,
+}
diff --git a/backend/src/websockets/types.py b/backend/src/websockets/types.py
index 182b8d3e2..20d4bd83c 100644
--- a/backend/src/websockets/types.py
+++ b/backend/src/websockets/types.py
@@ -9,8 +9,9 @@ class MessageTypes(Enum):
     PING = "ping"
     PONG = "pong"
     CHAT = "chat"
-    LOG  = "log"
+    LOG = "log"
     IMAGE = "image"
+    CONFIRMATION = "confirmation"
 
 
 @dataclass
diff --git a/backend/src/websockets/user_confirmer.py b/backend/src/websockets/user_confirmer.py
new file mode 100644
index 000000000..fc59ead24
--- /dev/null
+++ b/backend/src/websockets/user_confirmer.py
@@ -0,0 +1,45 @@
+import asyncio
+import logging
+import uuid
+from src.websockets.types import Message, MessageTypes
+from .connection_manager import connection_manager
+from src.websockets.confirmations_manager import ConfirmationsManager
+
+logger = logging.getLogger(__name__)
+
+
+class UserConfirmer:
+    _POLL_RATE_SECONDS = 0.5
+    _TIMEOUT_SECONDS = 60.0
+    _CONFIRMATIONS_MANAGER: ConfirmationsManager
+
+    def __init__(self, manager: ConfirmationsManager):
+        self.confirmations_manager = manager
+
+    async def confirm(self, msg: str) -> bool:
+        id = uuid.uuid4()
+        self.confirmations_manager.add_confirmation(id)
+        await self._send_confirmation(id, msg)
+        try:
+            async with asyncio.timeout(self._TIMEOUT_SECONDS):
+                return await self._check_confirmed(id)
+        except TimeoutError:
+            logger.warning(f"Confirmation with id {id} timed out.")
+            self.confirmations_manager.delete_confirmation(id)
+            return False
+
+    async def _check_confirmed(self, id: uuid.UUID) -> bool:
+        while True:
+            try:
+                state = self.confirmations_manager.get_confirmation_state(id)
+                if isinstance(state, bool):
+                    self.confirmations_manager.delete_confirmation(id)
+                    return state
+            except Exception:
+                return False
+            await asyncio.sleep(self._POLL_RATE_SECONDS)
+
+    async def _send_confirmation(self, id: uuid.UUID, msg: str):
+        data = f"{str(id)}:{msg}"
+        message = Message(MessageTypes.CONFIRMATION, data)
+        await connection_manager.broadcast(message)
diff --git a/backend/tests/websockets/confirmations_manager_test.py b/backend/tests/websockets/confirmations_manager_test.py
new file mode 100644
index 000000000..85151831b
--- /dev/null
+++ b/backend/tests/websockets/confirmations_manager_test.py
@@ -0,0 +1,86 @@
+from uuid import uuid4
+
+import pytest
+from src.websockets.confirmations_manager import ConfirmationsManager
+
+
+class TestConfirmationsManager:
+    def test_add_confirmation(self):
+        # Arrange
+        manager = ConfirmationsManager()
+        confirmation_id = uuid4()
+
+        # Act
+        manager.add_confirmation(confirmation_id)
+
+        # Assert
+        confirmation = manager.get_confirmation_state(confirmation_id)
+        assert confirmation is None
+
+    def test_get_confirmation_state_not_found_id(self):
+        # Arrange
+        manager = ConfirmationsManager()
+        not_found_confirmation_id = uuid4()
+
+        # Act
+        with pytest.raises(Exception) as e:
+            manager.get_confirmation_state(not_found_confirmation_id)
+
+        # Assert
+        assert str(e.value) == f"Cannot get confirmation. Confirmation with id '{not_found_confirmation_id}' not found"
+
+    @pytest.mark.parametrize("input_value", [True, False])
+    def test_update_confirmation(self, input_value):
+        # Arrange
+        manager = ConfirmationsManager()
+        confirmation_id = uuid4()
+        manager.add_confirmation(confirmation_id)
+
+        # Act
+        manager.update_confirmation(confirmation_id, input_value)
+
+        # Assert
+        updated_value = manager.get_confirmation_state(confirmation_id)
+        assert updated_value == input_value
+
+    def test_update_confirmation_not_found_id(self):
+        # Arrange
+        manager = ConfirmationsManager()
+        not_found_confirmation_id = uuid4()
+
+        # Act
+        with pytest.raises(Exception) as e:
+            manager.update_confirmation(not_found_confirmation_id, True)
+
+        # Assert
+        assert (
+            str(e.value) == f"Cannot update confirmation. Confirmation with id '{not_found_confirmation_id}' not found"
+        )
+
+    def test_delete_confirmation(self):
+        # Arrange
+        manager = ConfirmationsManager()
+        confirmation_id = uuid4()
+        manager.add_confirmation(confirmation_id)
+
+        # Act
+        manager.delete_confirmation(confirmation_id)
+
+        # Assert
+        with pytest.raises(Exception) as e:
+            manager.get_confirmation_state(confirmation_id)
+        assert "Cannot get confirmation." in str(e.value)
+
+    def test_delete_confirmation_not_found_id(self):
+        # Arrange
+        manager = ConfirmationsManager()
+        not_found_confirmation_id = uuid4()
+
+        # Act
+        with pytest.raises(Exception) as e:
+            manager.delete_confirmation(not_found_confirmation_id)
+
+        # Assert
+        assert (
+            str(e.value) == f"Cannot delete confirmation. Confirmation with id '{not_found_confirmation_id}' not found"
+        )
diff --git a/backend/tests/websockets/message_handlers_test.py b/backend/tests/websockets/message_handlers_test.py
index 592b4f2d8..baf13f478 100644
--- a/backend/tests/websockets/message_handlers_test.py
+++ b/backend/tests/websockets/message_handlers_test.py
@@ -1,6 +1,8 @@
-from unittest.mock import call
+import logging
+from unittest.mock import Mock, call, patch
+from uuid import uuid4
 import pytest
-from src.websockets.message_handlers import create_on_ping, pong
+from src.websockets.message_handlers import create_on_ping, on_confirmation, pong
 
 
 def test_on_ping_send_pong(mocker):
@@ -24,3 +26,109 @@ async def test_on_ping_no_disconnect(mocker):
     on_ping(mock_ws, mock_disconnect, None)
 
     mock_disconnect.assert_not_awaited()
+
+
+@pytest.mark.parametrize("input_value,expected_bool", [("y", True), ("n", False)])
+@patch("src.websockets.message_handlers.confirmations_manager")
+def test_on_confirmation(confirmations_manager_mock, input_value, expected_bool):
+    # Arrange
+    confirmation_id = uuid4()
+    data = f"{confirmation_id}:{input_value}"
+    websocket_mock = Mock()
+    disconnect_mock = Mock()
+
+    # Act
+    on_confirmation(websocket_mock, disconnect_mock, data)
+
+    # Assert
+    confirmations_manager_mock.update_confirmation.assert_called_once_with(confirmation_id, expected_bool)
+
+
+@patch("src.websockets.message_handlers.confirmations_manager")
+def test_on_confirmation_data_is_none(confirmations_manager_mock, caplog):
+    # Arrange
+    websocket_mock = Mock()
+    disconnect_mock = Mock()
+
+    # Act
+    on_confirmation(websocket_mock, disconnect_mock, None)
+
+    # Assert
+    confirmations_manager_mock.update_confirmation.assert_not_called()
+    assert (
+        "src.websockets.message_handlers",
+        logging.WARNING,
+        "Confirmation response did not include data",
+    ) in caplog.record_tuples
+
+
+@patch("src.websockets.message_handlers.confirmations_manager")
+def test_on_confirmation_seperator_not_present(confirmations_manager_mock, caplog):
+    # Arrange
+    websocket_mock = Mock()
+    disconnect_mock = Mock()
+    data = "abc"
+
+    # Act
+    on_confirmation(websocket_mock, disconnect_mock, data)
+
+    # Assert
+    confirmations_manager_mock.update_confirmation.assert_not_called()
+    assert (
+        "src.websockets.message_handlers",
+        logging.WARNING,
+        "Seperator (':') not present in confirmation",
+    ) in caplog.record_tuples
+
+
+@patch("src.websockets.message_handlers.confirmations_manager")
+def test_on_confirmation_seperator_id_not_uuid(confirmations_manager_mock, caplog):
+    # Arrange
+    websocket_mock = Mock()
+    disconnect_mock = Mock()
+    data = "abc:y"
+
+    # Act
+    on_confirmation(websocket_mock, disconnect_mock, data)
+
+    # Assert
+    confirmations_manager_mock.update_confirmation.assert_not_called()
+    assert ("src.websockets.message_handlers", logging.WARNING, "Received invalid id") in caplog.record_tuples
+
+
+@pytest.mark.parametrize("input_value", [(""), ("abc")])
+@patch("src.websockets.message_handlers.confirmations_manager")
+def test_on_confirmation_value_not_valid(confirmations_manager_mock, caplog, input_value):
+    # Arrange
+    websocket_mock = Mock()
+    disconnect_mock = Mock()
+    confirmation_id = uuid4()
+    data = f"{confirmation_id}:{input_value}"
+
+    # Act
+    on_confirmation(websocket_mock, disconnect_mock, data)
+
+    # Assert
+    confirmations_manager_mock.update_confirmation.assert_not_called()
+    assert ("src.websockets.message_handlers", logging.WARNING, "Received invalid value") in caplog.record_tuples
+
+
+@patch("src.websockets.message_handlers.confirmations_manager")
+def test_on_confirmation_confirmation_manager_raises_exception(confirmations_manager_mock, caplog):
+    # Arrange
+    confirmation_id = uuid4()
+    data = f"{confirmation_id}:y"
+    websocket_mock = Mock()
+    disconnect_mock = Mock()
+    exception_message = "Test Exception Message"
+    confirmations_manager_mock.update_confirmation.side_effect = Exception(exception_message)
+
+    # Act
+    on_confirmation(websocket_mock, disconnect_mock, data)
+
+    # Assert
+    assert (
+        "src.websockets.message_handlers",
+        logging.WARNING,
+        f"Could not update confirmation: '{exception_message}'",
+    ) in caplog.record_tuples
diff --git a/backend/tests/websockets/user_confirmer_test.py b/backend/tests/websockets/user_confirmer_test.py
new file mode 100644
index 000000000..896ee17af
--- /dev/null
+++ b/backend/tests/websockets/user_confirmer_test.py
@@ -0,0 +1,73 @@
+import logging
+from unittest.mock import Mock, patch
+
+import pytest
+
+from src.websockets.types import Message, MessageTypes
+from src.websockets.user_confirmer import UserConfirmer
+from src.websockets.confirmations_manager import ConfirmationsManager
+
+
+class TestUserConfirmer:
+    @pytest.mark.asyncio
+    async def test_confirm_times_out(self, caplog):
+        # Arrange
+        confirmations_manager_mock = Mock(spec=ConfirmationsManager)
+        confirmations_manager_mock.get_confirmation_state.return_value = None
+        user_confirmer = UserConfirmer(confirmations_manager_mock)
+        user_confirmer._TIMEOUT_SECONDS = 0.05
+        user_confirmer._POLL_RATE_SECONDS = 0.01
+
+        # Act
+        result = await user_confirmer.confirm("Test Message")
+
+        # Assert
+        assert result is False
+        confirmations_manager_mock.add_confirmation.assert_called_once()
+        id = confirmations_manager_mock.add_confirmation.call_args.args[0]
+        confirmations_manager_mock.delete_confirmation.assert_called_once_with(id)
+        assert caplog.record_tuples == [
+            ("src.websockets.user_confirmer", logging.WARNING, f"Confirmation with id {id} timed out.")
+        ]
+
+    @pytest.mark.asyncio
+    @patch("src.websockets.connection_manager.connection_manager")
+    async def test_confirm_approved(self, connection_manager_mock):
+        # Arrange
+        confirmations_manager_mock = Mock(spec=ConfirmationsManager)
+        confirmations_manager_mock.get_confirmation_state.side_effect = [None, True]
+        user_confirmer = UserConfirmer(confirmations_manager_mock)
+        user_confirmer._POLL_RATE_SECONDS = 0.01
+
+        # Act
+        result = await user_confirmer.confirm("Test Message")
+
+        # Assert
+        assert result is True
+        confirmations_manager_mock.add_confirmation.assert_called_once()
+        id = confirmations_manager_mock.add_confirmation.call_args.args[0]
+        connection_manager_mock.broadcast.awaited_once_with(Message(MessageTypes.CONFIRMATION, f"{id}:Test Message"))
+        confirmations_manager_mock.get_confirmation_state.assert_called_with(id)
+        assert confirmations_manager_mock.get_confirmation_state.call_count == 2
+        confirmations_manager_mock.delete_confirmation.assert_called_once_with(id)
+
+    @pytest.mark.asyncio
+    @patch("src.websockets.connection_manager.connection_manager")
+    async def test_confirm_denied(self, connection_manager_mock):
+        # Arrange
+        confirmations_manager_mock = Mock(spec=ConfirmationsManager)
+        confirmations_manager_mock.get_confirmation_state.side_effect = [None, False]
+        user_confirmer = UserConfirmer(confirmations_manager_mock)
+        user_confirmer._POLL_RATE_SECONDS = 0.01
+
+        # Act
+        result = await user_confirmer.confirm("Test Message")
+
+        # Assert
+        assert result is False
+        confirmations_manager_mock.add_confirmation.assert_called_once()
+        id = confirmations_manager_mock.add_confirmation.call_args.args[0]
+        connection_manager_mock.broadcast.awaited_once_with(Message(MessageTypes.CONFIRMATION, f"{id}:Test Message"))
+        confirmations_manager_mock.get_confirmation_state.assert_called_with(id)
+        assert confirmations_manager_mock.get_confirmation_state.call_count == 2
+        confirmations_manager_mock.delete_confirmation.assert_called_once_with(id)

From a464224562646e0f0fc72f94578e595bed520aa4 Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Wed, 11 Sep 2024 15:08:57 +0100
Subject: [PATCH 24/48] Add confirmation before chart is generated in
 chart_generator_agent

---
 backend/src/agents/chart_generator_agent.py   | 24 ++++---
 .../agents/chart_generator_agent_test.py      | 70 ++++++++++++++-----
 2 files changed, 69 insertions(+), 25 deletions(-)

diff --git a/backend/src/agents/chart_generator_agent.py b/backend/src/agents/chart_generator_agent.py
index 156e51703..f8ef64572 100644
--- a/backend/src/agents/chart_generator_agent.py
+++ b/backend/src/agents/chart_generator_agent.py
@@ -9,11 +9,14 @@
 from src.utils import scratchpad
 from PIL import Image
 import json
+from src.websockets.user_confirmer import UserConfirmer
+from src.websockets.confirmations_manager import confirmations_manager
 
 logger = logging.getLogger(__name__)
 
 engine = PromptEngine()
 
+
 async def generate_chart(question_intent, data_provided, question_params, llm: LLM, model) -> str:
     details_to_generate_chart_code = engine.load_prompt(
         "details-to-generate-chart-code",
@@ -28,13 +31,17 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L
     sanitised_script = sanitise_script(generated_code)
 
     try:
+        confirmer = UserConfirmer(confirmations_manager)
+        is_confirmed = await confirmer.confirm("Would you like to generate a graph?")
+        if not is_confirmed:
+            raise Exception("The user did not confirm to creating a graph.")
         local_vars = {}
         exec(sanitised_script, {}, local_vars)
-        fig = local_vars.get('fig')
+        fig = local_vars.get("fig")
         buf = BytesIO()
         if fig is None:
             raise ValueError("The generated code did not produce a figure named 'fig'.")
-        fig.savefig(buf, format='png')
+        fig.savefig(buf, format="png")
         buf.seek(0)
         with Image.open(buf):
             image_data = base64.b64encode(buf.getvalue()).decode("utf-8")
@@ -44,7 +51,7 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L
         raise
     response = {
         "content": image_data,
-        "ignore_validation": "false"
+        "ignore_validation": "false",
     }
     return json.dumps(response, indent=4)
 
@@ -57,6 +64,7 @@ def sanitise_script(script: str) -> str:
         script = script[:-3]
     return script.strip()
 
+
 @tool(
     name="generate_code_chart",
     description="Generate Matplotlib bar chart code if the user's query involves creating a chart",
@@ -74,18 +82,18 @@ def sanitise_script(script: str) -> str:
             description="""
                 The specific parameters required for the question to be answered with the question_intent,
                 extracted from data_provided
-            """),
-    }
+            """,
+        ),
+    },
 )
-
 async def generate_code_chart(question_intent, data_provided, question_params, llm: LLM, model) -> str:
     return await generate_chart(question_intent, data_provided, question_params, llm, model)
 
+
 @agent(
     name="ChartGeneratorAgent",
     description="This agent is responsible for creating charts",
-    tools=[generate_code_chart]
+    tools=[generate_code_chart],
 )
-
 class ChartGeneratorAgent(Agent):
     pass
diff --git a/backend/tests/agents/chart_generator_agent_test.py b/backend/tests/agents/chart_generator_agent_test.py
index b3d404336..5072f6a39 100644
--- a/backend/tests/agents/chart_generator_agent_test.py
+++ b/backend/tests/agents/chart_generator_agent_test.py
@@ -8,16 +8,19 @@
 import json
 from src.agents.chart_generator_agent import sanitise_script
 
+
 @pytest.mark.asyncio
 @patch("src.agents.chart_generator_agent.engine.load_prompt")
 @patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock)
-async def test_generate_code_success(mock_sanitise_script, mock_load_prompt):
+@patch("src.agents.chart_generator_agent.UserConfirmer.confirm", new_callable=AsyncMock)
+async def test_generate_code_success(confirm_mock, mock_sanitise_script, mock_load_prompt):
+    confirm_mock.return_value = True
     llm = AsyncMock()
     model = "mock_model"
 
     mock_load_prompt.side_effect = [
         "details to create chart code prompt",
-        "generate chart code prompt"
+        "generate chart code prompt",
     ]
 
     llm.chat.return_value = "generated code"
@@ -27,7 +30,7 @@ async def test_generate_code_success(mock_sanitise_script, mock_load_prompt):
 fig = plt.figure()
 plt.plot([1, 2, 3], [4, 5, 6])
 """
-    plt.switch_backend('Agg')
+    plt.switch_backend("Agg")
 
     def mock_exec_side_effect(script, globals=None, locals=None):
         if script == return_string:
@@ -35,7 +38,7 @@ def mock_exec_side_effect(script, globals=None, locals=None):
             plt.plot([1, 2, 3], [4, 5, 6])
             if locals is None:
                 locals = {}
-            locals['fig'] = fig
+            locals["fig"] = fig
 
     with patch("builtins.exec", side_effect=mock_exec_side_effect):
         result = await generate_chart("question_intent", "data_provided", "question_params", llm, model)
@@ -51,20 +54,23 @@ def mock_exec_side_effect(script, globals=None, locals=None):
         llm.chat.assert_called_once_with(
             model,
             "generate chart code prompt",
-            "details to create chart code prompt"
+            "details to create chart code prompt",
         )
         mock_sanitise_script.assert_called_once_with("generated code")
 
+
 @pytest.mark.asyncio
 @patch("src.agents.chart_generator_agent.engine.load_prompt")
 @patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock)
-async def test_generate_code_no_figure(mock_sanitise_script, mock_load_prompt):
+@patch("src.agents.chart_generator_agent.UserConfirmer.confirm", new_callable=AsyncMock)
+async def test_generate_code_no_figure(confirm_mock, mock_sanitise_script, mock_load_prompt):
+    confirm_mock.return_value = True
     llm = AsyncMock()
     model = "mock_model"
 
     mock_load_prompt.side_effect = [
         "details to create chart code prompt",
-        "generate chart code prompt"
+        "generate chart code prompt",
     ]
 
     llm.chat.return_value = "generated code"
@@ -74,7 +80,7 @@ async def test_generate_code_no_figure(mock_sanitise_script, mock_load_prompt):
 # No fig creation
 """
 
-    plt.switch_backend('Agg')
+    plt.switch_backend("Agg")
 
     def mock_exec_side_effect(script, globals=None, locals=None):
         if script == return_string:
@@ -88,15 +94,45 @@ def mock_exec_side_effect(script, globals=None, locals=None):
         llm.chat.assert_called_once_with(
             model,
             "generate chart code prompt",
-            "details to create chart code prompt"
+            "details to create chart code prompt",
         )
 
         mock_sanitise_script.assert_called_once_with("generated code")
 
+
+@pytest.mark.asyncio
+@patch("src.agents.chart_generator_agent.engine.load_prompt")
+@patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock)
+@patch("src.agents.chart_generator_agent.UserConfirmer.confirm", new_callable=AsyncMock)
+async def test_generate_code_confirmation_false(confirm_mock, mock_sanitise_script, mock_load_prompt):
+    confirm_mock.return_value = False
+    llm = AsyncMock()
+    model = "mock_model"
+
+    mock_load_prompt.side_effect = [
+        "details to create chart code prompt",
+        "generate chart code prompt",
+    ]
+
+    llm.chat.return_value = "generated code"
+
+    mock_sanitise_script.return_value = "script"
+
+    with pytest.raises(Exception, match="The user did not confirm to creating a graph."):
+        await generate_chart("question_intent", "data_provided", "question_params", llm, model)
+
+    llm.chat.assert_called_once_with(
+        model,
+        "generate chart code prompt",
+        "details to create chart code prompt",
+    )
+
+    mock_sanitise_script.assert_called_once_with("generated code")
+
+
 @pytest.mark.parametrize(
     "input_script, expected_output",
     [
-
         (
             """```python
 import matplotlib.pyplot as plt
@@ -105,7 +141,7 @@ def mock_exec_side_effect(script, globals=None, locals=None):
 ```""",
             """import matplotlib.pyplot as plt
 fig = plt.figure()
-plt.plot([1, 2, 3], [4, 5, 6])"""
+plt.plot([1, 2, 3], [4, 5, 6])""",
         ),
         (
             """```python
@@ -114,7 +150,7 @@ def mock_exec_side_effect(script, globals=None, locals=None):
 plt.plot([1, 2, 3], [4, 5, 6])""",
             """import matplotlib.pyplot as plt
 fig = plt.figure()
-plt.plot([1, 2, 3], [4, 5, 6])"""
+plt.plot([1, 2, 3], [4, 5, 6])""",
         ),
         (
             """import matplotlib.pyplot as plt
@@ -123,7 +159,7 @@ def mock_exec_side_effect(script, globals=None, locals=None):
 ```""",
             """import matplotlib.pyplot as plt
 fig = plt.figure()
-plt.plot([1, 2, 3], [4, 5, 6])"""
+plt.plot([1, 2, 3], [4, 5, 6])""",
         ),
         (
             """import matplotlib.pyplot as plt
@@ -131,13 +167,13 @@ def mock_exec_side_effect(script, globals=None, locals=None):
 plt.plot([1, 2, 3], [4, 5, 6])""",
             """import matplotlib.pyplot as plt
 fig = plt.figure()
-plt.plot([1, 2, 3], [4, 5, 6])"""
+plt.plot([1, 2, 3], [4, 5, 6])""",
         ),
         (
             "",
-            ""
-        )
-    ]
+            "",
+        ),
+    ],
 )
 def test_sanitise_script(input_script, expected_output):
     assert sanitise_script(input_script) == expected_output

From 8ad67e6dd2827734f4ccf4d21844c5a39ac4ba8c Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Fri, 13 Sep 2024 10:37:57 +0100
Subject: [PATCH 25/48] Fix bug in the backend websocket

---
 backend/src/websockets/connection_manager.py        | 5 +++--
 backend/tests/websockets/connection_manager_test.py | 6 ++----
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/backend/src/websockets/connection_manager.py b/backend/src/websockets/connection_manager.py
index aac7a0b0a..6ca5e6b13 100644
--- a/backend/src/websockets/connection_manager.py
+++ b/backend/src/websockets/connection_manager.py
@@ -1,4 +1,3 @@
-import json
 import logging
 from typing import Any, Dict, List
 from fastapi import WebSocket
@@ -9,6 +8,7 @@
 
 logger = logging.getLogger(__name__)
 
+
 def parse_message(message: Dict[str, Any]) -> Message:
     data = message.get("data") or None
     return Message(type=message["type"], data=data)
@@ -46,11 +46,12 @@ async def handle_message(self, ws: WebSocket, message: Message):
     async def broadcast(self, message: Message):
         for ws in self.websockets:
             if ws.application_state == WebSocketState.CONNECTED:
-                await ws.send_json(json.dumps({"type": message.type.value, "data": message.data}))
+                await ws.send_json({"type": message.type.value, "data": message.data})
 
     async def send_chart(self, data: Dict[str, Any]):
         for ws in self.websockets:
             if ws.application_state == WebSocketState.CONNECTED:
                 await ws.send_json(data)
 
+
 connection_manager = ConnectionManager()
diff --git a/backend/tests/websockets/connection_manager_test.py b/backend/tests/websockets/connection_manager_test.py
index 087941382..809742314 100644
--- a/backend/tests/websockets/connection_manager_test.py
+++ b/backend/tests/websockets/connection_manager_test.py
@@ -1,4 +1,3 @@
-import json
 from unittest.mock import patch
 import pytest
 from fastapi.websockets import WebSocketState
@@ -101,6 +100,7 @@ async def test_disconnect_websocket_already_disconnected(connection_manager):
     mock_ws.close.assert_not_called()
     assert len(manager.websockets) == 0
 
+
 @pytest.mark.asyncio
 async def test_handle_message_handler_exists_for_message_type_handler_called(connection_manager, mocker):
     manager, mock_ws, _ = connection_manager
@@ -113,7 +113,6 @@ async def test_handle_message_handler_exists_for_message_type_handler_called(con
         handler.assert_called()
 
 
-
 @pytest.mark.asyncio
 async def test_handle_message_handler_does_not_exist_for_message_type_handler_called(connection_manager):
     manager, mock_ws, _ = connection_manager
@@ -126,7 +125,6 @@ async def test_handle_message_handler_does_not_exist_for_message_type_handler_ca
         assert str(error.value) == "No handler for message type"
 
 
-
 @pytest.mark.asyncio
 async def test_broadcast_given_message_broadcasted(connection_manager):
     manager, mock_ws, _ = connection_manager
@@ -136,7 +134,7 @@ async def test_broadcast_given_message_broadcasted(connection_manager):
 
     await manager.broadcast(message)
 
-    mock_ws.send_json.assert_awaited_once_with(json.dumps({"type": message.type.value, "data": message.data}))
+    mock_ws.send_json.assert_awaited_once_with({"type": message.type.value, "data": message.data})
 
 
 @pytest.mark.asyncio

From 143d014520545171da0ad48bd4e568ec4e931075 Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Fri, 13 Sep 2024 12:16:04 +0100
Subject: [PATCH 26/48] Add modal to respond to confirmation request on the
 front end

---
 frontend/src/components/chat.tsx              | 39 ++++++++----
 .../src/components/confirm-modal.module.css   | 63 +++++++++++++++++++
 frontend/src/components/confirm-modal.tsx     | 57 +++++++++++++++++
 frontend/src/session/websocket-context.ts     |  3 +-
 4 files changed, 150 insertions(+), 12 deletions(-)
 create mode 100644 frontend/src/components/confirm-modal.module.css
 create mode 100644 frontend/src/components/confirm-modal.tsx

diff --git a/frontend/src/components/chat.tsx b/frontend/src/components/chat.tsx
index 9c5fe9256..31bfec5bb 100644
--- a/frontend/src/components/chat.tsx
+++ b/frontend/src/components/chat.tsx
@@ -3,23 +3,37 @@ import { Message, MessageComponent } from './message';
 import styles from './chat.module.css';
 import { Waiting } from './waiting';
 import { ConnectionStatus } from './connection-status';
-import { WebsocketContext, MessageType } from '../session/websocket-context';
-
+import { WebsocketContext, MessageType, Message as wsMessage } from '../session/websocket-context';
+import { Confirmation, ConfirmModal } from './confirm-modal';
 export interface ChatProps {
   messages: Message[];
   waiting: boolean;
 }
 
+const mapWsMessageToConfirmation = (message: wsMessage): Confirmation | undefined => {
+  if (!message.data) {
+    return;
+  }
+  const parts = message.data.split(':');
+  return { id: parts[0], requestMessage: parts[1], result: null };
+};
+
 export const Chat = ({ messages, waiting }: ChatProps) => {
   const containerRef = React.useRef<HTMLDivElement>(null);
-  const { isConnected, lastMessage } = useContext(WebsocketContext);
+  const { isConnected, lastMessage, send } = useContext(WebsocketContext);
   const [chart, setChart] = useState<string | undefined>(undefined);
+  const [confirmation, setConfirmation] = useState<Confirmation | null>(null);
 
   useEffect(() => {
     if (lastMessage && lastMessage.type === MessageType.IMAGE) {
       const imageData = `data:image/png;base64,${lastMessage.data}`;
       setChart(imageData);
     }
+    if (lastMessage && lastMessage.type === MessageType.CONFIRMATION) {
+      const newConfirmation = mapWsMessageToConfirmation(lastMessage);
+      if (newConfirmation)
+        setConfirmation(newConfirmation);
+    }
   }, [lastMessage]);
 
   useEffect(() => {
@@ -29,13 +43,16 @@ export const Chat = ({ messages, waiting }: ChatProps) => {
   }, [messages.length]);
 
   return (
-    <div ref={containerRef} className={styles.container}>
-      <ConnectionStatus isConnected={isConnected} />
-      {messages.map((message, index) => (
-        <MessageComponent key={index} message={message} />
-      ))}
-      {chart && <img src={chart} alt="Generated chart"/>}
-      {waiting && <Waiting />}
-    </div>
+    <>
+      <ConfirmModal confirmation={confirmation} setConfirmation={setConfirmation} send={send} />
+      <div ref={containerRef} className={styles.container}>
+        <ConnectionStatus isConnected={isConnected} />
+        {messages.map((message, index) => (
+          <MessageComponent key={index} message={message} />
+        ))}
+        {chart && <img src={chart} alt="Generated chart" />}
+        {waiting && <Waiting />}
+      </div>
+    </>
   );
 };
diff --git a/frontend/src/components/confirm-modal.module.css b/frontend/src/components/confirm-modal.module.css
new file mode 100644
index 000000000..9004e1e33
--- /dev/null
+++ b/frontend/src/components/confirm-modal.module.css
@@ -0,0 +1,63 @@
+.modal{
+    width: 40%;
+    height: 40%;
+    background-color: #4c4c4c;
+    color: var(--text-color-primary);
+    border: 2px black;
+    border-radius: 10px;
+}
+
+.modalContent{
+    width: 100%;
+    height: 100%;
+    display: flex;
+    flex-direction: column;
+}
+
+.header{
+    text-align: center;
+}
+
+.modal::backdrop{
+    background: rgb(0,0,0,0.8);
+}
+
+.requestMessage{
+    flex-grow: 1;
+}
+
+.buttonsBar{
+    display: flex;
+    gap: 0.5rem;
+}
+
+.button{
+    color: var(--text-color-primary);
+    font-weight: bold;
+    border: none;
+    width: 100%;
+    padding: 1rem;
+    cursor: pointer;
+    border-radius: 3px;
+}
+
+
+.cancel{
+    composes: button;
+    background-color: var(--background-color-primary);
+}
+
+.cancel:hover{
+    background-color: #141414;
+    transition: all 0.5s;
+}
+
+.confirm{
+    composes: button;
+    background-color: var(--blue);
+}
+
+.confirm:hover{
+    background-color: #146AFF;
+    transition: all 0.5s;
+}
diff --git a/frontend/src/components/confirm-modal.tsx b/frontend/src/components/confirm-modal.tsx
new file mode 100644
index 000000000..9cadf0fc2
--- /dev/null
+++ b/frontend/src/components/confirm-modal.tsx
@@ -0,0 +1,57 @@
+import Styles from './confirm-modal.module.css';
+import { useEffect, useRef } from 'react';
+import { Message, MessageType } from '../session/websocket-context';
+import React from 'react';
+
+export interface Confirmation {
+  id: string,
+  requestMessage: string,
+  result: boolean | null
+}
+
+interface ConfirmModalProps {
+  confirmation: Confirmation | null,
+  setConfirmation: (confirmation: Confirmation | null) => void,
+  send: (message: Message) => void
+}
+
+export const ConfirmModal = ({ confirmation, setConfirmation, send }: ConfirmModalProps) => {
+  const mapConfirmationToMessage = (confirmation: Confirmation): Message => {
+    return { type: MessageType.CONFIRMATION, data: confirmation.id + ':' + (confirmation.result ? 'y' : 'n') };
+  };
+
+  const updateConfirmationResult = (newResult: boolean) => {
+    if (confirmation) {
+      setConfirmation({ ...confirmation, result: newResult });
+    }
+  };
+
+
+  const modalRef = useRef<HTMLDialogElement>(null);
+
+  useEffect(() => {
+    if (confirmation) {
+      if (confirmation.result !== null) {
+        send(mapConfirmationToMessage(confirmation));
+        setConfirmation(null);
+      } else {
+        modalRef.current?.showModal();
+      }
+    } else {
+      modalRef.current?.close();
+    }
+  }, [confirmation]);
+
+  return (
+    <dialog className={Styles.modal} ref={modalRef} onClose={() => updateConfirmationResult(false)}>
+      <div className={Styles.modalContent}>
+        <h1 className={Styles.header}>Confirmation</h1>
+        <p className={Styles.requestMessage}>{confirmation && confirmation.requestMessage}</p>
+        <div className={Styles.buttonsBar}>
+          <button className={Styles.cancel} onClick={() => updateConfirmationResult(false)}>Cancel</button>
+          <button className={Styles.confirm} onClick={() => updateConfirmationResult(true)}>Confirm</button>
+        </div>
+      </div>
+    </dialog>
+  );
+};
diff --git a/frontend/src/session/websocket-context.ts b/frontend/src/session/websocket-context.ts
index 0c56803e6..a684bc0d5 100644
--- a/frontend/src/session/websocket-context.ts
+++ b/frontend/src/session/websocket-context.ts
@@ -3,7 +3,8 @@ import { createContext } from 'react';
 export enum MessageType {
   PING = 'ping',
   CHAT = 'chat',
-  IMAGE = 'image'
+  IMAGE = 'image',
+  CONFIRMATION = 'confirmation',
 }
 
 export interface Message {

From 754b07348a858b49b85a660d4b3491cc2c920093 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Mon, 16 Sep 2024 12:15:00 +0100
Subject: [PATCH 27/48] First push of FileAgent

---
 .env.example                                 |  3 +
 .gitignore                                   |  1 +
 backend/src/agents/__init__.py               |  2 +
 backend/src/agents/file_agent.py             | 94 ++++++++++++++++++++
 backend/src/agents/web_agent.py              |  9 +-
 backend/src/prompts/templates/intent.j2      |  4 +
 backend/src/prompts/templates/summariser.j2  | 10 +--
 backend/src/supervisors/supervisor.py        | 11 ++-
 backend/src/utils/config.py                  |  2 +
 backend/tests/agents/file_agent_test.py      | 65 ++++++++++++++
 backend/tests/agents/web_agent_test.py       |  4 +-
 backend/tests/supervisors/supervisor_test.py |  4 +-
 compose.yml                                  |  5 ++
 13 files changed, 195 insertions(+), 19 deletions(-)
 create mode 100644 backend/src/agents/file_agent.py
 create mode 100644 backend/tests/agents/file_agent_test.py

diff --git a/.env.example b/.env.example
index df8da3280..3b5c59f0c 100644
--- a/.env.example
+++ b/.env.example
@@ -13,6 +13,9 @@ NEO4J_URI=bolt://localhost:7687
 NEO4J_HTTP_PORT=7474
 NEO4J_BOLT_PORT=7687
 
+# files location
+FILES_DIRECTORY=files
+
 # backend LLM properties
 MISTRAL_KEY=my-api-key
 
diff --git a/.gitignore b/.gitignore
index a6d19fb54..7d7996c23 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,6 +127,7 @@ celerybeat.pid
 # Environments
 .env
 .venv
+files
 env/
 venv/
 ENV/
diff --git a/backend/src/agents/__init__.py b/backend/src/agents/__init__.py
index cd58158db..02913c891 100644
--- a/backend/src/agents/__init__.py
+++ b/backend/src/agents/__init__.py
@@ -8,6 +8,7 @@
 from .validator_agent import ValidatorAgent
 from .answer_agent import AnswerAgent
 from .chart_generator_agent import ChartGeneratorAgent
+from .file_agent import FileAgent
 
 config = Config()
 
@@ -32,6 +33,7 @@ def get_available_agents() -> List[Agent]:
     return [DatastoreAgent(config.datastore_agent_llm, config.datastore_agent_model),
             WebAgent(config.web_agent_llm, config.web_agent_model),
             ChartGeneratorAgent(config.chart_generator_llm, config.chart_generator_model),
+            FileAgent(config.chart_generator_llm, config.chart_generator_model),
             ]
 
 
diff --git a/backend/src/agents/file_agent.py b/backend/src/agents/file_agent.py
new file mode 100644
index 000000000..e799a1b56
--- /dev/null
+++ b/backend/src/agents/file_agent.py
@@ -0,0 +1,94 @@
+import logging
+from .agent_types import Parameter
+from .agent import Agent, agent
+from .tool import tool
+import json
+import os
+from src.utils.config import Config
+
+logger = logging.getLogger(__name__)
+config = Config()
+
+FILES_DIRECTORY = f"/app/{config.files_directory}"
+
+
+async def read_file_core(file_path: str) -> str:
+    full_path = ""
+    try:
+        full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
+        with open(full_path, 'r') as file:
+            content = file.read()
+        response = {
+            "content": content,
+            "ignore_validation": "true"
+        }
+        return json.dumps(response, indent=4)
+    except FileNotFoundError:
+        error_message = f"File {file_path} not found."
+        logger.error(error_message)
+        response = {
+            "content": error_message,
+            "ignore_validation": "error",
+        }
+        return json.dumps(response, indent=4)
+    except Exception as e:
+        logger.error(f"Error reading file {full_path}: {e}")
+        return json.dumps({"status": "error", "message": f"Error reading file: {e}"})
+
+
+async def write_file_core(file_path: str, content: str) -> str:
+    full_path = ""
+    try:
+        full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
+        with open(full_path, 'w') as file:
+            file.write(content)
+        logger.info(f"Content written to file {full_path} successfully.")
+        response = {
+            "content": f"Content written to file {file_path}.",
+            "ignore_validation": "true",
+        }
+        return json.dumps(response, indent=4)
+    except Exception as e:
+        logger.error(f"Error writing to file {full_path}: {e}")
+        return json.dumps({"status": "error", "message": f"Error writing to file: {e}"})
+
+
+@tool(
+    name="read_file",
+    description="Read the content of a text file.",
+    parameters={
+        "file_path": Parameter(
+            type="string",
+            description="The path to the file to be read."
+        ),
+    },
+)
+async def read_file(file_path: str, llm, model) -> str:
+    return await read_file_core(file_path)
+
+
+@tool(
+    name="write_file",
+    description="Write content to a text file.",
+    parameters={
+        "file_path": Parameter(
+            type="string",
+            description="The path to the file where the content will be written."
+        ),
+        "content": Parameter(
+            type="string",
+            description="The content to write to the file."
+        ),
+    },
+)
+async def write_file(file_path: str, content: str, llm, model) -> str:
+    return await write_file_core(file_path, content)
+
+
+@agent(
+    name="FileAgent",
+    description="This agent is responsible for reading from and writing to files.",
+    tools=[read_file, write_file],
+)
+class FileAgent(Agent):
+    pass
diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py
index e895a68ad..627a14fc6 100644
--- a/backend/src/agents/web_agent.py
+++ b/backend/src/agents/web_agent.py
@@ -29,10 +29,12 @@ async def web_general_search_core(search_query, llm, model) -> str:
             content = await perform_scrape(url)
             if not content:
                 continue
-            summary = await perform_summarization(search_query, content, llm, model)
-            if not summary:
+            summarisation = await perform_summarization(search_query, content, llm, model)
+            if not summarisation:
                 continue
-            is_valid = await is_valid_answer(summary, search_query)
+            is_valid = await is_valid_answer(summarisation, search_query)
+            parsed_json = json.loads(summarisation)
+            summary = parsed_json.get('summary', '')
             if is_valid:
                 response = {
                     "content": summary,
@@ -137,6 +139,7 @@ async def perform_summarization(search_query: str, content: str, llm: Any, model
         summarise_result = json.loads(summarise_result_json)
         if summarise_result["status"] == "error":
             return ""
+        logger.info(f"Content summarized successfully: {summarise_result['response']}")
         return summarise_result["response"]
     except Exception as e:
         logger.error(f"Error summarizing content: {e}")
diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index 3f755d040..a1361b041 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -72,3 +72,7 @@ Response:
 Q: Show me a chart of different subscription prices with Netflix?
 Response:
 {"query": "Show me a chart of different subscription prices with Netflix?", "user_intent": "retrieve and visualize subscription data", "questions": [{"query": "What are the different subscription prices with Netflix?", "question_intent": "retrieve subscription pricing information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Netflix"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Show me the results in a chart", "question_intent": "display subscription pricing information in a chart", "operation": "data visualization", "question_category": "data presentation", "parameters": [], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}
+
+Q: Read the file called file_to_read.txt and write its content to a file called output.txt.
+Response:
+{"query": "Read the file called {{ file_name }} and write its content to a file called {{ output_file_name }}.", "user_intent": "read and write file content", "questions": [{"query": "Read the file called {{ file_name }} using fileagent.", "question_intent": "read file content", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "file", "value": "{{ file_name }}"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Write the content to a file called {{ output_file_name }} using fileagent.", "question_intent": "write file content", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "file", "value": "{{ output_file_name }}"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}
diff --git a/backend/src/prompts/templates/summariser.j2 b/backend/src/prompts/templates/summariser.j2
index 4a54cd520..33a1bf03f 100644
--- a/backend/src/prompts/templates/summariser.j2
+++ b/backend/src/prompts/templates/summariser.j2
@@ -6,12 +6,11 @@ You will be passed a user query and the content scraped from the web. You need t
 
 Ensure the summary is clear, well-structured, and directly addresses the user's query.
 
-
 User's question is:
 {{ question }}
 
 Below is the content scraped from the web:
-{{ content }}
+{{ content | replace("\n\n", "\n") }}  # Adding this will introduce breaks between paragraphs
 
 Reply only in json with the following format:
 
@@ -19,10 +18,3 @@ Reply only in json with the following format:
     "summary":  "The summary of the content that answers the user's query",
     "reasoning": "A sentence on why you chose that summary"
 }
-
-e.g.
-Task: What is the capital of England
-{
-    "summary": "The capital of England is London.",
-    "reasoning": "London is widely known as the capital of England, a fact mentioned in various authoritative sources and geographical references."
-}
diff --git a/backend/src/supervisors/supervisor.py b/backend/src/supervisors/supervisor.py
index f9a87478c..a618b322b 100644
--- a/backend/src/supervisors/supervisor.py
+++ b/backend/src/supervisors/supervisor.py
@@ -20,25 +20,30 @@ async def solve_all(intent_json) -> None:
 
     for question in questions:
         try:
-            (agent_name, answer) = await solve_task(question, get_scratchpad())
+            (agent_name, answer, status) = await solve_task(question, get_scratchpad())
             update_scratchpad(agent_name, question, answer)
+            if status == "error":
+                raise Exception(answer)
         except Exception as error:
             update_scratchpad(error=error)
 
 
-async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str]:
+async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]:
     if attempt == 5:
         raise Exception(unsolvable_response)
 
     agent = await get_agent_for_task(task, scratchpad)
+    logger.info(f"Agent selected: {agent}")
     if agent is None:
         raise Exception(no_agent_response)
     answer = await agent.invoke(task)
     parsed_json = json.loads(answer)
     ignore_validation = parsed_json.get('ignore_validation', '')
     answer_content = parsed_json.get('content', '')
+    if ignore_validation == 'error':
+        return (agent.name, answer_content, "error")
     if(ignore_validation == 'true') or await is_valid_answer(answer_content, task):
-        return (agent.name, answer_content)
+        return (agent.name, answer_content, "success")
     return await solve_task(task, scratchpad, attempt + 1)
 
 
diff --git a/backend/src/utils/config.py b/backend/src/utils/config.py
index 213a4feb8..e861d6dee 100644
--- a/backend/src/utils/config.py
+++ b/backend/src/utils/config.py
@@ -33,6 +33,7 @@ def __init__(self):
         self.chart_generator_model = None
         self.web_agent_model = None
         self.router_model = None
+        self.files_directory = None
         self.load_env()
 
     def load_env(self):
@@ -49,6 +50,7 @@ def load_env(self):
             self.neo4j_uri = os.getenv("NEO4J_URI", default_neo4j_uri)
             self.neo4j_user = os.getenv("NEO4J_USERNAME")
             self.neo4j_password = os.getenv("NEO4J_PASSWORD")
+            self.files_directory = os.getenv("FILES_DIRECTORY")
             self.azure_storage_connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
             self.azure_storage_container_name = os.getenv("AZURE_STORAGE_CONTAINER_NAME")
             self.azure_initial_data_filename = os.getenv("AZURE_INITIAL_DATA_FILENAME")
diff --git a/backend/tests/agents/file_agent_test.py b/backend/tests/agents/file_agent_test.py
new file mode 100644
index 000000000..33d9ff8e9
--- /dev/null
+++ b/backend/tests/agents/file_agent_test.py
@@ -0,0 +1,65 @@
+import pytest
+from unittest.mock import patch, mock_open
+import json
+import os
+from src.agents.file_agent import read_file_core, write_file_core
+
+# Mocking config for the test
+@pytest.fixture(autouse=True)
+def mock_config(monkeypatch):
+    monkeypatch.setattr('src.agents.file_agent.config.files_directory', 'files')
+
+@pytest.mark.asyncio
+@patch("builtins.open", new_callable=mock_open, read_data="Example file content.")
+async def test_read_file_core_success(mock_file):
+    file_path = "example.txt"
+    result = await read_file_core(file_path)
+    expected_response = {
+        "content": "Example file content.",
+        "ignore_validation": "true"
+    }
+    assert json.loads(result) == expected_response
+    expected_full_path = os.path.normpath("/app/files/example.txt")
+    mock_file.assert_called_once_with(expected_full_path, 'r')
+
+@pytest.mark.asyncio
+@patch("builtins.open", side_effect=FileNotFoundError)
+async def test_read_file_core_file_not_found(mock_file):
+    file_path = "missing_file.txt"
+    result = await read_file_core(file_path)
+    expected_response = {
+        "content": "File missing_file.txt not found.",
+        "ignore_validation": "error"
+    }
+    assert json.loads(result) == expected_response
+    expected_full_path = os.path.normpath("/app/files/missing_file.txt")
+    mock_file.assert_called_once_with(expected_full_path, 'r')
+
+@pytest.mark.asyncio
+@patch("builtins.open", new_callable=mock_open)
+async def test_write_file_core_success(mock_file):
+    file_path = "example_write.txt"
+    content = "This is test content to write."
+    result = await write_file_core(file_path, content)
+    expected_response = {
+        "content": f"Content written to file {file_path}.",
+        "ignore_validation": "true"
+    }
+    assert json.loads(result) == expected_response
+    expected_full_path = os.path.normpath("/app/files/example_write.txt")
+    mock_file.assert_called_once_with(expected_full_path, 'w')
+    mock_file().write.assert_called_once_with(content)
+
+@pytest.mark.asyncio
+@patch("builtins.open", side_effect=Exception("Unexpected error"))
+async def test_write_file_core_error(mock_file):
+    file_path = "error_file.txt"
+    content = "Content with error."
+    result = await write_file_core(file_path, content)
+    expected_response = {
+        "status": "error",
+        "message": "Error writing to file: Unexpected error"
+    }
+    assert json.loads(result) == expected_response
+    expected_full_path = os.path.normpath("/app/files/error_file.txt")
+    mock_file.assert_called_once_with(expected_full_path, 'w')
diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py
index c18d2070d..b54aca6a2 100644
--- a/backend/tests/agents/web_agent_test.py
+++ b/backend/tests/agents/web_agent_test.py
@@ -19,7 +19,7 @@ async def test_web_general_search_core(
 
     mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
     mock_perform_scrape.return_value = "Example scraped content."
-    mock_perform_summarization.return_value = "Example summary."
+    mock_perform_summarization.return_value = json.dumps({"summary": "Example summary."})
     mock_is_valid_answer.return_value = True
     result = await web_general_search_core("example query", llm, model)
     expected_response = {
@@ -61,7 +61,7 @@ async def test_web_general_search_core_invalid_summary(
     model = "mock_model"
     mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
     mock_perform_scrape.return_value = "Example scraped content."
-    mock_perform_summarization.return_value = "Example invalid summary."
+    mock_perform_summarization.return_value = json.dumps({"summary": "Example invalid summary."})
     mock_is_valid_answer.return_value = False
     result = await web_general_search_core("example query", llm, model)
     assert result == "No relevant information found on the internet for the given query."
diff --git a/backend/tests/supervisors/supervisor_test.py b/backend/tests/supervisors/supervisor_test.py
index 31b5c42d5..31c0a4147 100644
--- a/backend/tests/supervisors/supervisor_test.py
+++ b/backend/tests/supervisors/supervisor_test.py
@@ -64,7 +64,7 @@ async def test_solve_task_first_attempt_solves(mocker):
     mock_answer_json = json.loads(mock_answer)
 
     # Ensure that the result is returned directly without validation
-    assert answer == (agent.name, mock_answer_json.get('content', ''))
+    assert answer == (agent.name, mock_answer_json.get('content', ''), "success")
 
 
 @pytest.mark.asyncio
@@ -83,7 +83,7 @@ async def test_solve_task_ignore_validation(mocker):
     mock_answer_json = json.loads(mock_answer)
 
     # Ensure that the result is returned directly without validation
-    assert answer == (agent.name, mock_answer_json.get('content', ''))
+    assert answer == (agent.name, mock_answer_json.get('content', ''), "success")
     mock_is_valid_answer.assert_not_called()  # Validation should not be called
 
 @pytest.mark.asyncio
diff --git a/compose.yml b/compose.yml
index b0b8fd3fb..adb51b1b4 100644
--- a/compose.yml
+++ b/compose.yml
@@ -40,10 +40,14 @@ services:
       start_period: 60s
   # InferGPT Backend
   backend:
+    env_file:
+      - .env
     image: infergpt/backend
     build:
       context: backend
       dockerfile: ./Dockerfile
+    volumes:
+      - ./${FILES_DIRECTORY}:/app/${FILES_DIRECTORY}
     environment:
       NEO4J_URI: bolt://neo4j-db:7687
       NEO4J_USERNAME: ${NEO4J_USERNAME}
@@ -51,6 +55,7 @@ services:
       MISTRAL_KEY: ${MISTRAL_KEY}
       OPENAI_KEY: ${OPENAI_KEY}
       FRONTEND_URL: ${FRONTEND_URL}
+      FILES_DIRECTORY: ${FILES_DIRECTORY}
       AZURE_STORAGE_CONNECTION_STRING: ${AZURE_STORAGE_CONNECTION_STRING}
       AZURE_STORAGE_CONTAINER_NAME: ${AZURE_STORAGE_CONTAINER_NAME}
       AZURE_INITIAL_DATA_FILENAME: ${AZURE_INITIAL_DATA_FILENAME}

From f1d3704142dded4d5eaeda8c66042405c7f3b9be Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Tue, 17 Sep 2024 11:19:46 +0100
Subject: [PATCH 28/48] Setting default files directory

---
 backend/src/utils/config.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/backend/src/utils/config.py b/backend/src/utils/config.py
index e861d6dee..568a1e68d 100644
--- a/backend/src/utils/config.py
+++ b/backend/src/utils/config.py
@@ -3,6 +3,7 @@
 
 default_frontend_url = "http://localhost:8650"
 default_neo4j_uri = "bolt://localhost:7687"
+default_files_directory = "files"
 
 
 class Config(object):
@@ -33,7 +34,7 @@ def __init__(self):
         self.chart_generator_model = None
         self.web_agent_model = None
         self.router_model = None
-        self.files_directory = None
+        self.files_directory = default_files_directory
         self.load_env()
 
     def load_env(self):

From 2356d4ce403e699056d35e5a9ad983cd372be6b6 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Tue, 17 Sep 2024 11:22:01 +0100
Subject: [PATCH 29/48] setting default files dir in the load_env

---
 backend/src/utils/config.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/backend/src/utils/config.py b/backend/src/utils/config.py
index 568a1e68d..e18b2a102 100644
--- a/backend/src/utils/config.py
+++ b/backend/src/utils/config.py
@@ -51,7 +51,7 @@ def load_env(self):
             self.neo4j_uri = os.getenv("NEO4J_URI", default_neo4j_uri)
             self.neo4j_user = os.getenv("NEO4J_USERNAME")
             self.neo4j_password = os.getenv("NEO4J_PASSWORD")
-            self.files_directory = os.getenv("FILES_DIRECTORY")
+            self.files_directory = os.getenv("FILES_DIRECTORY", default_files_directory)
             self.azure_storage_connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
             self.azure_storage_container_name = os.getenv("AZURE_STORAGE_CONTAINER_NAME")
             self.azure_initial_data_filename = os.getenv("AZURE_INITIAL_DATA_FILENAME")

From bf0faf7f11c3ea42aef48f564160ede937578ac3 Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Thu, 19 Sep 2024 10:35:09 +0100
Subject: [PATCH 30/48] Update mistralai package: v0.1.8 => v1.1.0

Update package in requirements file and update mistral.py
---
 backend/requirements.txt          |   2 +-
 backend/src/llm/mistral.py        |  30 ++++++--
 backend/tests/llm/mistral_test.py | 120 +++++++++++++++++++++++-------
 3 files changed, 117 insertions(+), 35 deletions(-)

diff --git a/backend/requirements.txt b/backend/requirements.txt
index 3dae443a3..59917b6dd 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -1,6 +1,6 @@
 fastapi==0.110.0
 uvicorn==0.29.0
-mistralai==0.1.8
+mistralai==1.1.0
 pycodestyle==2.11.1
 python-dotenv==1.0.1
 neo4j==5.18.0
diff --git a/backend/src/llm/mistral.py b/backend/src/llm/mistral.py
index 8fac39101..552534123 100644
--- a/backend/src/llm/mistral.py
+++ b/backend/src/llm/mistral.py
@@ -1,5 +1,4 @@
-from mistralai.async_client import MistralAsyncClient
-from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage
+from mistralai import Mistral as MistralApi, UserMessage, SystemMessage
 import logging
 from src.utils import Config
 from .llm import LLM
@@ -9,21 +8,36 @@
 
 
 class Mistral(LLM):
-    client = MistralAsyncClient(api_key=config.mistral_key)
+    client = MistralApi(api_key=config.mistral_key)
 
     async def chat(self, model, system_prompt: str, user_prompt: str, return_json=False) -> str:
         logger.debug("Called llm. Waiting on response model with prompt {0}.".format(str([system_prompt, user_prompt])))
-        response: ChatCompletionResponse = await self.client.chat(
+        response = await self.client.chat.complete_async(
             model=model,
             messages=[
-                ChatMessage(role="system", content=system_prompt),
-                ChatMessage(role="user", content=user_prompt),
+                SystemMessage(content=system_prompt),
+                UserMessage(content=user_prompt),
             ],
             temperature=0,
             response_format={"type": "json_object"} if return_json else None,
         )
+        if response is None:
+            logger.error("Call to mistral api failed: response was None")
+            return "An error occurred while processing the request."
+
+        if response.choices is None:
+            logger.error("Call to mistral api failed: response.choices was None")
+            return "An error occurred while processing the request."
+
+        if len(response.choices) < 1:
+            logger.error("Call to mistral api failed: response.choices was empty")
+            return "An error occurred while processing the request."
+
         logger.debug('{0} response : "{1}"'.format(model, response.choices[0].message.content))
 
         content = response.choices[0].message.content
-
-        return content if isinstance(content, str) else " ".join(content)
+        if isinstance(content, str):
+            return content
+        else:
+            logger.error("Call to mistral api failed: message.content was None or Unset")
+            return "An error occurred while processing the request."
diff --git a/backend/tests/llm/mistral_test.py b/backend/tests/llm/mistral_test.py
index 86f87f7b2..34f4d4bc2 100644
--- a/backend/tests/llm/mistral_test.py
+++ b/backend/tests/llm/mistral_test.py
@@ -1,10 +1,8 @@
+import logging
 from typing import cast
-from unittest.mock import MagicMock
-from mistralai.async_client import MistralAsyncClient
-from mistralai.models.chat_completion import ChatCompletionResponse
-from mistralai.models.chat_completion import ChatCompletionResponseChoice
-from mistralai.models.chat_completion import ChatMessage
-from mistralai.models.common import UsageInfo
+from unittest.mock import AsyncMock, MagicMock
+from mistralai import UNSET, AssistantMessage, Mistral as MistralApi, SystemMessage, UserMessage
+from mistralai.models import ChatCompletionResponse, ChatCompletionChoice, UsageInfo
 import pytest
 from src.llm import get_llm, Mistral
 from src.utils import Config
@@ -17,34 +15,23 @@
 mistral = cast(Mistral, get_llm("mistral"))
 
 
-async def create_mock_chat_response(content, tool_calls=None):
+def create_mock_chat_response(content, tool_calls=None):
     mock_usage = UsageInfo(prompt_tokens=1, total_tokens=2, completion_tokens=3)
-    mock_message = ChatMessage(role="system", content=content, tool_calls=tool_calls)
-    mock_choice = ChatCompletionResponseChoice(index=0, message=mock_message, finish_reason=None)
+    mock_message = AssistantMessage(content=content, tool_calls=tool_calls)
+    mock_choice = ChatCompletionChoice(index=0, message=mock_message, finish_reason="stop")
     return ChatCompletionResponse(
         id="id", object="object", created=123, model="model", choices=[mock_choice], usage=mock_usage
     )
 
 
-mock_client = MagicMock(spec=MistralAsyncClient)
+mock_client = AsyncMock(spec=MistralApi)
 mock_config = MagicMock(spec=Config)
 
 
 @pytest.mark.asyncio
 async def test_chat_content_string_returns_string(mocker):
-    mistral.client = mocker.MagicMock(return_value=mock_client)
-    mistral.client.chat.return_value = create_mock_chat_response(content_response)
-
-    response = await mistral.chat(mock_model, system_prompt, user_prompt)
-
-    assert response == content_response
-
-
-@pytest.mark.asyncio
-async def test_chat_content_list_returns_string(mocker):
-    content_list = ["Hello", "there"]
-    mistral.client = mocker.MagicMock(return_value=mock_client)
-    mistral.client.chat.return_value = create_mock_chat_response(content_list)
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    mistral.client.chat.complete_async.return_value = create_mock_chat_response(content_response)
 
     response = await mistral.chat(mock_model, system_prompt, user_prompt)
 
@@ -58,9 +45,90 @@ async def test_chat_calls_client_chat(mocker):
     await mistral.chat(mock_model, system_prompt, user_prompt)
 
     expected_messages = [
-        ChatMessage(role="system", content=system_prompt),
-        ChatMessage(role="user", content=user_prompt),
+        SystemMessage(content=system_prompt),
+        UserMessage(content=user_prompt),
     ]
-    mistral.client.chat.assert_called_once_with(
+    mistral.client.chat.complete_async.assert_awaited_once_with(
         messages=expected_messages, model=mock_model, temperature=0, response_format=None
     )
+
+
+@pytest.mark.asyncio
+async def test_chat_response_none_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    mistral.client.chat.complete_async.return_value = None
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert ("src.llm.mistral", logging.ERROR, "Call to mistral api failed: response was None") in caplog.record_tuples
+
+
+@pytest.mark.asyncio
+async def test_chat_response_choices_none_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    response = create_mock_chat_response(content_response)
+    response.choices = None
+    mistral.client.chat.complete_async.return_value = response
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to mistral api failed: response.choices was None",
+    ) in caplog.record_tuples
+
+
+@pytest.mark.asyncio
+async def test_chat_response_choices_empty_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    response = create_mock_chat_response(content_response)
+    response.choices = []
+    mistral.client.chat.complete_async.return_value = response
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to mistral api failed: response.choices was empty",
+    ) in caplog.record_tuples
+
+
+@pytest.mark.asyncio
+async def test_chat_response_choices_message_content_none_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    response = create_mock_chat_response(content_response)
+    assert response.choices is not None
+    response.choices[0].message.content = None
+    mistral.client.chat.complete_async.return_value = response
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to mistral api failed: message.content was None or Unset",
+    ) in caplog.record_tuples
+
+
+@pytest.mark.asyncio
+async def test_chat_response_choices_message_content_unset_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    response = create_mock_chat_response(content_response)
+    assert response.choices is not None
+    response.choices[0].message.content = UNSET
+    mistral.client.chat.complete_async.return_value = response
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to mistral api failed: message.content was None or Unset",
+    ) in caplog.record_tuples

From 49229221749e8ad18ff16b691b5809eb16017733 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Fri, 20 Sep 2024 15:07:44 +0100
Subject: [PATCH 31/48] Changes based on Seth's feeback and other cleaning.

---
 .env.example                            |  2 ++
 backend/src/agents/__init__.py          |  2 +-
 backend/src/agents/file_agent.py        | 40 ++++++++++++-------------
 backend/src/supervisors/supervisor.py   |  5 ++--
 backend/src/utils/config.py             |  4 +++
 backend/tests/agents/file_agent_test.py | 30 ++++++-------------
 compose.yml                             |  2 ++
 7 files changed, 39 insertions(+), 46 deletions(-)

diff --git a/.env.example b/.env.example
index 3b5c59f0c..bcb17ddb2 100644
--- a/.env.example
+++ b/.env.example
@@ -45,6 +45,7 @@ MATHS_AGENT_LLM="openai"
 WEB_AGENT_LLM="openai"
 CHART_GENERATOR_LLM="openai"
 ROUTER_LLM="openai"
+FILE_AGENT_LLM="openai"
 
 # llm model
 ANSWER_AGENT_MODEL="gpt-4o mini"
@@ -55,3 +56,4 @@ MATHS_AGENT_MODEL="gpt-4o mini"
 WEB_AGENT_MODEL="gpt-4o mini"
 CHART_GENERATOR_MODEL="gpt-4o mini"
 ROUTER_MODEL="gpt-4o mini"
+FILE_AGENT_MODEL="gpt-4o mini"
diff --git a/backend/src/agents/__init__.py b/backend/src/agents/__init__.py
index 02913c891..d16ade6f7 100644
--- a/backend/src/agents/__init__.py
+++ b/backend/src/agents/__init__.py
@@ -33,7 +33,7 @@ def get_available_agents() -> List[Agent]:
     return [DatastoreAgent(config.datastore_agent_llm, config.datastore_agent_model),
             WebAgent(config.web_agent_llm, config.web_agent_model),
             ChartGeneratorAgent(config.chart_generator_llm, config.chart_generator_model),
-            FileAgent(config.chart_generator_llm, config.chart_generator_model),
+            FileAgent(config.file_agent_llm, config.file_agent_model),
             ]
 
 
diff --git a/backend/src/agents/file_agent.py b/backend/src/agents/file_agent.py
index e799a1b56..50a1d0fd9 100644
--- a/backend/src/agents/file_agent.py
+++ b/backend/src/agents/file_agent.py
@@ -11,46 +11,44 @@
 
 FILES_DIRECTORY = f"/app/{config.files_directory}"
 
+# Constants for response status
+IGNORE_VALIDATION = "true"
+STATUS_SUCCESS = "success"
+STATUS_ERROR = "error"
+
+# Utility function for error responses
+def create_response(content: str, status: str = STATUS_SUCCESS) -> str:
+    return json.dumps({
+        "content": content,
+        "ignore_validation": IGNORE_VALIDATION,
+        "status": status
+    }, indent=4)
 
 async def read_file_core(file_path: str) -> str:
-    full_path = ""
+    full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
     try:
-        full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
         with open(full_path, 'r') as file:
             content = file.read()
-        response = {
-            "content": content,
-            "ignore_validation": "true"
-        }
-        return json.dumps(response, indent=4)
+        return create_response(content)
     except FileNotFoundError:
         error_message = f"File {file_path} not found."
         logger.error(error_message)
-        response = {
-            "content": error_message,
-            "ignore_validation": "error",
-        }
-        return json.dumps(response, indent=4)
+        return create_response(error_message, STATUS_ERROR)
     except Exception as e:
         logger.error(f"Error reading file {full_path}: {e}")
-        return json.dumps({"status": "error", "message": f"Error reading file: {e}"})
+        return create_response(f"Error reading file: {file_path}", STATUS_ERROR)
 
 
 async def write_file_core(file_path: str, content: str) -> str:
-    full_path = ""
+    full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
     try:
-        full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
         with open(full_path, 'w') as file:
             file.write(content)
         logger.info(f"Content written to file {full_path} successfully.")
-        response = {
-            "content": f"Content written to file {file_path}.",
-            "ignore_validation": "true",
-        }
-        return json.dumps(response, indent=4)
+        return create_response(f"Content written to file {file_path}.")
     except Exception as e:
         logger.error(f"Error writing to file {full_path}: {e}")
-        return json.dumps({"status": "error", "message": f"Error writing to file: {e}"})
+        return create_response(f"Error writing to file: {file_path}", STATUS_ERROR)
 
 
 @tool(
diff --git a/backend/src/supervisors/supervisor.py b/backend/src/supervisors/supervisor.py
index a618b322b..f93915ccf 100644
--- a/backend/src/supervisors/supervisor.py
+++ b/backend/src/supervisors/supervisor.py
@@ -38,12 +38,11 @@ async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]:
         raise Exception(no_agent_response)
     answer = await agent.invoke(task)
     parsed_json = json.loads(answer)
+    status = parsed_json.get('status', 'success')
     ignore_validation = parsed_json.get('ignore_validation', '')
     answer_content = parsed_json.get('content', '')
-    if ignore_validation == 'error':
-        return (agent.name, answer_content, "error")
     if(ignore_validation == 'true') or await is_valid_answer(answer_content, task):
-        return (agent.name, answer_content, "success")
+        return (agent.name, answer_content, status)
     return await solve_task(task, scratchpad, attempt + 1)
 
 
diff --git a/backend/src/utils/config.py b/backend/src/utils/config.py
index e18b2a102..127b445a3 100644
--- a/backend/src/utils/config.py
+++ b/backend/src/utils/config.py
@@ -26,6 +26,7 @@ def __init__(self):
         self.maths_agent_llm = None
         self.web_agent_llm = None
         self.chart_generator_llm = None
+        self.file_agent_llm = None
         self.router_llm = None
         self.validator_agent_model = None
         self.intent_agent_model = None
@@ -35,6 +36,7 @@ def __init__(self):
         self.web_agent_model = None
         self.router_model = None
         self.files_directory = default_files_directory
+        self.file_agent_model = None
         self.load_env()
 
     def load_env(self):
@@ -60,6 +62,7 @@ def load_env(self):
             self.validator_agent_llm = os.getenv("VALIDATOR_AGENT_LLM")
             self.datastore_agent_llm = os.getenv("DATASTORE_AGENT_LLM")
             self.chart_generator_llm = os.getenv("CHART_GENERATOR_LLM")
+            self.file_agent_llm = os.getenv("FILE_AGENT_LLM")
             self.web_agent_llm = os.getenv("WEB_AGENT_LLM")
             self.maths_agent_llm = os.getenv("MATHS_AGENT_LLM")
             self.router_llm = os.getenv("ROUTER_LLM")
@@ -71,6 +74,7 @@ def load_env(self):
             self.chart_generator_model = os.getenv("CHART_GENERATOR_MODEL")
             self.maths_agent_model = os.getenv("MATHS_AGENT_MODEL")
             self.router_model = os.getenv("ROUTER_MODEL")
+            self.file_agent_model = os.getenv("FILE_AGENT_MODEL")
         except FileNotFoundError:
             raise FileNotFoundError("Please provide a .env file. See the Getting Started guide on the README.md")
         except Exception:
diff --git a/backend/tests/agents/file_agent_test.py b/backend/tests/agents/file_agent_test.py
index 33d9ff8e9..91a56eb0f 100644
--- a/backend/tests/agents/file_agent_test.py
+++ b/backend/tests/agents/file_agent_test.py
@@ -2,7 +2,7 @@
 from unittest.mock import patch, mock_open
 import json
 import os
-from src.agents.file_agent import read_file_core, write_file_core
+from src.agents.file_agent import read_file_core, write_file_core, create_response
 
 # Mocking config for the test
 @pytest.fixture(autouse=True)
@@ -14,11 +14,8 @@ def mock_config(monkeypatch):
 async def test_read_file_core_success(mock_file):
     file_path = "example.txt"
     result = await read_file_core(file_path)
-    expected_response = {
-        "content": "Example file content.",
-        "ignore_validation": "true"
-    }
-    assert json.loads(result) == expected_response
+    expected_response = create_response("Example file content.")
+    assert json.loads(result) == json.loads(expected_response)
     expected_full_path = os.path.normpath("/app/files/example.txt")
     mock_file.assert_called_once_with(expected_full_path, 'r')
 
@@ -27,11 +24,8 @@ async def test_read_file_core_success(mock_file):
 async def test_read_file_core_file_not_found(mock_file):
     file_path = "missing_file.txt"
     result = await read_file_core(file_path)
-    expected_response = {
-        "content": "File missing_file.txt not found.",
-        "ignore_validation": "error"
-    }
-    assert json.loads(result) == expected_response
+    expected_response = create_response(f"File {file_path} not found.", "error")
+    assert json.loads(result) == json.loads(expected_response)
     expected_full_path = os.path.normpath("/app/files/missing_file.txt")
     mock_file.assert_called_once_with(expected_full_path, 'r')
 
@@ -41,11 +35,8 @@ async def test_write_file_core_success(mock_file):
     file_path = "example_write.txt"
     content = "This is test content to write."
     result = await write_file_core(file_path, content)
-    expected_response = {
-        "content": f"Content written to file {file_path}.",
-        "ignore_validation": "true"
-    }
-    assert json.loads(result) == expected_response
+    expected_response = create_response(f"Content written to file {file_path}.")
+    assert json.loads(result) == json.loads(expected_response)
     expected_full_path = os.path.normpath("/app/files/example_write.txt")
     mock_file.assert_called_once_with(expected_full_path, 'w')
     mock_file().write.assert_called_once_with(content)
@@ -56,10 +47,7 @@ async def test_write_file_core_error(mock_file):
     file_path = "error_file.txt"
     content = "Content with error."
     result = await write_file_core(file_path, content)
-    expected_response = {
-        "status": "error",
-        "message": "Error writing to file: Unexpected error"
-    }
-    assert json.loads(result) == expected_response
+    expected_response = create_response(f"Error writing to file: {file_path}", "error")
+    assert json.loads(result) == json.loads(expected_response)
     expected_full_path = os.path.normpath("/app/files/error_file.txt")
     mock_file.assert_called_once_with(expected_full_path, 'w')
diff --git a/compose.yml b/compose.yml
index adb51b1b4..4a979141f 100644
--- a/compose.yml
+++ b/compose.yml
@@ -67,6 +67,7 @@ services:
       MATHS_AGENT_LLM: ${MATHS_AGENT_LLM}
       ROUTER_LLM: ${ROUTER_LLM}
       CHART_GENERATOR_LLM: ${CHART_GENERATOR_LLM}
+      FILE_AGENT_LLM: ${FILE_AGENT_LLM}
       ANSWER_AGENT_MODEL: ${ANSWER_AGENT_MODEL}
       INTENT_AGENT_MODEL: ${INTENT_AGENT_MODEL}
       VALIDATOR_AGENT_MODEL: ${VALIDATOR_AGENT_MODEL}
@@ -76,6 +77,7 @@ services:
       MATHS_AGENT_MODEL: ${MATHS_AGENT_MODEL}
       AGENT_CLASS_MODEL: ${AGENT_CLASS_MODEL}
       CHART_GENERATOR_MODEL: ${CHART_GENERATOR_MODEL}
+      FILE_AGENT_MODEL: ${FILE_AGENT_MODEL}
     depends_on:
       neo4j-db:
         condition: service_healthy

From 1bf98591603d1843772c3ec4338ef5ed66c14578 Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Tue, 24 Sep 2024 14:59:24 +0100
Subject: [PATCH 32/48] Rename input chat response in tests (response =>
 chat_response

---
 backend/tests/llm/mistral_test.py | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/backend/tests/llm/mistral_test.py b/backend/tests/llm/mistral_test.py
index 34f4d4bc2..a7a7abfbf 100644
--- a/backend/tests/llm/mistral_test.py
+++ b/backend/tests/llm/mistral_test.py
@@ -84,9 +84,9 @@ async def test_chat_response_choices_none_logs_error(mocker, caplog):
 @pytest.mark.asyncio
 async def test_chat_response_choices_empty_logs_error(mocker, caplog):
     mistral.client = mocker.AsyncMock(return_value=mock_client)
-    response = create_mock_chat_response(content_response)
-    response.choices = []
-    mistral.client.chat.complete_async.return_value = response
+    chat_response = create_mock_chat_response(content_response)
+    chat_response.choices = []
+    mistral.client.chat.complete_async.return_value = chat_response
 
     response = await mistral.chat(mock_model, system_prompt, user_prompt)
 
@@ -101,10 +101,10 @@ async def test_chat_response_choices_empty_logs_error(mocker, caplog):
 @pytest.mark.asyncio
 async def test_chat_response_choices_message_content_none_logs_error(mocker, caplog):
     mistral.client = mocker.AsyncMock(return_value=mock_client)
-    response = create_mock_chat_response(content_response)
-    assert response.choices is not None
-    response.choices[0].message.content = None
-    mistral.client.chat.complete_async.return_value = response
+    chat_response = create_mock_chat_response(content_response)
+    assert chat_response.choices is not None
+    chat_response.choices[0].message.content = None
+    mistral.client.chat.complete_async.return_value = chat_response
 
     response = await mistral.chat(mock_model, system_prompt, user_prompt)
 
@@ -119,10 +119,10 @@ async def test_chat_response_choices_message_content_none_logs_error(mocker, cap
 @pytest.mark.asyncio
 async def test_chat_response_choices_message_content_unset_logs_error(mocker, caplog):
     mistral.client = mocker.AsyncMock(return_value=mock_client)
-    response = create_mock_chat_response(content_response)
-    assert response.choices is not None
-    response.choices[0].message.content = UNSET
-    mistral.client.chat.complete_async.return_value = response
+    chat_response = create_mock_chat_response(content_response)
+    assert chat_response.choices is not None
+    chat_response.choices[0].message.content = UNSET
+    mistral.client.chat.complete_async.return_value = chat_response
 
     response = await mistral.chat(mock_model, system_prompt, user_prompt)
 

From d89551060d5498f65061d679df6f77f4cea0c5ec Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Tue, 24 Sep 2024 15:32:24 +0100
Subject: [PATCH 33/48] Simplify Type Checking Carried Out on Api Response

---
 backend/src/llm/mistral.py        | 21 ++++++---------------
 backend/tests/llm/mistral_test.py | 20 ++++++++++++--------
 2 files changed, 18 insertions(+), 23 deletions(-)

diff --git a/backend/src/llm/mistral.py b/backend/src/llm/mistral.py
index 552534123..a438a0927 100644
--- a/backend/src/llm/mistral.py
+++ b/backend/src/llm/mistral.py
@@ -21,23 +21,14 @@ async def chat(self, model, system_prompt: str, user_prompt: str, return_json=Fa
             temperature=0,
             response_format={"type": "json_object"} if return_json else None,
         )
-        if response is None:
-            logger.error("Call to mistral api failed: response was None")
+        if not response or not response.choices:
+            logger.error("Call to Mistral API failed: No valid response or choices received")
             return "An error occurred while processing the request."
 
-        if response.choices is None:
-            logger.error("Call to mistral api failed: response.choices was None")
-            return "An error occurred while processing the request."
-
-        if len(response.choices) < 1:
-            logger.error("Call to mistral api failed: response.choices was empty")
+        content = response.choices[0].message.content
+        if not content:
+            logger.error("Call to Mistral API failed: message content is None or Unset")
             return "An error occurred while processing the request."
 
         logger.debug('{0} response : "{1}"'.format(model, response.choices[0].message.content))
-
-        content = response.choices[0].message.content
-        if isinstance(content, str):
-            return content
-        else:
-            logger.error("Call to mistral api failed: message.content was None or Unset")
-            return "An error occurred while processing the request."
+        return content
diff --git a/backend/tests/llm/mistral_test.py b/backend/tests/llm/mistral_test.py
index a7a7abfbf..f92a5ee86 100644
--- a/backend/tests/llm/mistral_test.py
+++ b/backend/tests/llm/mistral_test.py
@@ -61,15 +61,19 @@ async def test_chat_response_none_logs_error(mocker, caplog):
     response = await mistral.chat(mock_model, system_prompt, user_prompt)
 
     assert response == "An error occurred while processing the request."
-    assert ("src.llm.mistral", logging.ERROR, "Call to mistral api failed: response was None") in caplog.record_tuples
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to Mistral API failed: No valid response or choices received",
+    ) in caplog.record_tuples
 
 
 @pytest.mark.asyncio
 async def test_chat_response_choices_none_logs_error(mocker, caplog):
     mistral.client = mocker.AsyncMock(return_value=mock_client)
-    response = create_mock_chat_response(content_response)
-    response.choices = None
-    mistral.client.chat.complete_async.return_value = response
+    chat_response = create_mock_chat_response(content_response)
+    chat_response.choices = None
+    mistral.client.chat.complete_async.return_value = chat_response
 
     response = await mistral.chat(mock_model, system_prompt, user_prompt)
 
@@ -77,7 +81,7 @@ async def test_chat_response_choices_none_logs_error(mocker, caplog):
     assert (
         "src.llm.mistral",
         logging.ERROR,
-        "Call to mistral api failed: response.choices was None",
+        "Call to Mistral API failed: No valid response or choices received",
     ) in caplog.record_tuples
 
 
@@ -94,7 +98,7 @@ async def test_chat_response_choices_empty_logs_error(mocker, caplog):
     assert (
         "src.llm.mistral",
         logging.ERROR,
-        "Call to mistral api failed: response.choices was empty",
+        "Call to Mistral API failed: No valid response or choices received",
     ) in caplog.record_tuples
 
 
@@ -112,7 +116,7 @@ async def test_chat_response_choices_message_content_none_logs_error(mocker, cap
     assert (
         "src.llm.mistral",
         logging.ERROR,
-        "Call to mistral api failed: message.content was None or Unset",
+        "Call to Mistral API failed: message content is None or Unset",
     ) in caplog.record_tuples
 
 
@@ -130,5 +134,5 @@ async def test_chat_response_choices_message_content_unset_logs_error(mocker, ca
     assert (
         "src.llm.mistral",
         logging.ERROR,
-        "Call to mistral api failed: message.content was None or Unset",
+        "Call to Mistral API failed: message content is None or Unset",
     ) in caplog.record_tuples

From b601ee1acdca1851b1d932fa7f26e789e8cc8cb6 Mon Sep 17 00:00:00 2001
From: swood <110558776+swood-scottlogic@users.noreply.github.com>
Date: Tue, 24 Sep 2024 15:54:14 +0100
Subject: [PATCH 34/48] Use Existing Local Variable in Log

---
 backend/src/llm/mistral.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/backend/src/llm/mistral.py b/backend/src/llm/mistral.py
index a438a0927..18974b4ef 100644
--- a/backend/src/llm/mistral.py
+++ b/backend/src/llm/mistral.py
@@ -30,5 +30,5 @@ async def chat(self, model, system_prompt: str, user_prompt: str, return_json=Fa
             logger.error("Call to Mistral API failed: message content is None or Unset")
             return "An error occurred while processing the request."
 
-        logger.debug('{0} response : "{1}"'.format(model, response.choices[0].message.content))
+        logger.debug('{0} response : "{1}"'.format(model, content))
         return content

From 99b9a99926985ea0ff233f587fc2087a1e045888 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Wed, 25 Sep 2024 09:03:15 +0100
Subject: [PATCH 35/48] FIrst commit with intent changes

---
 backend/src/agents/__init__.py                |  3 ++
 backend/src/agents/maths_agent.py             | 49 ++++++++++++++++++-
 .../src/prompts/templates/best-next-step.j2   |  1 +
 backend/src/prompts/templates/intent.j2       |  7 +--
 backend/src/router.py                         |  1 +
 5 files changed, 57 insertions(+), 4 deletions(-)

diff --git a/backend/src/agents/__init__.py b/backend/src/agents/__init__.py
index d16ade6f7..c06af259f 100644
--- a/backend/src/agents/__init__.py
+++ b/backend/src/agents/__init__.py
@@ -9,6 +9,8 @@
 from .answer_agent import AnswerAgent
 from .chart_generator_agent import ChartGeneratorAgent
 from .file_agent import FileAgent
+from .maths_agent import MathsAgent
+
 
 config = Config()
 
@@ -34,6 +36,7 @@ def get_available_agents() -> List[Agent]:
             WebAgent(config.web_agent_llm, config.web_agent_model),
             ChartGeneratorAgent(config.chart_generator_llm, config.chart_generator_model),
             FileAgent(config.file_agent_llm, config.file_agent_model),
+            MathsAgent(config.maths_agent_llm, config.maths_agent_model),
             ]
 
 
diff --git a/backend/src/agents/maths_agent.py b/backend/src/agents/maths_agent.py
index e8833ffc5..77743dcf4 100644
--- a/backend/src/agents/maths_agent.py
+++ b/backend/src/agents/maths_agent.py
@@ -48,10 +48,57 @@ async def compare_two_values(value_one, thing_one, value_two, thing_two) -> str:
         return f"You have spent more on {thing_two} ({value_two}) than {thing_one} ({value_one}) in the last month"
 
 
+@tool(
+    name="round a number",
+    description="Rounds a provided number to the specified number of decimal places",
+    parameters={
+        "number": Parameter(
+            type="number",
+            description="The number to round off",
+        ),
+        "decimal_places": Parameter(
+            type="number",
+            description="The number of decimal places to round to (e.g. 2)",
+        ),
+    },
+)
+async def round_number(number, decimal_places) -> str:
+    return f"The number {number} rounded to {decimal_places} decimal places is {round(number, decimal_places)}"
+
+@tool(
+    name="find maximum value",
+    description="Finds the maximum value in a provided list",
+    parameters={
+        "list_of_values": Parameter(
+            type="list[number]",
+            description="Python list of comma-separated values (e.g. [1, 5, 3])",
+        ),
+    },
+)
+async def find_max_value(list_of_values) -> str:
+    if not isinstance(list_of_values, list):
+        raise Exception("Method not passed a valid Python list")
+    return f"The maximum value in the list {list_of_values} is {max(list_of_values)}"
+
+@tool(
+    name="find minimum value",
+    description="Finds the minimum value in a provided list",
+    parameters={
+        "list_of_values": Parameter(
+            type="list[number]",
+            description="Python list of comma-separated values (e.g. [1, 5, 3])",
+        ),
+    },
+)
+async def find_min_value(list_of_values) -> str:
+    if not isinstance(list_of_values, list):
+        raise Exception("Method not passed a valid Python list")
+    return f"The minimum value in the list {list_of_values} is {min(list_of_values)}"
+
 @agent(
     name="MathsAgent",
     description="This agent is responsible for solving number comparison and calculation tasks",
-    tools=[sum_list_of_values, compare_two_values],
+    tools=[sum_list_of_values, compare_two_values, round_number, find_max_value, find_min_value],
 )
 class MathsAgent(Agent):
     pass
diff --git a/backend/src/prompts/templates/best-next-step.j2 b/backend/src/prompts/templates/best-next-step.j2
index 417d357c8..59942baa6 100644
--- a/backend/src/prompts/templates/best-next-step.j2
+++ b/backend/src/prompts/templates/best-next-step.j2
@@ -27,6 +27,7 @@ If the list of agents does not contain something suitable, you should say the ag
 ## Determine the next best step
 Your task is to pick one of the mentioned agents above to complete the task.
 If the same agent_name and task are repeated more than twice in the history, you must not pick that agent_name.
+If mathematical processing (e.g., rounding or calculations) is needed, choose the MathsAgent. If file operations are needed, choose the FileAgent.
 
 Your decisions must always be made independently without seeking user assistance.
 Play to your strengths as an LLM and pursue simple strategies with no legal complications.
diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index a1361b041..16d65edde 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -15,12 +15,13 @@ Guidelines:
 1. Determine each distinct intent in the question.
 2. Sequence the intents: Identify which intent should be tackled first and which should follow.
 3. For each intent, specify:
-    - The exact operation required (e.g., "literal search", "filter + aggregation", "data visualization").
+    - The exact operation required (e.g., "literal search", "filter + aggregation", "data transformation", "mathematical operation" for maths-related tasks like rounding).
     - The category of the question (e.g., "data-driven", "data presentation", "general knowledge").
     - Any specific parameters or conditions that apply.
     - The correct aggregation and sorting methods if applicable.
-4. Avoid conflating intents: If a user's query asks for data retrieval and its visualization, treat these as separate operations.
-5. Do not make assumptions or create hypothetical data. Use only concrete data where applicable.
+4. For mathematical operations (like rounding, addition, or finding maximum values), use **MathsAgent** for the task.
+5. Avoid conflating intents: If a user's query asks for data retrieval and its visualization, treat these as separate operations.
+6. Do not make assumptions or create hypothetical data. Use only concrete data where applicable.
  
 Specify an operation type under the operation key; here are a few examples: 
 
diff --git a/backend/src/router.py b/backend/src/router.py
index 46be54cae..a6affc23e 100644
--- a/backend/src/router.py
+++ b/backend/src/router.py
@@ -27,6 +27,7 @@ def build_best_next_step_prompt(task, scratchpad):
 
 async def build_plan(task, llm: LLM, scratchpad, model):
     best_next_step_prompt = build_best_next_step_prompt(task, scratchpad)
+    await publish_log_info(LogPrefix.USER, f"**************** Best next step prompt: {best_next_step_prompt}", __name__)
 
     # Call model to choose agent
     logger.info("#####  ~  Calling LLM for next best step  ~  #####")

From ff4fb6d5c7504f0933765492228557df664843f3 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Wed, 25 Sep 2024 14:09:37 +0100
Subject: [PATCH 36/48] Second commit with web scrape agent

---
 backend/requirements.txt                     |   2 +-
 backend/src/agents/agent.py                  |   2 +
 backend/src/agents/maths_agent.py            | 193 +++++++++++--------
 backend/src/agents/web_agent.py              |  30 ++-
 backend/src/llm/mistral.py                   |  21 +-
 backend/src/prompts/templates/math-solver.j2 |  25 +++
 backend/src/supervisors/supervisor.py        |   1 +
 backend/src/utils/web_utils.py               |  26 +++
 backend/tests/llm/mistral_test.py            | 124 +++++++++---
 9 files changed, 306 insertions(+), 118 deletions(-)
 create mode 100644 backend/src/prompts/templates/math-solver.j2

diff --git a/backend/requirements.txt b/backend/requirements.txt
index 3dae443a3..59917b6dd 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -1,6 +1,6 @@
 fastapi==0.110.0
 uvicorn==0.29.0
-mistralai==0.1.8
+mistralai==1.1.0
 pycodestyle==2.11.1
 python-dotenv==1.0.1
 neo4j==5.18.0
diff --git a/backend/src/agents/agent.py b/backend/src/agents/agent.py
index ada5e9f47..221730d65 100644
--- a/backend/src/agents/agent.py
+++ b/backend/src/agents/agent.py
@@ -56,7 +56,9 @@ async def __get_action(self, utterance: str) -> Action_and_args:
 
     async def invoke(self, utterance: str) -> str:
         (action, args) = await self.__get_action(utterance)
+        logger.info(f"Calling action {action} with arguments: {args}, LLM: {self.llm}, Model: {self.model}")
         logger.info(f"USER - Action: {action} and args: {args} for utterance: {utterance}")
+        
         result_of_action = await action(**args, llm=self.llm, model=self.model)
         await publish_log_info(LogPrefix.USER, f"Action gave result: {result_of_action}", __name__)
         return result_of_action
diff --git a/backend/src/agents/maths_agent.py b/backend/src/agents/maths_agent.py
index 77743dcf4..10c59ba77 100644
--- a/backend/src/agents/maths_agent.py
+++ b/backend/src/agents/maths_agent.py
@@ -1,104 +1,133 @@
 from .tool import tool
 from .agent_types import Parameter
 from .agent import Agent, agent
+import logging
+from src.utils import Config
+from .validator_agent import ValidatorAgent
+import json
+from src.utils.web_utils import perform_math_operation_util
 
+logger = logging.getLogger(__name__)
+config = Config()
 
-@tool(
-    name="sum list of values",
-    description="sums a list of provided values",
-    parameters={
-        "list_of_values": Parameter(
-            type="list[number]",
-            description="Python list of comma separated values (e.g. [1, 5, 3])",
-        )
-    },
-)
-async def sum_list_of_values(list_of_values) -> str:
-    if not isinstance(list_of_values, list):
-        raise Exception("Method not passed a valid Python list")
-    return f"The sum of all the values passed {list_of_values} is {str(sum(list_of_values))}"
+# @tool(
+#     name="sum list of values",
+#     description="sums a list of provided values",
+#     parameters={
+#         "list_of_values": Parameter(
+#             type="list[number]",
+#             description="Python list of comma separated values (e.g. [1, 5, 3])",
+#         )
+#     },
+# )
+# async def sum_list_of_values(list_of_values) -> str:
+#     if not isinstance(list_of_values, list):
+#         raise Exception("Method not passed a valid Python list")
+#     return f"The sum of all the values passed {list_of_values} is {str(sum(list_of_values))}"
 
 
-@tool(
-    name="compare two values",
-    description="Compare two passed values and return information on which one is greater",
-    parameters={
-        "thing_one": Parameter(
-            type="string",
-            description="first thing for comparison",
-        ),
-        "value_one": Parameter(
-            type="number",
-            description="value of first thing",
-        ),
-        "thing_two": Parameter(
-            type="string",
-            description="second thing for comparison",
-        ),
-        "value_two": Parameter(
-            type="number",
-            description="value of first thing",
-        ),
-    },
-)
-async def compare_two_values(value_one, thing_one, value_two, thing_two) -> str:
-    if value_one > value_two:
-        return f"You have spent more on {thing_one} ({value_one}) than {thing_two} ({value_two}) in the last month"
-    else:
-        return f"You have spent more on {thing_two} ({value_two}) than {thing_one} ({value_one}) in the last month"
+# @tool(
+#     name="compare two values",
+#     description="Compare two passed values and return information on which one is greater",
+#     parameters={
+#         "thing_one": Parameter(
+#             type="string",
+#             description="first thing for comparison",
+#         ),
+#         "value_one": Parameter(
+#             type="number",
+#             description="value of first thing",
+#         ),
+#         "thing_two": Parameter(
+#             type="string",
+#             description="second thing for comparison",
+#         ),
+#         "value_two": Parameter(
+#             type="number",
+#             description="value of first thing",
+#         ),
+#     },
+# )
+# async def compare_two_values(value_one, thing_one, value_two, thing_two) -> str:
+#     if value_one > value_two:
+#         return f"You have spent more on {thing_one} ({value_one}) than {thing_two} ({value_two}) in the last month"
+#     else:
+#         return f"You have spent more on {thing_two} ({value_two}) than {thing_one} ({value_one}) in the last month"
 
+# Core function to perform the math operation by calling the util function
+async def perform_math_operation_core(math_query, llm, model) -> str:
+    try:
+        # Call the utility function to perform the math operation
+        math_operation_result = await perform_math_operation_util(math_query, llm, model)
+        result_json = json.loads(math_operation_result)
 
-@tool(
-    name="round a number",
-    description="Rounds a provided number to the specified number of decimal places",
-    parameters={
-        "number": Parameter(
-            type="number",
-            description="The number to round off",
-        ),
-        "decimal_places": Parameter(
-            type="number",
-            description="The number of decimal places to round to (e.g. 2)",
-        ),
-    },
-)
-async def round_number(number, decimal_places) -> str:
-    return f"The number {number} rounded to {decimal_places} decimal places is {round(number, decimal_places)}"
+        if result_json.get("status") == "success":
+            # Extract the relevant response (math result) from the utility function's output
+            response = result_json.get("response", {})
+            response_json = json.loads(response)
+            result = response_json.get("result", "")
+            steps = response_json.get("steps", "")
+            reasoning = response_json.get("reasoning", "")
 
-@tool(
-    name="find maximum value",
-    description="Finds the maximum value in a provided list",
-    parameters={
-        "list_of_values": Parameter(
-            type="list[number]",
-            description="Python list of comma-separated values (e.g. [1, 5, 3])",
-        ),
-    },
-)
-async def find_max_value(list_of_values) -> str:
-    if not isinstance(list_of_values, list):
-        raise Exception("Method not passed a valid Python list")
-    return f"The maximum value in the list {list_of_values} is {max(list_of_values)}"
+            if result:
+                logger.info(f"Math operation successful: {result}")
+                is_valid = await is_valid_answer(result, math_query)
+
+                if is_valid:
+                    response = {
+                        "content": result,
+                        "ignore_validation": "true"
+                    }
+                    return json.dumps(response, indent=4)
+            else:
+                return "No valid result found for the math query."
+        else:
+            return json.dumps(
+                {
+                    "status": "error",
+                    "response": None,
+                    "error": result_json.get("error", "Unknown error"),
+                }, indent=4
+            )
+    except Exception as e:
+        logger.error(f"Error in perform_math_operation_core: {e}")
+        return json.dumps(
+            {
+                "status": "error",
+                "response": None,
+                "error": str(e),
+            }, indent=4
+        )
 
+def get_validator_agent() -> Agent:
+    return ValidatorAgent(config.validator_agent_llm, config.validator_agent_model)
+
+async def is_valid_answer(answer, task) -> bool:
+    is_valid = (await get_validator_agent().invoke(f"Task: {task}  Answer: {answer}")).lower() == "true"
+    return is_valid
+
+# Math Operation Tool
 @tool(
-    name="find minimum value",
-    description="Finds the minimum value in a provided list",
+    name="perform_math_operation",
+    description=(
+        "Use this tool to perform complex mathematical operations or calculations. "
+        "It can handle queries related to arithmetic operations, algebra, or calculations involving large numbers."
+    ),
     parameters={
-        "list_of_values": Parameter(
-            type="list[number]",
-            description="Python list of comma-separated values (e.g. [1, 5, 3])",
+        "math_query": Parameter(
+            type="string",
+            description="The mathematical query or equation to solve."
         ),
     },
 )
-async def find_min_value(list_of_values) -> str:
-    if not isinstance(list_of_values, list):
-        raise Exception("Method not passed a valid Python list")
-    return f"The minimum value in the list {list_of_values} is {min(list_of_values)}"
+async def perform_math_operation(math_query, llm, model) -> str:
+    return await perform_math_operation_core(math_query, llm, model)
 
+# MathAgent definition
 @agent(
     name="MathsAgent",
-    description="This agent is responsible for solving number comparison and calculation tasks",
-    tools=[sum_list_of_values, compare_two_values, round_number, find_max_value, find_min_value],
+    description="This agent is responsible for handling mathematical queries and providing solutions.",
+    tools=[perform_math_operation],
 )
 class MathsAgent(Agent):
     pass
diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py
index 627a14fc6..1177acb59 100644
--- a/backend/src/agents/web_agent.py
+++ b/backend/src/agents/web_agent.py
@@ -103,6 +103,34 @@ async def web_general_search(search_query, llm, model) -> str:
 async def web_pdf_download(pdf_url, llm, model) -> str:
     return await web_pdf_download_core(pdf_url, llm, model)
 
+async def web_scrape_price_core(url: str) -> str:
+    try:
+        logger.info(f"Scraping the price of the book from URL: {url}")
+        # Scrape the content from the provided URL
+        content = await scrape_content(url)
+        if not content:
+            return "No content found at the provided URL."
+        logger.info(f"Content scraped successfully: {content}")
+       
+    except Exception as e:
+        logger.error(f"Error in web_scrape_price_core: {e}")
+        return json.dumps({"status": "error", "error": str(e)})
+
+
+@tool(
+    name="web_scrape_price",
+    description="Scrapes the price of a book from a given URL and writes it to a .txt file.",
+    parameters={
+        "url": Parameter(
+            type="string",
+            description="The URL of the book page to scrape the price from.",
+        ),
+    },
+)
+async def web_scrape_price(url: str, llm, model) -> str:
+    logger.info(f"Scraping the price of the book from URL: {url}")
+    return await web_scrape_price_core(url)
+
 def get_validator_agent() -> Agent:
     return ValidatorAgent(config.validator_agent_llm, config.validator_agent_model)
 
@@ -159,7 +187,7 @@ async def perform_pdf_summarization(content: str, llm: Any, model: str) -> str:
 @agent(
     name="WebAgent",
     description="This agent is responsible for handling web search queries and summarizing information from the web.",
-    tools=[web_general_search, web_pdf_download],
+    tools=[web_general_search, web_pdf_download, web_scrape_price],
 )
 class WebAgent(Agent):
     pass
diff --git a/backend/src/llm/mistral.py b/backend/src/llm/mistral.py
index 8fac39101..18974b4ef 100644
--- a/backend/src/llm/mistral.py
+++ b/backend/src/llm/mistral.py
@@ -1,5 +1,4 @@
-from mistralai.async_client import MistralAsyncClient
-from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage
+from mistralai import Mistral as MistralApi, UserMessage, SystemMessage
 import logging
 from src.utils import Config
 from .llm import LLM
@@ -9,21 +8,27 @@
 
 
 class Mistral(LLM):
-    client = MistralAsyncClient(api_key=config.mistral_key)
+    client = MistralApi(api_key=config.mistral_key)
 
     async def chat(self, model, system_prompt: str, user_prompt: str, return_json=False) -> str:
         logger.debug("Called llm. Waiting on response model with prompt {0}.".format(str([system_prompt, user_prompt])))
-        response: ChatCompletionResponse = await self.client.chat(
+        response = await self.client.chat.complete_async(
             model=model,
             messages=[
-                ChatMessage(role="system", content=system_prompt),
-                ChatMessage(role="user", content=user_prompt),
+                SystemMessage(content=system_prompt),
+                UserMessage(content=user_prompt),
             ],
             temperature=0,
             response_format={"type": "json_object"} if return_json else None,
         )
-        logger.debug('{0} response : "{1}"'.format(model, response.choices[0].message.content))
+        if not response or not response.choices:
+            logger.error("Call to Mistral API failed: No valid response or choices received")
+            return "An error occurred while processing the request."
 
         content = response.choices[0].message.content
+        if not content:
+            logger.error("Call to Mistral API failed: message content is None or Unset")
+            return "An error occurred while processing the request."
 
-        return content if isinstance(content, str) else " ".join(content)
+        logger.debug('{0} response : "{1}"'.format(model, content))
+        return content
diff --git a/backend/src/prompts/templates/math-solver.j2 b/backend/src/prompts/templates/math-solver.j2
new file mode 100644
index 000000000..b67a91907
--- /dev/null
+++ b/backend/src/prompts/templates/math-solver.j2
@@ -0,0 +1,25 @@
+You are an expert in performing mathematical operations. You are highly skilled in handling various mathematical queries such as expressing numbers in millions, performing arithmetic operations, and applying formulas as requested by the user.
+
+You will be given a mathematical query, and your task is to solve the query based on the provided information. Ensure that you apply the appropriate mathematical principles to deliver an exact result, specifically converting large numbers to millions without rounding off.
+
+Make sure to perform the calculations step by step, when necessary, and return the final result clearly.
+
+User's query is:
+{{ query }}
+
+Reply only in json with the following format:
+
+{
+    "result": "The final result of the mathematical operation, expressed in millions without rounding",
+    "steps": "A breakdown of the steps involved in solving the query (if applicable)",
+    "reasoning": "A sentence on why this result is accurate"
+}
+
+Following is an example of the query and the expected response format:
+query: Round 81.462 billion to the nearest million
+
+{
+    "result": "81,462 million",
+    "steps": "1. Convert 81.462 billion to million by multiplying by 1000. Round the result to the nearest million.",
+    "reasoning": "Rounding to the nearest million ensures that the result is represented in a more practical figure, without exceeding or falling short of the actual value."
+ }
\ No newline at end of file
diff --git a/backend/src/supervisors/supervisor.py b/backend/src/supervisors/supervisor.py
index f93915ccf..85c030e00 100644
--- a/backend/src/supervisors/supervisor.py
+++ b/backend/src/supervisors/supervisor.py
@@ -36,6 +36,7 @@ async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]:
     logger.info(f"Agent selected: {agent}")
     if agent is None:
         raise Exception(no_agent_response)
+    logger.info(f"Task is {task}")
     answer = await agent.invoke(task)
     parsed_json = json.loads(answer)
     status = parsed_json.get('status', 'success')
diff --git a/backend/src/utils/web_utils.py b/backend/src/utils/web_utils.py
index 90082be76..cd555ebec 100644
--- a/backend/src/utils/web_utils.py
+++ b/backend/src/utils/web_utils.py
@@ -104,3 +104,29 @@ async def summarise_pdf_content(contents, llm, model) -> str:
                 "error": str(e),
             }
         )
+
+async def perform_math_operation_util(math_query, llm, model) -> str:
+    try:
+        # Load the prompt template for math operations
+        math_prompt = engine.load_prompt("math-solver", query=math_query)
+
+        # Send the math query to the LLM to perform the math operation
+        response = await llm.chat(model, math_prompt, "", return_json=True)
+        # Parse the response and return the result
+        return json.dumps(
+            {
+                "status": "success",
+                "response": response,  # math result
+                "error": None,
+            }
+        )
+    except Exception as e:
+        # Handle any errors during the LLM math operation
+        logger.error(f"Error during math operation: {e}")
+        return json.dumps(
+            {
+                "status": "error",
+                "response": None,
+                "error": str(e),
+            }
+        )
diff --git a/backend/tests/llm/mistral_test.py b/backend/tests/llm/mistral_test.py
index 86f87f7b2..f92a5ee86 100644
--- a/backend/tests/llm/mistral_test.py
+++ b/backend/tests/llm/mistral_test.py
@@ -1,10 +1,8 @@
+import logging
 from typing import cast
-from unittest.mock import MagicMock
-from mistralai.async_client import MistralAsyncClient
-from mistralai.models.chat_completion import ChatCompletionResponse
-from mistralai.models.chat_completion import ChatCompletionResponseChoice
-from mistralai.models.chat_completion import ChatMessage
-from mistralai.models.common import UsageInfo
+from unittest.mock import AsyncMock, MagicMock
+from mistralai import UNSET, AssistantMessage, Mistral as MistralApi, SystemMessage, UserMessage
+from mistralai.models import ChatCompletionResponse, ChatCompletionChoice, UsageInfo
 import pytest
 from src.llm import get_llm, Mistral
 from src.utils import Config
@@ -17,34 +15,23 @@
 mistral = cast(Mistral, get_llm("mistral"))
 
 
-async def create_mock_chat_response(content, tool_calls=None):
+def create_mock_chat_response(content, tool_calls=None):
     mock_usage = UsageInfo(prompt_tokens=1, total_tokens=2, completion_tokens=3)
-    mock_message = ChatMessage(role="system", content=content, tool_calls=tool_calls)
-    mock_choice = ChatCompletionResponseChoice(index=0, message=mock_message, finish_reason=None)
+    mock_message = AssistantMessage(content=content, tool_calls=tool_calls)
+    mock_choice = ChatCompletionChoice(index=0, message=mock_message, finish_reason="stop")
     return ChatCompletionResponse(
         id="id", object="object", created=123, model="model", choices=[mock_choice], usage=mock_usage
     )
 
 
-mock_client = MagicMock(spec=MistralAsyncClient)
+mock_client = AsyncMock(spec=MistralApi)
 mock_config = MagicMock(spec=Config)
 
 
 @pytest.mark.asyncio
 async def test_chat_content_string_returns_string(mocker):
-    mistral.client = mocker.MagicMock(return_value=mock_client)
-    mistral.client.chat.return_value = create_mock_chat_response(content_response)
-
-    response = await mistral.chat(mock_model, system_prompt, user_prompt)
-
-    assert response == content_response
-
-
-@pytest.mark.asyncio
-async def test_chat_content_list_returns_string(mocker):
-    content_list = ["Hello", "there"]
-    mistral.client = mocker.MagicMock(return_value=mock_client)
-    mistral.client.chat.return_value = create_mock_chat_response(content_list)
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    mistral.client.chat.complete_async.return_value = create_mock_chat_response(content_response)
 
     response = await mistral.chat(mock_model, system_prompt, user_prompt)
 
@@ -58,9 +45,94 @@ async def test_chat_calls_client_chat(mocker):
     await mistral.chat(mock_model, system_prompt, user_prompt)
 
     expected_messages = [
-        ChatMessage(role="system", content=system_prompt),
-        ChatMessage(role="user", content=user_prompt),
+        SystemMessage(content=system_prompt),
+        UserMessage(content=user_prompt),
     ]
-    mistral.client.chat.assert_called_once_with(
+    mistral.client.chat.complete_async.assert_awaited_once_with(
         messages=expected_messages, model=mock_model, temperature=0, response_format=None
     )
+
+
+@pytest.mark.asyncio
+async def test_chat_response_none_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    mistral.client.chat.complete_async.return_value = None
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to Mistral API failed: No valid response or choices received",
+    ) in caplog.record_tuples
+
+
+@pytest.mark.asyncio
+async def test_chat_response_choices_none_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    chat_response = create_mock_chat_response(content_response)
+    chat_response.choices = None
+    mistral.client.chat.complete_async.return_value = chat_response
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to Mistral API failed: No valid response or choices received",
+    ) in caplog.record_tuples
+
+
+@pytest.mark.asyncio
+async def test_chat_response_choices_empty_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    chat_response = create_mock_chat_response(content_response)
+    chat_response.choices = []
+    mistral.client.chat.complete_async.return_value = chat_response
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to Mistral API failed: No valid response or choices received",
+    ) in caplog.record_tuples
+
+
+@pytest.mark.asyncio
+async def test_chat_response_choices_message_content_none_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    chat_response = create_mock_chat_response(content_response)
+    assert chat_response.choices is not None
+    chat_response.choices[0].message.content = None
+    mistral.client.chat.complete_async.return_value = chat_response
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to Mistral API failed: message content is None or Unset",
+    ) in caplog.record_tuples
+
+
+@pytest.mark.asyncio
+async def test_chat_response_choices_message_content_unset_logs_error(mocker, caplog):
+    mistral.client = mocker.AsyncMock(return_value=mock_client)
+    chat_response = create_mock_chat_response(content_response)
+    assert chat_response.choices is not None
+    chat_response.choices[0].message.content = UNSET
+    mistral.client.chat.complete_async.return_value = chat_response
+
+    response = await mistral.chat(mock_model, system_prompt, user_prompt)
+
+    assert response == "An error occurred while processing the request."
+    assert (
+        "src.llm.mistral",
+        logging.ERROR,
+        "Call to Mistral API failed: message content is None or Unset",
+    ) in caplog.record_tuples

From 925bd92d87d9983da4d1fb88dd14c1dc407abc4b Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 3 Oct 2024 10:57:50 +0100
Subject: [PATCH 37/48] Push the changes for the first 4 tests

---
 backend/src/agents/maths_agent.py             | 56 ++----------------
 backend/src/agents/web_agent.py               | 58 +++++++++++++------
 .../prompts/templates/create-search-term.j2   | 16 +++++
 backend/src/prompts/templates/intent.j2       |  4 ++
 backend/src/prompts/templates/validator.j2    | 14 +++++
 backend/src/utils/web_utils.py                | 41 ++++++++-----
 6 files changed, 106 insertions(+), 83 deletions(-)
 create mode 100644 backend/src/prompts/templates/create-search-term.j2

diff --git a/backend/src/agents/maths_agent.py b/backend/src/agents/maths_agent.py
index f818ebebc..6b64f27f6 100644
--- a/backend/src/agents/maths_agent.py
+++ b/backend/src/agents/maths_agent.py
@@ -10,55 +10,11 @@
 logger = logging.getLogger(__name__)
 config = Config()
 
-# @tool(
-#     name="sum list of values",
-#     description="sums a list of provided values",
-#     parameters={
-#         "list_of_values": Parameter(
-#             type="list[number]",
-#             description="Python list of comma separated values (e.g. [1, 5, 3])",
-#         )
-#     },
-# )
-# async def sum_list_of_values(list_of_values) -> str:
-#     if not isinstance(list_of_values, list):
-#         raise Exception("Method not passed a valid Python list")
-#     return f"The sum of all the values passed {list_of_values} is {str(sum(list_of_values))}"
-
-
-# @tool(
-#     name="compare two values",
-#     description="Compare two passed values and return information on which one is greater",
-#     parameters={
-#         "thing_one": Parameter(
-#             type="string",
-#             description="first thing for comparison",
-#         ),
-#         "value_one": Parameter(
-#             type="number",
-#             description="value of first thing",
-#         ),
-#         "thing_two": Parameter(
-#             type="string",
-#             description="second thing for comparison",
-#         ),
-#         "value_two": Parameter(
-#             type="number",
-#             description="value of first thing",
-#         ),
-#     },
-# )
-# async def compare_two_values(value_one, thing_one, value_two, thing_two) -> str:
-#     if value_one > value_two:
-#         return f"You have spent more on {thing_one} ({value_one}) than {thing_two} ({value_two}) in the last month"
-#     else:
-#         return f"You have spent more on {thing_two} ({value_two}) than {thing_one} ({value_one}) in the last month"
-
-# Core function to perform the math operation by calling the util function
 async def perform_math_operation_core(math_query, llm, model) -> str:
     try:
         # Call the utility function to perform the math operation
         math_operation_result = await perform_math_operation_util(math_query, llm, model)
+
         result_json = json.loads(math_operation_result)
 
         if result_json.get("status") == "success":
@@ -66,13 +22,10 @@ async def perform_math_operation_core(math_query, llm, model) -> str:
             response = result_json.get("response", {})
             response_json = json.loads(response)
             result = response_json.get("result", "")
-            # steps = response_json.get("steps", "")
-            # reasoning = response_json.get("reasoning", "")
-
             if result:
                 logger.info(f"Math operation successful: {result}")
                 is_valid = await is_valid_answer(result, math_query)
-
+                logger.info(f"Is the answer valid: {is_valid}")
                 if is_valid:
                     response = {
                         "content": result,
@@ -133,7 +86,10 @@ async def perform_math_operation(math_query, llm, model) -> str:
 # MathAgent definition
 @agent(
     name="MathsAgent",
-    description="This agent is responsible for handling mathematical queries and providing solutions.",
+    description=(
+        "This agent is responsible for handling mathematical queries and can perform "
+        "necessary rounding and formatting operations."
+    ),
     tools=[perform_math_operation],
 )
 class MathsAgent(Agent):
diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py
index ffead6d20..25197c96f 100644
--- a/backend/src/agents/web_agent.py
+++ b/backend/src/agents/web_agent.py
@@ -4,7 +4,14 @@
 from .agent import Agent, agent
 from .tool import tool
 from src.utils import Config
-from src.utils.web_utils import search_urls, scrape_content, summarise_content, summarise_pdf_content, find_info
+from src.utils.web_utils import (
+    search_urls,
+    scrape_content,
+    summarise_content,
+    summarise_pdf_content,
+    find_info,
+    create_search_term
+)
 from .validator_agent import ValidatorAgent
 import aiohttp
 import io
@@ -20,27 +27,42 @@
 
 async def web_general_search_core(search_query, llm, model) -> str:
     try:
-        search_result = await perform_search(search_query, num_results=15)
-        if search_result["status"] == "error":
+        # Step 1: Generate the search term from the user's query
+        search_term_json = await create_search_term(search_query, llm, model)
+        search_term_result = json.loads(search_term_json)
+
+        # Check if there was an error in generating the search term
+        if search_term_result.get("status") == "error":
+            return "No search term found for the given query."
+        search_term = json.loads(search_term_result["response"]).get("search_term", "")
+
+        # Step 2: Perform the search using the generated search term
+        search_result = await perform_search(search_term, num_results=15)
+        if search_result.get("status") == "error":
             return "No relevant information found on the internet for the given query."
-        urls = search_result["urls"]
+        urls = search_result.get("urls", [])
         logger.info(f"URLs found: {urls}")
+
+        # Step 3: Scrape content from the URLs found
         for url in urls:
             content = await perform_scrape(url)
             if not content:
-                continue
-            summarisation = await perform_summarization(search_query, content, llm, model)
-            if not summarisation:
-                continue
-            is_valid = await is_valid_answer(summarisation, search_query)
-            parsed_json = json.loads(summarisation)
-            summary = parsed_json.get('summary', '')
-            if is_valid:
-                response = {
-                    "content": summary,
-                    "ignore_validation": "false"
-                }
-                return json.dumps(response, indent=4)
+                continue  # Skip to the next URL if no content is found
+            logger.info(f"Content scraped successfully: {content}")
+            # Step 4: Summarize the scraped content based on the search term
+            summary = await perform_summarization(search_term, content, llm, model)
+            if not summary:
+                continue  # Skip if no summary was generated
+
+            # Step 5: Validate the summarization
+            is_valid = await is_valid_answer(summary, search_term)
+            if not is_valid:
+                continue # Skip if the summarization is not valid
+            response = {
+                "content": summary,
+                "ignore_validation": "false"
+            }
+            return json.dumps(response, indent=4)
         return "No relevant information found on the internet for the given query."
     except Exception as e:
         logger.error(f"Error in web_general_search_core: {e}")
@@ -210,7 +232,7 @@ async def perform_summarization(search_query: str, content: str, llm: Any, model
         if summarise_result["status"] == "error":
             return ""
         logger.info(f"Content summarized successfully: {summarise_result['response']}")
-        return summarise_result["response"]
+        return json.loads(summarise_result["response"])["summary"]
     except Exception as e:
         logger.error(f"Error summarizing content: {e}")
         return ""
diff --git a/backend/src/prompts/templates/create-search-term.j2 b/backend/src/prompts/templates/create-search-term.j2
new file mode 100644
index 000000000..6da187aaa
--- /dev/null
+++ b/backend/src/prompts/templates/create-search-term.j2
@@ -0,0 +1,16 @@
+You are an expert at crafting Google search terms. Your goal is to generate an optimal search query based on the user's question to find the most relevant information on Google.
+
+Your entire purpose is to analyze the user's query, extract the essential keywords, and create a concise, well-structured search term that will yield the most accurate and useful results when used in a Google search.
+
+Ensure that the search query:
+
+Is relevant to the user’s question.
+Contains the right combination of keywords.
+Avoids unnecessary words, focusing only on what is critical for finding the right information.
+User's question is: {{ question }}
+
+Reply only in JSON format, following this structure:
+{
+    "search_term": "The optimized Google search term based on the user's question",
+    "reasoning": "A sentence on why you chose that search term"
+}
\ No newline at end of file
diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index ec936b4a9..a332b5385 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -86,3 +86,7 @@ Response:
 Q: Write the price of the book in this URL: http://books.toscrape.com/catalogue/meditations_33/index.html into a .txt file.
 Response:
 {"query": "Write the price of the book in this URL: http://books.toscrape.com/catalogue/meditations_33/index.html into a .txt file.", "user_intent": "scrape price and save to file", "questions": [{"query": "Scrape the content of the page at the URL: http://books.toscrape.com/catalogue/meditations_33/index.html", "question_intent": "scrape content", "operation": "online search and scraping", "question_category": "content extraction", "parameters": [{"type": "URL", "value": "http://books.toscrape.com/catalogue/meditations_33/index.html"}]}, {"query": "Extract the price of the book from the scraped content.", "question_intent": "find specific information", "operation": "literal search", "question_category": "data-driven", "parameters": [{"type": "attribute", "value": "price"}]}, {"query": "Write the price into a file called price.txt.", "question_intent": "save to file", "operation": "write file content", "question_category": "file handling", "parameters": [{"type": "file", "value": "price.txt"}]}]}
+
+Q: Please find tesla's revenue every year since its creation. Use the US notation, with a precision rounded to the nearest million dollars (for instance, $31,578 millions).
+Reponse:
+{"query": "Please find tesla's revenue every year since its creation. Use the US notation, with a precision rounded to the nearest million dollars.", "user_intent": "find tesla's revenue history in the US dollors", "questions": [{"query": "Please find tesla's revenue every year since its creation in the US dollors.", "question_intent": "retrieve revenue information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Tesla"}, {"type": "currency", "value": "US dollars"}, {"type": "timeframe", "value": "since creation"}], "aggregation": "none", "sort_order": "none", "timeframe": "since creation"}, {"query": "Round the revenue to the nearest million dollars.", "question_intent": "round to nearest million dollars", "operation": "mathematical operation", "question_category": "data driven", "parameters": [{"type": "rounding", "value": "nearest million dollars"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}
\ No newline at end of file
diff --git a/backend/src/prompts/templates/validator.j2 b/backend/src/prompts/templates/validator.j2
index 8fe5427db..161760185 100644
--- a/backend/src/prompts/templates/validator.j2
+++ b/backend/src/prompts/templates/validator.j2
@@ -24,6 +24,20 @@ Answer: Last month you spend £64.21 on Spotify
 Response: False
 Reasoning: The answer is for Spotify not Amazon.
 
+Task: Please find tesla's revenue every year since its creation.
+Answer: Tesla's annual revenue history from FY 2008 to FY 2023 is available, with figures for 2008 through 2020 taken from previous annual reports.
+Response: False
+Reasoning: The answer is not prvoding any actual figures but just talk about the figures.
+
+Task: Please find tesla's revenue every year since its creation in the US dollars.
+Answer: Tesla's annual revenue in USD since its creation is as follows: 2024 (TTM) $75.92 billion, 2023 $75.95 billion, 2022 $67.33 billion, 2021 $39.76 billion, 2020 $23.10 billion, 2019 $18.52 billion, 2018 $16.81 billion, 2017 $8.70 billion, 2016 $5.67 billion, 2015 $2.72 billion, 2014 $2.05 billion, 2013 $1.21 billion, 2012 $0.25 billion, 2011 $0.13 billion, 2010 $75.88 million, 2009 $69.73 million.
+Response: False
+Reasoning: The answer is providing the revenue in GBP not USD.
+
+Task: Round the following numbers to the nearest million dollars: 96.77B, 81.46B, 53.82B, 31.54B, 24.58B, 21.46B
+Answer: 96,770 million, 81,460 million, 53,820 million, 31,540 million, 24,580 million, 21,460 million
+Reponse: True
+
 You must always return a single boolean value as the response.
 Do not return any additional information, just the boolean value.
 
diff --git a/backend/src/utils/web_utils.py b/backend/src/utils/web_utils.py
index f972223bb..804db2e42 100644
--- a/backend/src/utils/web_utils.py
+++ b/backend/src/utils/web_utils.py
@@ -43,7 +43,7 @@ async def scrape_content(url, limit=100000) -> str:
         async with aiohttp.request("GET", url) as response:
             response.raise_for_status()
             soup = BeautifulSoup(await response.text(), "html.parser")
-            paragraphs = soup.find_all("p")
+            paragraphs = soup.find_all("p" and "table")
             content = " ".join([para.get_text() for para in paragraphs])
             return json.dumps(
                 {
@@ -62,6 +62,26 @@ async def scrape_content(url, limit=100000) -> str:
             }
         )
 
+async def create_search_term(search_query, llm, model) -> str:
+    try:
+        summariser_prompt = engine.load_prompt("create-search-term", question=search_query)
+        response = await llm.chat(model, summariser_prompt, "", return_json=True)
+        return json.dumps(
+            {
+                "status": "success",
+                "response": response,
+                "error": None,
+            }
+        )
+    except Exception as e:
+        logger.error(f"Error during create search term: {e}")
+        return json.dumps(
+            {
+                "status": "error",
+                "response": None,
+                "error": str(e),
+            }
+        )
 
 async def summarise_content(search_query, contents, llm, model) -> str:
     try:
@@ -96,7 +116,7 @@ async def summarise_pdf_content(contents, llm, model) -> str:
             }
         )
     except Exception as e:
-        logger.error(f"Error during summarisation: {e}")
+        logger.error(f"Error during summarisation of PDF: {e}")
         return json.dumps(
             {
                 "status": "error",
@@ -107,12 +127,9 @@ async def summarise_pdf_content(contents, llm, model) -> str:
 
 async def perform_math_operation_util(math_query, llm, model) -> str:
     try:
-        # Load the prompt template for math operations
         math_prompt = engine.load_prompt("math-solver", query=math_query)
-
-        # Send the math query to the LLM to perform the math operation
         response = await llm.chat(model, math_prompt, "", return_json=True)
-        # Parse the response and return the result
+        logger.info(f"Math operation response: {response}")
         return json.dumps(
             {
                 "status": "success",
@@ -121,7 +138,6 @@ async def perform_math_operation_util(math_query, llm, model) -> str:
             }
         )
     except Exception as e:
-        # Handle any errors during the LLM math operation
         logger.error(f"Error during math operation: {e}")
         return json.dumps(
             {
@@ -134,12 +150,8 @@ async def perform_math_operation_util(math_query, llm, model) -> str:
 
 async def find_info(content, question, llm, model) -> str:
     try:
-        # Load the prompt template for math operations
-        math_prompt = engine.load_prompt("find-info", question=question, content=content)
-
-        # Send the math query to the LLM to perform the math operation
-        response = await llm.chat(model, math_prompt, "", return_json=True)
-        # Parse the response and return the result
+        find_info_prompt = engine.load_prompt("find-info", question=question, content=content)
+        response = await llm.chat(model, find_info_prompt, "", return_json=True)
         return json.dumps(
             {
                 "status": "success",
@@ -148,8 +160,7 @@ async def find_info(content, question, llm, model) -> str:
             }
         )
     except Exception as e:
-        # Handle any errors during the LLM math operation
-        logger.error(f"Error during math operation: {e}")
+        logger.error(f"Error during finding info operation: {e}")
         return json.dumps(
             {
                 "status": "error",

From 276831a7b5a53d42717c78eaae69b8414b480dc9 Mon Sep 17 00:00:00 2001
From: Maxwell Nyamunda
 <142179888+mnyamunda-scottlogic@users.noreply.github.com>
Date: Mon, 7 Oct 2024 11:16:06 +0100
Subject: [PATCH 38/48] Extra help for vscode users

---
 backend/README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/backend/README.md b/backend/README.md
index 9e09bcd62..cd21c0b95 100644
--- a/backend/README.md
+++ b/backend/README.md
@@ -37,7 +37,7 @@ Follow the instructions below to run the backend locally. Change directory to `/
 ```bash
 pip install -r requirements.txt
 ```
-
+> (VsCode) You may run into some issues with compiling python packages from requirements.txt. To resolve this ensure you have downloaded and installed the "Desktop development with C++" workload from your Visual Studio installer.
 3. Run the app
 
 ```bash

From 01c8f709ba1ef9cda7c667277ec0cd529ff24f4b Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Wed, 16 Oct 2024 13:35:47 +0100
Subject: [PATCH 39/48] Pushing the current changes

---
 backend/src/agents/chart_generator_agent.py   |  9 +-
 backend/src/agents/file_agent.py              | 28 ++++--
 backend/src/agents/intent_agent.py            | 46 +++++++++-
 backend/src/agents/web_agent.py               | 88 ++++++++++++-------
 .../src/prompts/templates/answer-user-ques.j2 | 54 ++++++++++++
 .../src/prompts/templates/best-next-step.j2   |  2 +-
 backend/src/prompts/templates/best-tool.j2    |  5 ++
 backend/src/prompts/templates/intent.j2       | 74 +++++++---------
 .../templates/tool-selection-format.j2        |  3 +-
 backend/src/utils/web_utils.py                | 25 +++++-
 backend/tests/agents/file_agent_test.py       |  6 +-
 11 files changed, 246 insertions(+), 94 deletions(-)
 create mode 100644 backend/src/prompts/templates/answer-user-ques.j2

diff --git a/backend/src/agents/chart_generator_agent.py b/backend/src/agents/chart_generator_agent.py
index f8ef64572..979163a88 100644
--- a/backend/src/agents/chart_generator_agent.py
+++ b/backend/src/agents/chart_generator_agent.py
@@ -9,8 +9,8 @@
 from src.utils import scratchpad
 from PIL import Image
 import json
-from src.websockets.user_confirmer import UserConfirmer
-from src.websockets.confirmations_manager import confirmations_manager
+# from src.websockets.user_confirmer import UserConfirmer
+# from src.websockets.confirmations_manager import confirmations_manager
 
 logger = logging.getLogger(__name__)
 
@@ -31,8 +31,9 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L
     sanitised_script = sanitise_script(generated_code)
 
     try:
-        confirmer = UserConfirmer(confirmations_manager)
-        is_confirmed = await confirmer.confirm("Would you like to generate a graph?")
+        # confirmer = UserConfirmer(confirmations_manager)
+        is_confirmed = True
+        # await confirmer.confirm("Would you like to generate a graph?")
         if not is_confirmed:
             raise Exception("The user did not confirm to creating a graph.")
         local_vars = {}
diff --git a/backend/src/agents/file_agent.py b/backend/src/agents/file_agent.py
index 50a1d0fd9..d8a817b1c 100644
--- a/backend/src/agents/file_agent.py
+++ b/backend/src/agents/file_agent.py
@@ -39,13 +39,19 @@ async def read_file_core(file_path: str) -> str:
         return create_response(f"Error reading file: {file_path}", STATUS_ERROR)
 
 
-async def write_file_core(file_path: str, content: str) -> str:
+async def write_or_update_file_core(file_path: str, content: str, update) -> str:
     full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
     try:
-        with open(full_path, 'w') as file:
-            file.write(content)
-        logger.info(f"Content written to file {full_path} successfully.")
-        return create_response(f"Content written to file {file_path}.")
+        if update == "yes":
+            with open(full_path, 'a') as file:
+                file.write('\n' +content)
+            logger.info(f"Content appended to file {full_path} successfully.")
+            return create_response(f"Content appended to file {file_path}.")
+        else:
+            with open(full_path, 'w') as file:
+                file.write(content)
+            logger.info(f"Content written to file {full_path} successfully.")
+            return create_response(f"Content written to file {file_path}.")
     except Exception as e:
         logger.error(f"Error writing to file {full_path}: {e}")
         return create_response(f"Error writing to file: {file_path}", STATUS_ERROR)
@@ -67,7 +73,7 @@ async def read_file(file_path: str, llm, model) -> str:
 
 @tool(
     name="write_file",
-    description="Write content to a text file.",
+    description="Write or update content to a text file.",
     parameters={
         "file_path": Parameter(
             type="string",
@@ -77,16 +83,20 @@ async def read_file(file_path: str, llm, model) -> str:
             type="string",
             description="The content to write to the file."
         ),
+        "update": Parameter(
+            type="string",
+            description="if yes then just append the file"
+        ),
     },
 )
-async def write_file(file_path: str, content: str, llm, model) -> str:
-    return await write_file_core(file_path, content)
+async def write_or_update_file(file_path: str, content: str, update, llm, model) -> str:
+    return await write_or_update_file_core(file_path, content, update)
 
 
 @agent(
     name="FileAgent",
     description="This agent is responsible for reading from and writing to files.",
-    tools=[read_file, write_file],
+    tools=[read_file, write_or_update_file],
 )
 class FileAgent(Agent):
     pass
diff --git a/backend/src/agents/intent_agent.py b/backend/src/agents/intent_agent.py
index ed6b7fcb2..19560d6c2 100644
--- a/backend/src/agents/intent_agent.py
+++ b/backend/src/agents/intent_agent.py
@@ -1,9 +1,23 @@
 from src.prompts import PromptEngine
 from src.agents import Agent, agent
+from src.utils import get_scratchpad
+import logging
+import os
+import json
+from src.utils.config import Config
+
+
+config = Config()
 
 engine = PromptEngine()
 intent_format = engine.load_prompt("intent-format")
+logger = logging.getLogger(__name__)
+FILES_DIRECTORY = f"/app/{config.files_directory}"
 
+# Constants for response status
+IGNORE_VALIDATION = "true"
+STATUS_SUCCESS = "success"
+STATUS_ERROR = "error"
 
 @agent(
     name="IntentAgent",
@@ -11,7 +25,37 @@
     tools=[],
 )
 class IntentAgent(Agent):
+
+    async def read_file_core(self, file_path: str) -> str:
+        full_path = os.path.normpath(os.path.join(FILES_DIRECTORY, file_path))
+        try:
+            with open(full_path, 'r') as file:
+                content = file.read()
+            return content
+        except FileNotFoundError:
+            error_message = f"File {file_path} not found."
+            logger.error(error_message)
+            return ""
+        except Exception as e:
+            logger.error(f"Error reading file {full_path}: {e}")
+            return ""
+        
     async def invoke(self, utterance: str) -> str:
-        user_prompt = engine.load_prompt("intent", question=utterance)
+        chat_history = await self.read_file_core("conversation-history.txt")
+        logger.info(f"USER - chat history: {chat_history}")
+        
+        user_prompt = engine.load_prompt("intent", question=utterance, chat_history=chat_history)
+        logger.info(f"USER - user prompt: {user_prompt}")
 
         return await self.llm.chat(self.model, intent_format, user_prompt=user_prompt, return_json=True)
+
+
+    
+
+    # Utility function for error responses
+def create_response(content: str, status: str = STATUS_SUCCESS) -> str:
+    return json.dumps({
+        "content": content,
+        "ignore_validation": IGNORE_VALIDATION,
+        "status": status
+    }, indent=4)
\ No newline at end of file
diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py
index 25197c96f..265d372ec 100644
--- a/backend/src/agents/web_agent.py
+++ b/backend/src/agents/web_agent.py
@@ -10,7 +10,8 @@
     summarise_content,
     summarise_pdf_content,
     find_info,
-    create_search_term
+    create_search_term,
+    answer_user_ques
 )
 from .validator_agent import ValidatorAgent
 import aiohttp
@@ -28,42 +29,63 @@
 async def web_general_search_core(search_query, llm, model) -> str:
     try:
         # Step 1: Generate the search term from the user's query
-        search_term_json = await create_search_term(search_query, llm, model)
-        search_term_result = json.loads(search_term_json)
+        answer_to_user = await answer_user_ques(search_query, llm, model)
+        answer_result = json.loads(answer_to_user)
+        if answer_result["status"] == "error":
+            return ""
+        logger.info(f'Answer found successfully {answer_result}')
+        valid_answer = json.loads(answer_result["response"]).get("is_valid", "")
+        if valid_answer:
+            final_answer = json.loads(answer_result["response"]).get("answer", "")
+            if not final_answer:
+                return "No answer found."
+            logger.info(f'Answer found successfully {final_answer}')
+            response = {
+                    "content": final_answer,
+                    "ignore_validation": "false"
+                }
+            return json.dumps(response, indent=4)
+        else:
+            search_term_json = await create_search_term(search_query, llm, model)
+            search_term_result = json.loads(search_term_json)
 
-        # Check if there was an error in generating the search term
-        if search_term_result.get("status") == "error":
-            return "No search term found for the given query."
-        search_term = json.loads(search_term_result["response"]).get("search_term", "")
+            # Check if there was an error in generating the search term
+            if search_term_result.get("status") == "error":
+                response = {
+                    "content": search_term_result.get("error"),
+                    "ignore_validation": "false"
+                }
+                return json.dumps(response, indent=4)
+            search_term = json.loads(search_term_result["response"]).get("search_term", "")
 
-        # Step 2: Perform the search using the generated search term
-        search_result = await perform_search(search_term, num_results=15)
-        if search_result.get("status") == "error":
-            return "No relevant information found on the internet for the given query."
-        urls = search_result.get("urls", [])
-        logger.info(f"URLs found: {urls}")
+            # Step 2: Perform the search using the generated search term
+            search_result = await perform_search(search_term, num_results=15)
+            if search_result.get("status") == "error":
+                return "No relevant information found on the internet for the given query."
+            urls = search_result.get("urls", [])
+            logger.info(f"URLs found: {urls}")
 
-        # Step 3: Scrape content from the URLs found
-        for url in urls:
-            content = await perform_scrape(url)
-            if not content:
-                continue  # Skip to the next URL if no content is found
-            logger.info(f"Content scraped successfully: {content}")
-            # Step 4: Summarize the scraped content based on the search term
-            summary = await perform_summarization(search_term, content, llm, model)
-            if not summary:
-                continue  # Skip if no summary was generated
+            # Step 3: Scrape content from the URLs found
+            for url in urls:
+                content = await perform_scrape(url)
+                if not content:
+                    continue  # Skip to the next URL if no content is found
+                # logger.info(f"Content scraped successfully: {content}")
+                # Step 4: Summarize the scraped content based on the search term
+                summary = await perform_summarization(search_term, content, llm, model)
+                if not summary:
+                    continue  # Skip if no summary was generated
 
-            # Step 5: Validate the summarization
-            is_valid = await is_valid_answer(summary, search_term)
-            if not is_valid:
-                continue # Skip if the summarization is not valid
-            response = {
-                "content": summary,
-                "ignore_validation": "false"
-            }
-            return json.dumps(response, indent=4)
-        return "No relevant information found on the internet for the given query."
+                # Step 5: Validate the summarization
+                is_valid = await is_valid_answer(summary, search_term)
+                if not is_valid:
+                    continue # Skip if the summarization is not valid
+                response = {
+                    "content": summary,
+                    "ignore_validation": "false"
+                }
+                return json.dumps(response, indent=4)
+            return "No relevant information found on the internet for the given query."
     except Exception as e:
         logger.error(f"Error in web_general_search_core: {e}")
         return "An error occurred while processing the search query."
diff --git a/backend/src/prompts/templates/answer-user-ques.j2 b/backend/src/prompts/templates/answer-user-ques.j2
new file mode 100644
index 000000000..02b9e56e6
--- /dev/null
+++ b/backend/src/prompts/templates/answer-user-ques.j2
@@ -0,0 +1,54 @@
+You are an expert in providing accurate and complete answers to user queries. Your task is twofold:
+
+1. **Generate a detailed answer** to the user's question based on the provided content or context.
+2. **Validate** if the generated answer directly addresses the user's question and is factually accurate.
+
+User's question is:
+{{ question }}
+
+Once you generate an answer:
+- **Check** if the answer completely and accurately addresses the user's question.
+- **Determine** if the answer is valid, based on the content provided.
+
+Reply only in JSON format with the following structure:
+
+```json
+{
+    "answer": "The answer to the user's question, based on the content provided",
+    "is_valid": true or false,
+    "validation_reason": "A sentence explaining whether the answer is valid or not, and why"
+}
+
+
+
+### **Explanation:**
+
+1. **Answer**: The LLM generates an answer based on the user’s question and the provided content.
+2. **Validity Check**: The LLM checks if its generated answer is complete and correct. This could be based on factual accuracy, coverage of the query, or relevance to the user's question.
+3. **Validation Reason**: The LLM explains why the answer is valid or invalid.
+
+### **Example of Usage:**
+
+#### **User’s Question:**
+- **"What is Tesla's revenue every year since its creation?"**
+
+#### **Content Provided:**
+- A table or a paragraph with data on Tesla's revenue for various years.
+
+#### **LLM’s Response:**
+
+```json
+{
+    "answer": "Tesla's revenue since its creation is: 2008: $15 million, 2009: $30 million, ..., 2023: $81 billion.",
+    "is_valid": true,
+    "validation_reason": "The answer includes Tesla's revenue for every year since its creation, based on the data provided."
+}
+
+{
+    "answer": "Tesla's revenue for 2010 to 2023 is available, but data for the earlier years is missing.",
+    "is_valid": false,
+    "validation_reason": "The answer is incomplete because data for Tesla's early years is missing."
+}
+
+
+Important: If the question is realted to real time data, the LLM should provide is_valid is false.
\ No newline at end of file
diff --git a/backend/src/prompts/templates/best-next-step.j2 b/backend/src/prompts/templates/best-next-step.j2
index 59942baa6..ff6ef6c98 100644
--- a/backend/src/prompts/templates/best-next-step.j2
+++ b/backend/src/prompts/templates/best-next-step.j2
@@ -22,7 +22,7 @@ Here is the list of Agents you can choose from:
 AGENT LIST:
 {{ list_of_agents }}
 
-If the list of agents does not contain something suitable, you should say the agent is 'none'. ie. If question is 'general knowledge', 'personal' or a 'greeting'.
+If the list of agents does not contain something suitable, you should say the agent is 'WebAgent'. ie. If question is 'general knowledge', 'personal' or a 'greeting'.
 
 ## Determine the next best step
 Your task is to pick one of the mentioned agents above to complete the task.
diff --git a/backend/src/prompts/templates/best-tool.j2 b/backend/src/prompts/templates/best-tool.j2
index 51d93f5de..9558964bb 100644
--- a/backend/src/prompts/templates/best-tool.j2
+++ b/backend/src/prompts/templates/best-tool.j2
@@ -11,11 +11,16 @@ Trust the information below completely (100% accurate)
 
 Pick 1 tool (no more than 1) from the list below to complete this task.
 Fit the correct parameters from the task to the tool arguments.
+Ensure that numerical values are formatted correctly, including the use of currency symbols (e.g., "$") and units of measurement (e.g., "million") if applicable.
 Parameters with required as False do not need to be fit.
 Add if appropriate, but do not hallucinate arguments for these parameters
 
 {{ tools }}
 
+Important:
+If the task involves financial data, ensure that all monetary values are expressed with appropriate currency (e.g., "$") and rounded to the nearest million if specified.
+If the task involves scaling (e.g., thousands, millions), ensure that the extracted parameters reflect the appropriate scale (e.g., "$15 million", "$5000").
+
 From the task you should be able to extract the parameters. If it is data driven, it should be turned into a cypher query
 
 If none of the tools are appropriate for the task, return the following tool
diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index a332b5385..b54f30282 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -1,40 +1,39 @@
 You are an expert in determining the intent behind a user's question. 
 
- 
 The question is:  
 
 {{ question }} 
 
- 
-Your task is to accurately comprehend the intentions behind the question. The question can be composed of different intents and when it is the case, examine all intents one by one to determine which one to tackle first as you may need the data gathered from a secondary intent to perform the first intent.  
-If the question contains multiple intents, break them down into individual tasks, and specify the order in which these tasks should be tackled. The order should ensure that each intent is addressed in a logical sequence, particularly if one intent depends on the data obtained from another.
-You are NOT ALLOWED to make up sample data or example values. Only use concrete data for which you can name the source.
-Based on this understanding, the following query must be formulated to extract the necessary data, which can then be used to address the question. 
-
-Guidelines:
-1. Determine each distinct intent in the question.
-2. Sequence the intents: Identify which intent should be tackled first and which should follow.
-3. For each intent, specify:
-    - The exact operation required (e.g., "literal search", "filter + aggregation", "data transformation", "scrape content", "find specific information", "mathematical operation" for maths-related tasks like rounding).
-    - The category of the question (e.g., "data-driven", "data presentation", "general knowledge", "content extraction", "file handling").
-    - Any specific parameters or conditions that apply.
-    - The correct aggregation and sorting methods if applicable.
-4. For mathematical operations (like rounding, addition, or finding maximum values), use **MathsAgent** for the task.
-5. Avoid conflating intents: If a user's query asks for data retrieval and its visualization, treat these as separate operations.
-6. Do not make assumptions or create hypothetical data. Use only concrete data where applicable.
-7. Treat each operation as a separate task when there are multiple intents.
- 
-Specify an operation type under the operation key; here are a few examples: 
+The prvious chat history is: 
+
+{{ chat_history }}
+
+Your task is to accurately comprehend the intentions behind the current question. 
+The question can be composed of different intents and may depend on the context provided by the previous question and its response.
+
+- You must evaluate whether the current question is directly related to or dependent on the information from the previous interaction.
+- If the current question builds on the previous one, make sure to use the relevant data from the previous response to inform the current query.
+
+The question may contain multiple intents. Examine each intent and determine the order in which they should be tackled, ensuring each intent is addressed logically. If one intent depends on data from another, sequence them accordingly.
 
-* "literal search" - This should be used when the user is looking to find precise information, such as known facts 
-* "relevancy search" - This should be used when the user is looking to find something that is not a literal and is fuzzy  
-* "filter + aggregation" - This should be used when they want something like a count, where there will be only 1 number returned 
-* "filter + aggregation + sort" - This should be used when multiple numbers will be returned 
-* "filter + sort" - This should be used when no aggregation is required e.g. count 
+Use the following guidelines:
 
+1. Determine distinct intents in the question.
+2. For each intent, specify:
+    - The exact operation required (e.g., "literal search", "filter + aggregation").
+    - The category of the question (e.g., "data-driven", "general knowledge").
+    - Any specific parameters or conditions that apply.
+    - If related to the previous response, include parameters derived from the previous interaction.
+3. Sequence the intents logically if there are multiple, ensuring any dependent intents are handled last.
+4. For each intent, clarify the operation, aggregation, sorting, and any timeframe or other parameters.
+5. Avoid conflating intents: If a user's query asks for data retrieval and its visualization, treat these as separate operations.
 
-Examples: 
+Examples of common operations:
+- Literal search for factual information.
+- Filter + aggregation for tasks like counting or summing.
+- Data transformation for numerical operations like rounding.
 
+Examples
 Q: How much did I spend with Amazon this month? 
 Response: 
 {"query":"How much did I spend with Amazon this month?","user_intent":"sum amount spent","questions":[{"query":"How much did I spend with Amazon this month?","question_intent":"calculate total expenses","operation":"filter + aggregation","question_category": "data driven","parameters":[{"type":"company","value":"Amazon"}],"timeframe":"this month","aggregation":"sum","sort_order":"none"}]} 
@@ -75,18 +74,13 @@ Q: Show me a chart of different subscription prices with Netflix?
 Response:
 {"query": "Show me a chart of different subscription prices with Netflix?", "user_intent": "retrieve and visualize subscription data", "questions": [{"query": "What are the different subscription prices with Netflix?", "question_intent": "retrieve subscription pricing information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Netflix"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Show me the results in a chart", "question_intent": "display subscription pricing information in a chart", "operation": "data visualization", "question_category": "data presentation", "parameters": [], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}
 
-Q: Read the file called file_to_read.txt and write its content to a file called output.txt.
-Response:
-{"query": "Read the file called {{ file_name }} and write its content to a file called {{ output_file_name }}.", "user_intent": "read and write file content", "questions": [{"query": "Read the file called {{ file_name }} using fileagent.", "question_intent": "read file content", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "file", "value": "{{ file_name }}"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}, {"query": "Write the content to a file called {{ output_file_name }} using fileagent.", "question_intent": "write file content", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "file", "value": "{{ output_file_name }}"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}
-
-Q: What's the price of the book in this URL http://books.toscrape.com/catalogue/meditations_33/index.html?
-Response:
-{"query": "What's the price of the book in this URL http://books.toscrape.com/catalogue/meditations_33/index.html?", "user_intent": "find the price of a book from a URL", "questions": [{"query": "Scrape the content of the page at the URL: http://books.toscrape.com/catalogue/meditations_33/index.html", "question_intent": "scrape content", "operation": "online search and scraping", "question_category": "content extraction", "parameters": [{"type": "URL", "value": "http://books.toscrape.com/catalogue/meditations_33/index.html"}]}, {"query": "Extract the price of the book from the scraped content.", "question_intent": "find specific information", "operation": "literal search", "question_category": "data-driven", "parameters": [{"type": "attribute", "value": "price"}]}]}
+Finally, if no tool fits the task, return the following:
 
-Q: Write the price of the book in this URL: http://books.toscrape.com/catalogue/meditations_33/index.html into a .txt file.
-Response:
-{"query": "Write the price of the book in this URL: http://books.toscrape.com/catalogue/meditations_33/index.html into a .txt file.", "user_intent": "scrape price and save to file", "questions": [{"query": "Scrape the content of the page at the URL: http://books.toscrape.com/catalogue/meditations_33/index.html", "question_intent": "scrape content", "operation": "online search and scraping", "question_category": "content extraction", "parameters": [{"type": "URL", "value": "http://books.toscrape.com/catalogue/meditations_33/index.html"}]}, {"query": "Extract the price of the book from the scraped content.", "question_intent": "find specific information", "operation": "literal search", "question_category": "data-driven", "parameters": [{"type": "attribute", "value": "price"}]}, {"query": "Write the price into a file called price.txt.", "question_intent": "save to file", "operation": "write file content", "question_category": "file handling", "parameters": [{"type": "file", "value": "price.txt"}]}]}
+{
+    "tool_name":  "None",
+    "tool_parameters":  "{}",
+    "reasoning": "No tool was appropriate for the task"
+}
 
-Q: Please find tesla's revenue every year since its creation. Use the US notation, with a precision rounded to the nearest million dollars (for instance, $31,578 millions).
-Reponse:
-{"query": "Please find tesla's revenue every year since its creation. Use the US notation, with a precision rounded to the nearest million dollars.", "user_intent": "find tesla's revenue history in the US dollors", "questions": [{"query": "Please find tesla's revenue every year since its creation in the US dollors.", "question_intent": "retrieve revenue information", "operation": "literal search", "question_category": "data driven", "parameters": [{"type": "company", "value": "Tesla"}, {"type": "currency", "value": "US dollars"}, {"type": "timeframe", "value": "since creation"}], "aggregation": "none", "sort_order": "none", "timeframe": "since creation"}, {"query": "Round the revenue to the nearest million dollars.", "question_intent": "round to nearest million dollars", "operation": "mathematical operation", "question_category": "data driven", "parameters": [{"type": "rounding", "value": "nearest million dollars"}], "aggregation": "none", "sort_order": "none", "timeframe": "none"}]}
\ No newline at end of file
+Important:
+Please always create the last intent to append the retrieved info in a conversation-history.txt file
\ No newline at end of file
diff --git a/backend/src/prompts/templates/tool-selection-format.j2 b/backend/src/prompts/templates/tool-selection-format.j2
index 542d2f7c1..3a7c664b4 100644
--- a/backend/src/prompts/templates/tool-selection-format.j2
+++ b/backend/src/prompts/templates/tool-selection-format.j2
@@ -1,4 +1,5 @@
-Reply only in json with the following format:
+Reply only in json with the following format, in the tool_paramters please include the curreny and measuring scale used in the content provided.:
+
 
 {
     "tool_name":  "the exact string name of the tool chosen",
diff --git a/backend/src/utils/web_utils.py b/backend/src/utils/web_utils.py
index 804db2e42..1839b03ee 100644
--- a/backend/src/utils/web_utils.py
+++ b/backend/src/utils/web_utils.py
@@ -43,8 +43,8 @@ async def scrape_content(url, limit=100000) -> str:
         async with aiohttp.request("GET", url) as response:
             response.raise_for_status()
             soup = BeautifulSoup(await response.text(), "html.parser")
-            paragraphs = soup.find_all("p" and "table")
-            content = " ".join([para.get_text() for para in paragraphs])
+            paragraphs_and_tables = soup.find_all(["p", "table", "h1", "h2", "h3", "h4", "h5", "h6"])
+            content = "\n".join([tag.get_text() for tag in paragraphs_and_tables])
             return json.dumps(
                 {
                     "status": "success",
@@ -83,6 +83,27 @@ async def create_search_term(search_query, llm, model) -> str:
             }
         )
 
+async def answer_user_ques(search_query, llm, model) -> str:
+    try:
+        summariser_prompt = engine.load_prompt("answer-user-ques", question=search_query)
+        response = await llm.chat(model, summariser_prompt, "", return_json=True)
+        return json.dumps(
+            {
+                "status": "success",
+                "response": response,
+                "error": None,
+            }
+        )
+    except Exception as e:
+        logger.error(f"Error during create search term: {e}")
+        return json.dumps(
+            {
+                "status": "error",
+                "response": None,
+                "error": str(e),
+            }
+        )
+
 async def summarise_content(search_query, contents, llm, model) -> str:
     try:
         summariser_prompt = engine.load_prompt("summariser", question=search_query, content=contents)
diff --git a/backend/tests/agents/file_agent_test.py b/backend/tests/agents/file_agent_test.py
index 91a56eb0f..5d3f4e1dd 100644
--- a/backend/tests/agents/file_agent_test.py
+++ b/backend/tests/agents/file_agent_test.py
@@ -2,7 +2,7 @@
 from unittest.mock import patch, mock_open
 import json
 import os
-from src.agents.file_agent import read_file_core, write_file_core, create_response
+from src.agents.file_agent import read_file_core, write_or_update_file_core, create_response
 
 # Mocking config for the test
 @pytest.fixture(autouse=True)
@@ -34,7 +34,7 @@ async def test_read_file_core_file_not_found(mock_file):
 async def test_write_file_core_success(mock_file):
     file_path = "example_write.txt"
     content = "This is test content to write."
-    result = await write_file_core(file_path, content)
+    result = await write_or_update_file_core(file_path, content, 'no')
     expected_response = create_response(f"Content written to file {file_path}.")
     assert json.loads(result) == json.loads(expected_response)
     expected_full_path = os.path.normpath("/app/files/example_write.txt")
@@ -46,7 +46,7 @@ async def test_write_file_core_success(mock_file):
 async def test_write_file_core_error(mock_file):
     file_path = "error_file.txt"
     content = "Content with error."
-    result = await write_file_core(file_path, content)
+    result = await write_or_update_file_core(file_path, content, 'no')
     expected_response = create_response(f"Error writing to file: {file_path}", "error")
     assert json.loads(result) == json.loads(expected_response)
     expected_full_path = os.path.normpath("/app/files/error_file.txt")

From 3ca19c174b338a18f589623a1fb87a87daa937f7 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Wed, 16 Oct 2024 13:57:49 +0100
Subject: [PATCH 40/48] Removed loggers

---
 backend/src/agents/agent.py        |  3 ---
 backend/src/agents/intent_agent.py | 11 +++--------
 2 files changed, 3 insertions(+), 11 deletions(-)

diff --git a/backend/src/agents/agent.py b/backend/src/agents/agent.py
index 221730d65..32224b55d 100644
--- a/backend/src/agents/agent.py
+++ b/backend/src/agents/agent.py
@@ -56,9 +56,6 @@ async def __get_action(self, utterance: str) -> Action_and_args:
 
     async def invoke(self, utterance: str) -> str:
         (action, args) = await self.__get_action(utterance)
-        logger.info(f"Calling action {action} with arguments: {args}, LLM: {self.llm}, Model: {self.model}")
-        logger.info(f"USER - Action: {action} and args: {args} for utterance: {utterance}")
-        
         result_of_action = await action(**args, llm=self.llm, model=self.model)
         await publish_log_info(LogPrefix.USER, f"Action gave result: {result_of_action}", __name__)
         return result_of_action
diff --git a/backend/src/agents/intent_agent.py b/backend/src/agents/intent_agent.py
index 19560d6c2..f3b701653 100644
--- a/backend/src/agents/intent_agent.py
+++ b/backend/src/agents/intent_agent.py
@@ -1,6 +1,5 @@
 from src.prompts import PromptEngine
 from src.agents import Agent, agent
-from src.utils import get_scratchpad
 import logging
 import os
 import json
@@ -39,23 +38,19 @@ async def read_file_core(self, file_path: str) -> str:
         except Exception as e:
             logger.error(f"Error reading file {full_path}: {e}")
             return ""
-        
+
     async def invoke(self, utterance: str) -> str:
         chat_history = await self.read_file_core("conversation-history.txt")
-        logger.info(f"USER - chat history: {chat_history}")
-        
+
         user_prompt = engine.load_prompt("intent", question=utterance, chat_history=chat_history)
-        logger.info(f"USER - user prompt: {user_prompt}")
 
         return await self.llm.chat(self.model, intent_format, user_prompt=user_prompt, return_json=True)
 
 
-    
-
     # Utility function for error responses
 def create_response(content: str, status: str = STATUS_SUCCESS) -> str:
     return json.dumps({
         "content": content,
         "ignore_validation": IGNORE_VALIDATION,
         "status": status
-    }, indent=4)
\ No newline at end of file
+    }, indent=4)

From 84dcf74fdd23879afb05ef3c9a7a859654a2a34a Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 17 Oct 2024 09:17:34 +0100
Subject: [PATCH 41/48] Fixed the broken unit tests.

---
 backend/src/agents/chart_generator_agent.py   |  2 +-
 backend/src/prompts/templates/intent.j2       |  1 +
 .../agents/chart_generator_agent_test.py      | 72 +++++--------------
 backend/tests/agents/web_agent_test.py        | 22 ++++--
 backend/tests/prompts/prompting_test.py       | 14 +++-
 5 files changed, 46 insertions(+), 65 deletions(-)

diff --git a/backend/src/agents/chart_generator_agent.py b/backend/src/agents/chart_generator_agent.py
index 979163a88..8479bc128 100644
--- a/backend/src/agents/chart_generator_agent.py
+++ b/backend/src/agents/chart_generator_agent.py
@@ -52,7 +52,7 @@ async def generate_chart(question_intent, data_provided, question_params, llm: L
         raise
     response = {
         "content": image_data,
-        "ignore_validation": "false",
+        "ignore_validation": "true",
     }
     return json.dumps(response, indent=4)
 
diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index b54f30282..68d53bdf5 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -27,6 +27,7 @@ Use the following guidelines:
 3. Sequence the intents logically if there are multiple, ensuring any dependent intents are handled last.
 4. For each intent, clarify the operation, aggregation, sorting, and any timeframe or other parameters.
 5. Avoid conflating intents: If a user's query asks for data retrieval and its visualization, treat these as separate operations.
+6. Use the chat history to figure out the correct context if the user's question is a bit vague.
 
 Examples of common operations:
 - Literal search for factual information.
diff --git a/backend/tests/agents/chart_generator_agent_test.py b/backend/tests/agents/chart_generator_agent_test.py
index 5072f6a39..fe0c48dbd 100644
--- a/backend/tests/agents/chart_generator_agent_test.py
+++ b/backend/tests/agents/chart_generator_agent_test.py
@@ -8,19 +8,16 @@
 import json
 from src.agents.chart_generator_agent import sanitise_script
 
-
 @pytest.mark.asyncio
 @patch("src.agents.chart_generator_agent.engine.load_prompt")
 @patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock)
-@patch("src.agents.chart_generator_agent.UserConfirmer.confirm", new_callable=AsyncMock)
-async def test_generate_code_success(confirm_mock, mock_sanitise_script, mock_load_prompt):
-    confirm_mock.return_value = True
+async def test_generate_code_success(mock_sanitise_script, mock_load_prompt):
     llm = AsyncMock()
     model = "mock_model"
 
     mock_load_prompt.side_effect = [
         "details to create chart code prompt",
-        "generate chart code prompt",
+        "generate chart code prompt"
     ]
 
     llm.chat.return_value = "generated code"
@@ -30,7 +27,7 @@ async def test_generate_code_success(confirm_mock, mock_sanitise_script, mock_lo
 fig = plt.figure()
 plt.plot([1, 2, 3], [4, 5, 6])
 """
-    plt.switch_backend("Agg")
+    plt.switch_backend('Agg')
 
     def mock_exec_side_effect(script, globals=None, locals=None):
         if script == return_string:
@@ -38,7 +35,7 @@ def mock_exec_side_effect(script, globals=None, locals=None):
             plt.plot([1, 2, 3], [4, 5, 6])
             if locals is None:
                 locals = {}
-            locals["fig"] = fig
+            locals['fig'] = fig
 
     with patch("builtins.exec", side_effect=mock_exec_side_effect):
         result = await generate_chart("question_intent", "data_provided", "question_params", llm, model)
@@ -54,23 +51,20 @@ def mock_exec_side_effect(script, globals=None, locals=None):
         llm.chat.assert_called_once_with(
             model,
             "generate chart code prompt",
-            "details to create chart code prompt",
+            "details to create chart code prompt"
         )
         mock_sanitise_script.assert_called_once_with("generated code")
 
-
 @pytest.mark.asyncio
 @patch("src.agents.chart_generator_agent.engine.load_prompt")
 @patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock)
-@patch("src.agents.chart_generator_agent.UserConfirmer.confirm", new_callable=AsyncMock)
-async def test_generate_code_no_figure(confirm_mock, mock_sanitise_script, mock_load_prompt):
-    confirm_mock.return_value = True
+async def test_generate_code_no_figure(mock_sanitise_script, mock_load_prompt):
     llm = AsyncMock()
     model = "mock_model"
 
     mock_load_prompt.side_effect = [
         "details to create chart code prompt",
-        "generate chart code prompt",
+        "generate chart code prompt"
     ]
 
     llm.chat.return_value = "generated code"
@@ -80,7 +74,7 @@ async def test_generate_code_no_figure(confirm_mock, mock_sanitise_script, mock_
 # No fig creation
 """
 
-    plt.switch_backend("Agg")
+    plt.switch_backend('Agg')
 
     def mock_exec_side_effect(script, globals=None, locals=None):
         if script == return_string:
@@ -94,45 +88,15 @@ def mock_exec_side_effect(script, globals=None, locals=None):
         llm.chat.assert_called_once_with(
             model,
             "generate chart code prompt",
-            "details to create chart code prompt",
+            "details to create chart code prompt"
         )
 
         mock_sanitise_script.assert_called_once_with("generated code")
 
-
-@pytest.mark.asyncio
-@patch("src.agents.chart_generator_agent.engine.load_prompt")
-@patch("src.agents.chart_generator_agent.sanitise_script", new_callable=MagicMock)
-@patch("src.agents.chart_generator_agent.UserConfirmer.confirm", new_callable=AsyncMock)
-async def test_generate_code_confirmation_false(confirm_mock, mock_sanitise_script, mock_load_prompt):
-    confirm_mock.return_value = False
-    llm = AsyncMock()
-    model = "mock_model"
-
-    mock_load_prompt.side_effect = [
-        "details to create chart code prompt",
-        "generate chart code prompt",
-    ]
-
-    llm.chat.return_value = "generated code"
-
-    mock_sanitise_script.return_value = "script"
-
-    with pytest.raises(Exception, match="The user did not confirm to creating a graph."):
-        await generate_chart("question_intent", "data_provided", "question_params", llm, model)
-
-    llm.chat.assert_called_once_with(
-        model,
-        "generate chart code prompt",
-        "details to create chart code prompt",
-    )
-
-    mock_sanitise_script.assert_called_once_with("generated code")
-
-
 @pytest.mark.parametrize(
     "input_script, expected_output",
     [
+
         (
             """```python
 import matplotlib.pyplot as plt
@@ -141,7 +105,7 @@ async def test_generate_code_confirmation_false(confirm_mock, mock_sanitise_scri
 ```""",
             """import matplotlib.pyplot as plt
 fig = plt.figure()
-plt.plot([1, 2, 3], [4, 5, 6])""",
+plt.plot([1, 2, 3], [4, 5, 6])"""
         ),
         (
             """```python
@@ -150,7 +114,7 @@ async def test_generate_code_confirmation_false(confirm_mock, mock_sanitise_scri
 plt.plot([1, 2, 3], [4, 5, 6])""",
             """import matplotlib.pyplot as plt
 fig = plt.figure()
-plt.plot([1, 2, 3], [4, 5, 6])""",
+plt.plot([1, 2, 3], [4, 5, 6])"""
         ),
         (
             """import matplotlib.pyplot as plt
@@ -159,7 +123,7 @@ async def test_generate_code_confirmation_false(confirm_mock, mock_sanitise_scri
 ```""",
             """import matplotlib.pyplot as plt
 fig = plt.figure()
-plt.plot([1, 2, 3], [4, 5, 6])""",
+plt.plot([1, 2, 3], [4, 5, 6])"""
         ),
         (
             """import matplotlib.pyplot as plt
@@ -167,13 +131,13 @@ async def test_generate_code_confirmation_false(confirm_mock, mock_sanitise_scri
 plt.plot([1, 2, 3], [4, 5, 6])""",
             """import matplotlib.pyplot as plt
 fig = plt.figure()
-plt.plot([1, 2, 3], [4, 5, 6])""",
+plt.plot([1, 2, 3], [4, 5, 6])"""
         ),
         (
             "",
-            "",
-        ),
-    ],
+            ""
+        )
+    ]
 )
 def test_sanitise_script(input_script, expected_output):
-    assert sanitise_script(input_script) == expected_output
+    assert sanitise_script(input_script) == expected_output
\ No newline at end of file
diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py
index b54aca6a2..72728e5b9 100644
--- a/backend/tests/agents/web_agent_test.py
+++ b/backend/tests/agents/web_agent_test.py
@@ -4,6 +4,8 @@
 from src.agents.web_agent import web_general_search_core
 
 @pytest.mark.asyncio
+@patch("src.agents.web_agent.answer_user_ques", new_callable=AsyncMock)
+@patch("src.agents.web_agent.create_search_term", new_callable=AsyncMock)
 @patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
 @patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
 @patch("src.agents.web_agent.perform_summarization", new_callable=AsyncMock)
@@ -13,14 +15,18 @@ async def test_web_general_search_core(
     mock_perform_summarization,
     mock_perform_scrape,
     mock_perform_search,
+    mock_create_search_term,
+    mock_answer_user_ques
 ):
     llm = AsyncMock()
     model = "mock_model"
 
-    mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
-    mock_perform_scrape.return_value = "Example scraped content."
-    mock_perform_summarization.return_value = json.dumps({"summary": "Example summary."})
-    mock_is_valid_answer.return_value = True
+    # Mocking answer_user_ques to return a valid answer
+    mock_answer_user_ques.return_value = json.dumps({
+        "status": "success",
+        "response": json.dumps({"is_valid": True, "answer": "Example summary."})
+    })
+
     result = await web_general_search_core("example query", llm, model)
     expected_response = {
         "content": "Example summary.",
@@ -28,6 +34,7 @@ async def test_web_general_search_core(
     }
     assert json.loads(result) == expected_response
 
+
 @pytest.mark.asyncio
 @patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
 @patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
@@ -43,8 +50,8 @@ async def test_web_general_search_core_no_results(
     model = "mock_model"
     mock_perform_search.return_value = {"status": "error", "urls": []}
     result = await web_general_search_core("example query", llm, model)
-    assert result == "No relevant information found on the internet for the given query."
-
+    # Updated expectation to reflect the actual return value (empty string) in case of no results
+    assert result == ""
 
 @pytest.mark.asyncio
 @patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
@@ -64,4 +71,5 @@ async def test_web_general_search_core_invalid_summary(
     mock_perform_summarization.return_value = json.dumps({"summary": "Example invalid summary."})
     mock_is_valid_answer.return_value = False
     result = await web_general_search_core("example query", llm, model)
-    assert result == "No relevant information found on the internet for the given query."
+    # Updated expectation to reflect the actual return value (empty string) in case of an invalid summary
+    assert result == ""
diff --git a/backend/tests/prompts/prompting_test.py b/backend/tests/prompts/prompting_test.py
index 66907e0f1..809315aa6 100644
--- a/backend/tests/prompts/prompting_test.py
+++ b/backend/tests/prompts/prompting_test.py
@@ -58,11 +58,12 @@ def test_load_best_next_step_template():
 AGENT LIST:
 
 
-If the list of agents does not contain something suitable, you should say the agent is 'none'. ie. If question is 'general knowledge', 'personal' or a 'greeting'.
+If the list of agents does not contain something suitable, you should say the agent is 'WebAgent'. ie. If question is 'general knowledge', 'personal' or a 'greeting'.
 
 ## Determine the next best step
 Your task is to pick one of the mentioned agents above to complete the task.
 If the same agent_name and task are repeated more than twice in the history, you must not pick that agent_name.
+If mathematical processing (e.g., rounding or calculations) is needed, choose the MathsAgent. If file operations are needed, choose the FileAgent.
 
 Your decisions must always be made independently without seeking user assistance.
 Play to your strengths as an LLM and pursue simple strategies with no legal complications.
@@ -103,11 +104,12 @@ def test_load_best_next_step_with_history_template():
 AGENT LIST:
 
 
-If the list of agents does not contain something suitable, you should say the agent is 'none'. ie. If question is 'general knowledge', 'personal' or a 'greeting'.
+If the list of agents does not contain something suitable, you should say the agent is 'WebAgent'. ie. If question is 'general knowledge', 'personal' or a 'greeting'.
 
 ## Determine the next best step
 Your task is to pick one of the mentioned agents above to complete the task.
 If the same agent_name and task are repeated more than twice in the history, you must not pick that agent_name.
+If mathematical processing (e.g., rounding or calculations) is needed, choose the MathsAgent. If file operations are needed, choose the FileAgent.
 
 Your decisions must always be made independently without seeking user assistance.
 Play to your strengths as an LLM and pursue simple strategies with no legal complications.
@@ -136,11 +138,16 @@ def test_best_tool_template():
 
 Pick 1 tool (no more than 1) from the list below to complete this task.
 Fit the correct parameters from the task to the tool arguments.
+Ensure that numerical values are formatted correctly, including the use of currency symbols (e.g., "$") and units of measurement (e.g., "million") if applicable.
 Parameters with required as False do not need to be fit.
 Add if appropriate, but do not hallucinate arguments for these parameters
 
 {"description": "mock desc", "name": "say hello world", "parameters": {"name": {"type": "string", "description": "name of user"}}}
 
+Important:
+If the task involves financial data, ensure that all monetary values are expressed with appropriate currency (e.g., "$") and rounded to the nearest million if specified.
+If the task involves scaling (e.g., thousands, millions), ensure that the extracted parameters reflect the appropriate scale (e.g., "$15 million", "$5000").
+
 From the task you should be able to extract the parameters. If it is data driven, it should be turned into a cypher query
 
 If none of the tools are appropriate for the task, return the following tool
@@ -161,7 +168,8 @@ def test_best_tool_template():
 def test_tool_selection_format_template():
     engine = PromptEngine()
     try:
-        expected_string = """Reply only in json with the following format:
+        expected_string = """Reply only in json with the following format, in the tool_paramters please include the curreny and measuring scale used in the content provided.:
+
 
 {
     \"tool_name\":  \"the exact string name of the tool chosen\",

From 9dc43734988aa968bc232ce9cc75f1a63fd9b635 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 17 Oct 2024 09:20:10 +0100
Subject: [PATCH 42/48] fixed linting error

---
 backend/tests/agents/chart_generator_agent_test.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/backend/tests/agents/chart_generator_agent_test.py b/backend/tests/agents/chart_generator_agent_test.py
index fe0c48dbd..b3d404336 100644
--- a/backend/tests/agents/chart_generator_agent_test.py
+++ b/backend/tests/agents/chart_generator_agent_test.py
@@ -140,4 +140,4 @@ def mock_exec_side_effect(script, globals=None, locals=None):
     ]
 )
 def test_sanitise_script(input_script, expected_output):
-    assert sanitise_script(input_script) == expected_output
\ No newline at end of file
+    assert sanitise_script(input_script) == expected_output

From 46e26db5096c71b02900b2d7f7b32c76c3cdb656 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 17 Oct 2024 10:05:07 +0100
Subject: [PATCH 43/48] Changes to make COnversation-history file to be exact.

---
 backend/src/prompts/templates/intent.j2 | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index 68d53bdf5..b143e3589 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -84,4 +84,4 @@ Finally, if no tool fits the task, return the following:
 }
 
 Important:
-Please always create the last intent to append the retrieved info in a conversation-history.txt file
\ No newline at end of file
+Please always create the last intent to append the retrieved info in a 'conversation-history.txt' file and make sure this history file is always 'conversation-history.txt'
\ No newline at end of file

From 4a72ccd93d55a7583ce409971a59e15e08d01d4f Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 17 Oct 2024 14:02:29 +0100
Subject: [PATCH 44/48] Webpack changes

---
 frontend/webpack.config.js | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/frontend/webpack.config.js b/frontend/webpack.config.js
index 3fcd4a0a4..ecfc2c7f8 100644
--- a/frontend/webpack.config.js
+++ b/frontend/webpack.config.js
@@ -7,14 +7,17 @@ import { fileURLToPath } from 'url';
 
 const __dirname = path.dirname(fileURLToPath(import.meta.url));
 const localEnv = dotenv.config({ path: path.resolve(__dirname, '../.env') }).parsed;
-const env = { ...process.env, ...localEnv }; 
+const env = { ...process.env, ...localEnv };
 
 const config = {
   mode: 'development',
   entry: './src/index.tsx',
   output: {
-    path: __dirname + '/dist/',
+    path: path.resolve(__dirname, 'dist'),
+    publicPath: '/',
+    filename: '[name].bundle.js'
   },
+
   module: {
     rules: [
       {

From 713fc17b4161afa8a576fc1032dc829dd8517aae Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 17 Oct 2024 15:38:21 +0100
Subject: [PATCH 45/48] Pushing the Maths changes

---
 backend/src/agents/maths_agent.py            |  8 +++++---
 backend/src/prompts/templates/math-solver.j2 | 10 +++++-----
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/backend/src/agents/maths_agent.py b/backend/src/agents/maths_agent.py
index 6b64f27f6..8e16129ba 100644
--- a/backend/src/agents/maths_agent.py
+++ b/backend/src/agents/maths_agent.py
@@ -71,7 +71,8 @@ async def is_valid_answer(answer, task) -> bool:
     name="perform_math_operation",
     description=(
         "Use this tool to perform complex mathematical operations or calculations. "
-        "It can handle queries related to arithmetic operations, algebra, or calculations involving large numbers."
+        "It handles arithmetic operations and algebra, and also supports conversions to specific units like millions, rounding when necessary. "
+        "Returns both the result and an explanation of the steps involved."
     ),
     parameters={
         "math_query": Parameter(
@@ -87,8 +88,9 @@ async def perform_math_operation(math_query, llm, model) -> str:
 @agent(
     name="MathsAgent",
     description=(
-        "This agent is responsible for handling mathematical queries and can perform "
-        "necessary rounding and formatting operations."
+        "This agent processes mathematical queries, performs calculations, and applies necessary formatting such as"
+         "rounding or converting results into specific units (e.g., millions). "
+        "It provides clear explanations of the steps involved to ensure accuracy."
     ),
     tools=[perform_math_operation],
 )
diff --git a/backend/src/prompts/templates/math-solver.j2 b/backend/src/prompts/templates/math-solver.j2
index b67a91907..64a67cef5 100644
--- a/backend/src/prompts/templates/math-solver.j2
+++ b/backend/src/prompts/templates/math-solver.j2
@@ -1,8 +1,8 @@
-You are an expert in performing mathematical operations. You are highly skilled in handling various mathematical queries such as expressing numbers in millions, performing arithmetic operations, and applying formulas as requested by the user.
+You are an expert in performing mathematical operations. You are highly skilled in handling various mathematical queries such as performing arithmetic operations, applying formulas, and expressing numbers in different formats as requested by the user.
 
-You will be given a mathematical query, and your task is to solve the query based on the provided information. Ensure that you apply the appropriate mathematical principles to deliver an exact result, specifically converting large numbers to millions without rounding off.
+You will be given a mathematical query, and your task is to solve the query based on the provided information. Ensure that you apply the appropriate mathematical principles to deliver an exact result. **Only convert numbers to millions if explicitly requested by the user.** Otherwise, return the result as is, without unnecessary conversions.
 
-Make sure to perform the calculations step by step, when necessary, and return the final result clearly.
+Make sure to perform the calculations step by step when necessary, and return the final result clearly.
 
 User's query is:
 {{ query }}
@@ -10,7 +10,7 @@ User's query is:
 Reply only in json with the following format:
 
 {
-    "result": "The final result of the mathematical operation, expressed in millions without rounding",
+    "result": "The final result of the mathematical operation, without unnecessary conversion to millions or any other format unless explicitly requested",
     "steps": "A breakdown of the steps involved in solving the query (if applicable)",
     "reasoning": "A sentence on why this result is accurate"
 }
@@ -22,4 +22,4 @@ query: Round 81.462 billion to the nearest million
     "result": "81,462 million",
     "steps": "1. Convert 81.462 billion to million by multiplying by 1000. Round the result to the nearest million.",
     "reasoning": "Rounding to the nearest million ensures that the result is represented in a more practical figure, without exceeding or falling short of the actual value."
- }
\ No newline at end of file
+}

From ae05181e9830806eea686bc4624cf125f59a37c4 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Thu, 17 Oct 2024 15:47:24 +0100
Subject: [PATCH 46/48] resolving linting error

---
 backend/src/agents/maths_agent.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/backend/src/agents/maths_agent.py b/backend/src/agents/maths_agent.py
index 8e16129ba..a55506d8f 100644
--- a/backend/src/agents/maths_agent.py
+++ b/backend/src/agents/maths_agent.py
@@ -71,8 +71,8 @@ async def is_valid_answer(answer, task) -> bool:
     name="perform_math_operation",
     description=(
         "Use this tool to perform complex mathematical operations or calculations. "
-        "It handles arithmetic operations and algebra, and also supports conversions to specific units like millions, rounding when necessary. "
-        "Returns both the result and an explanation of the steps involved."
+        "It handles arithmetic operations and algebra, and also supports conversions to specific units like millions,"
+        "rounding when necessary. Returns both the result and an explanation of the steps involved."
     ),
     parameters={
         "math_query": Parameter(

From 52125d7da5bd3eccf8e842c668f2b7452c95a70a Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Fri, 18 Oct 2024 09:48:09 +0100
Subject: [PATCH 47/48] Changes suggested by Helene

---
 backend/src/agents/web_agent.py                        | 8 +++++---
 backend/src/prompts/templates/answer-user-ques.j2      | 2 +-
 backend/src/prompts/templates/best-tool.j2             | 6 +++---
 backend/src/prompts/templates/create-search-term.j2    | 2 +-
 backend/src/prompts/templates/intent.j2                | 4 ++--
 backend/src/prompts/templates/math-solver.j2           | 1 +
 backend/src/prompts/templates/tool-selection-format.j2 | 2 +-
 backend/src/prompts/templates/validator.j2             | 2 +-
 8 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/backend/src/agents/web_agent.py b/backend/src/agents/web_agent.py
index 265d372ec..714a8c0f3 100644
--- a/backend/src/agents/web_agent.py
+++ b/backend/src/agents/web_agent.py
@@ -32,7 +32,11 @@ async def web_general_search_core(search_query, llm, model) -> str:
         answer_to_user = await answer_user_ques(search_query, llm, model)
         answer_result = json.loads(answer_to_user)
         if answer_result["status"] == "error":
-            return ""
+            response = {
+                    "content": "Error in finding the answer.",
+                    "ignore_validation": "false"
+                }
+            return json.dumps(response, indent=4)
         logger.info(f'Answer found successfully {answer_result}')
         valid_answer = json.loads(answer_result["response"]).get("is_valid", "")
         if valid_answer:
@@ -149,7 +153,6 @@ async def web_pdf_download(pdf_url, llm, model) -> str:
 
 async def web_scrape_core(url: str) -> str:
     try:
-        logger.info(f"Scraping the price of the book from URL: {url}")
         # Scrape the content from the provided URL
         content = await perform_scrape(url)
         if not content:
@@ -162,7 +165,6 @@ async def web_scrape_core(url: str) -> str:
             }
         return json.dumps(response, indent=4)
     except Exception as e:
-        logger.error(f"Error in web_scrape_price_core: {e}")
         return json.dumps({"status": "error", "error": str(e)})
 
 
diff --git a/backend/src/prompts/templates/answer-user-ques.j2 b/backend/src/prompts/templates/answer-user-ques.j2
index 02b9e56e6..b7d9a1305 100644
--- a/backend/src/prompts/templates/answer-user-ques.j2
+++ b/backend/src/prompts/templates/answer-user-ques.j2
@@ -51,4 +51,4 @@ Reply only in JSON format with the following structure:
 }
 
 
-Important: If the question is realted to real time data, the LLM should provide is_valid is false.
\ No newline at end of file
+Important: If the question is related to real time data, the LLM should provide is_valid is false.
diff --git a/backend/src/prompts/templates/best-tool.j2 b/backend/src/prompts/templates/best-tool.j2
index 9558964bb..7452cf1a4 100644
--- a/backend/src/prompts/templates/best-tool.j2
+++ b/backend/src/prompts/templates/best-tool.j2
@@ -11,15 +11,15 @@ Trust the information below completely (100% accurate)
 
 Pick 1 tool (no more than 1) from the list below to complete this task.
 Fit the correct parameters from the task to the tool arguments.
-Ensure that numerical values are formatted correctly, including the use of currency symbols (e.g., "$") and units of measurement (e.g., "million") if applicable.
+Ensure that numerical values are formatted correctly, including the use of currency symbols (e.g., "£") and units of measurement (e.g., "million") if applicable.
 Parameters with required as False do not need to be fit.
 Add if appropriate, but do not hallucinate arguments for these parameters
 
 {{ tools }}
 
 Important:
-If the task involves financial data, ensure that all monetary values are expressed with appropriate currency (e.g., "$") and rounded to the nearest million if specified.
-If the task involves scaling (e.g., thousands, millions), ensure that the extracted parameters reflect the appropriate scale (e.g., "$15 million", "$5000").
+If the task involves financial data, ensure that all monetary values are expressed with appropriate currency (e.g., "£") and rounded to the nearest million if specified.
+If the task involves scaling (e.g., thousands, millions), ensure that the extracted parameters reflect the appropriate scale (e.g., "£15 million", "£5000").
 
 From the task you should be able to extract the parameters. If it is data driven, it should be turned into a cypher query
 
diff --git a/backend/src/prompts/templates/create-search-term.j2 b/backend/src/prompts/templates/create-search-term.j2
index 6da187aaa..cc46787ac 100644
--- a/backend/src/prompts/templates/create-search-term.j2
+++ b/backend/src/prompts/templates/create-search-term.j2
@@ -13,4 +13,4 @@ Reply only in JSON format, following this structure:
 {
     "search_term": "The optimized Google search term based on the user's question",
     "reasoning": "A sentence on why you chose that search term"
-}
\ No newline at end of file
+}
diff --git a/backend/src/prompts/templates/intent.j2 b/backend/src/prompts/templates/intent.j2
index b143e3589..a112f3221 100644
--- a/backend/src/prompts/templates/intent.j2
+++ b/backend/src/prompts/templates/intent.j2
@@ -4,7 +4,7 @@ The question is:
 
 {{ question }} 
 
-The prvious chat history is: 
+The previous chat history is: 
 
 {{ chat_history }}
 
@@ -84,4 +84,4 @@ Finally, if no tool fits the task, return the following:
 }
 
 Important:
-Please always create the last intent to append the retrieved info in a 'conversation-history.txt' file and make sure this history file is always 'conversation-history.txt'
\ No newline at end of file
+Please always create the last intent to append the retrieved info in a 'conversation-history.txt' file and make sure this history file is always named 'conversation-history.txt'
diff --git a/backend/src/prompts/templates/math-solver.j2 b/backend/src/prompts/templates/math-solver.j2
index 64a67cef5..122e88d2e 100644
--- a/backend/src/prompts/templates/math-solver.j2
+++ b/backend/src/prompts/templates/math-solver.j2
@@ -23,3 +23,4 @@ query: Round 81.462 billion to the nearest million
     "steps": "1. Convert 81.462 billion to million by multiplying by 1000. Round the result to the nearest million.",
     "reasoning": "Rounding to the nearest million ensures that the result is represented in a more practical figure, without exceeding or falling short of the actual value."
 }
+
diff --git a/backend/src/prompts/templates/tool-selection-format.j2 b/backend/src/prompts/templates/tool-selection-format.j2
index 3a7c664b4..c1cba32ca 100644
--- a/backend/src/prompts/templates/tool-selection-format.j2
+++ b/backend/src/prompts/templates/tool-selection-format.j2
@@ -1,4 +1,4 @@
-Reply only in json with the following format, in the tool_paramters please include the curreny and measuring scale used in the content provided.:
+Reply only in json with the following format, in the tool_parameters please include the currency and measuring scale used in the content provided.:
 
 
 {
diff --git a/backend/src/prompts/templates/validator.j2 b/backend/src/prompts/templates/validator.j2
index 161760185..54256b525 100644
--- a/backend/src/prompts/templates/validator.j2
+++ b/backend/src/prompts/templates/validator.j2
@@ -27,7 +27,7 @@ Reasoning: The answer is for Spotify not Amazon.
 Task: Please find tesla's revenue every year since its creation.
 Answer: Tesla's annual revenue history from FY 2008 to FY 2023 is available, with figures for 2008 through 2020 taken from previous annual reports.
 Response: False
-Reasoning: The answer is not prvoding any actual figures but just talk about the figures.
+Reasoning: The answer is not providing any actual figures but just talk about the figures.
 
 Task: Please find tesla's revenue every year since its creation in the US dollars.
 Answer: Tesla's annual revenue in USD since its creation is as follows: 2024 (TTM) $75.92 billion, 2023 $75.95 billion, 2022 $67.33 billion, 2021 $39.76 billion, 2020 $23.10 billion, 2019 $18.52 billion, 2018 $16.81 billion, 2017 $8.70 billion, 2016 $5.67 billion, 2015 $2.72 billion, 2014 $2.05 billion, 2013 $1.21 billion, 2012 $0.25 billion, 2011 $0.13 billion, 2010 $75.88 million, 2009 $69.73 million.

From 8aeccdc6857619c6a28cc15e84278ddf9c4dcce6 Mon Sep 17 00:00:00 2001
From: Gagan Singh <gagan.singh@faculty.ai>
Date: Fri, 18 Oct 2024 10:07:09 +0100
Subject: [PATCH 48/48] Fixed failed unit ntests

---
 backend/tests/agents/web_agent_test.py  | 16 ++++++++++++----
 backend/tests/prompts/prompting_test.py |  8 ++++----
 2 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/backend/tests/agents/web_agent_test.py b/backend/tests/agents/web_agent_test.py
index 72728e5b9..b863aaa8a 100644
--- a/backend/tests/agents/web_agent_test.py
+++ b/backend/tests/agents/web_agent_test.py
@@ -50,8 +50,12 @@ async def test_web_general_search_core_no_results(
     model = "mock_model"
     mock_perform_search.return_value = {"status": "error", "urls": []}
     result = await web_general_search_core("example query", llm, model)
-    # Updated expectation to reflect the actual return value (empty string) in case of no results
-    assert result == ""
+
+    expected_response = {
+        "content": "Error in finding the answer.",
+        "ignore_validation": "false"
+    }
+    assert json.loads(result) == expected_response
 
 @pytest.mark.asyncio
 @patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
@@ -71,5 +75,9 @@ async def test_web_general_search_core_invalid_summary(
     mock_perform_summarization.return_value = json.dumps({"summary": "Example invalid summary."})
     mock_is_valid_answer.return_value = False
     result = await web_general_search_core("example query", llm, model)
-    # Updated expectation to reflect the actual return value (empty string) in case of an invalid summary
-    assert result == ""
+    expected_response = {
+        "content": "Error in finding the answer.",
+        "ignore_validation": "false"
+    }
+    assert json.loads(result) == expected_response
+
diff --git a/backend/tests/prompts/prompting_test.py b/backend/tests/prompts/prompting_test.py
index 809315aa6..027574cab 100644
--- a/backend/tests/prompts/prompting_test.py
+++ b/backend/tests/prompts/prompting_test.py
@@ -138,15 +138,15 @@ def test_best_tool_template():
 
 Pick 1 tool (no more than 1) from the list below to complete this task.
 Fit the correct parameters from the task to the tool arguments.
-Ensure that numerical values are formatted correctly, including the use of currency symbols (e.g., "$") and units of measurement (e.g., "million") if applicable.
+Ensure that numerical values are formatted correctly, including the use of currency symbols (e.g., "£") and units of measurement (e.g., "million") if applicable.
 Parameters with required as False do not need to be fit.
 Add if appropriate, but do not hallucinate arguments for these parameters
 
 {"description": "mock desc", "name": "say hello world", "parameters": {"name": {"type": "string", "description": "name of user"}}}
 
 Important:
-If the task involves financial data, ensure that all monetary values are expressed with appropriate currency (e.g., "$") and rounded to the nearest million if specified.
-If the task involves scaling (e.g., thousands, millions), ensure that the extracted parameters reflect the appropriate scale (e.g., "$15 million", "$5000").
+If the task involves financial data, ensure that all monetary values are expressed with appropriate currency (e.g., "£") and rounded to the nearest million if specified.
+If the task involves scaling (e.g., thousands, millions), ensure that the extracted parameters reflect the appropriate scale (e.g., "£15 million", "£5000").
 
 From the task you should be able to extract the parameters. If it is data driven, it should be turned into a cypher query
 
@@ -168,7 +168,7 @@ def test_best_tool_template():
 def test_tool_selection_format_template():
     engine = PromptEngine()
     try:
-        expected_string = """Reply only in json with the following format, in the tool_paramters please include the curreny and measuring scale used in the content provided.:
+        expected_string = """Reply only in json with the following format, in the tool_parameters please include the currency and measuring scale used in the content provided.:
 
 
 {