From 5a30ef30bb49d73a67ffcbf3a7baca3e31475f4b Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 18 Jun 2024 18:58:44 +0200 Subject: [PATCH 1/2] fix(SemanticAgent): join data to be fixed --- pandasai/ee/helpers/query_builder.py | 2 +- tests/unit_tests/ee/helpers/schema.py | 43 +++++++++++++++++++ .../ee/helpers/test_query_builder.py | 39 ++++++++++++++++- 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/pandasai/ee/helpers/query_builder.py b/pandasai/ee/helpers/query_builder.py index 32a11071e..7f7202085 100644 --- a/pandasai/ee/helpers/query_builder.py +++ b/pandasai/ee/helpers/query_builder.py @@ -296,7 +296,7 @@ def _build_from_clause(self, main_table_entry): def _build_joins_clause(self, main_table_entry, referenced_tables): sql = "" - main_table = main_table_entry["table"] + main_table = main_table_entry["name"] for table_name in referenced_tables: if table_name != main_table: diff --git a/tests/unit_tests/ee/helpers/schema.py b/tests/unit_tests/ee/helpers/schema.py index 80bae8bad..f82ac7365 100644 --- a/tests/unit_tests/ee/helpers/schema.py +++ b/tests/unit_tests/ee/helpers/schema.py @@ -43,3 +43,46 @@ ], } ] + + +MULTI_JOIN_SCHEMA = [ + { + "name": "Sales", + "table": "sales", + "measures": [ + {"name": "total_revenue", "type": "sum", "sql": "revenue"}, + {"name": "total_sales", "type": "count", "sql": "id"}, + ], + "dimensions": [ + {"name": "product", "type": "string", "sql": "product"}, + {"name": "region", "type": "string", "sql": "region"}, + {"name": "sales_date", "type": "date", "sql": "sales_date"}, + {"name": "id", "type": "string", "sql": "id"}, + ], + "joins": [ + { + "name": "Engagement", + "join_type": "left", + "sql": "${Sales.id} = ${Engagement.id}", + } + ], + }, + { + "name": "Engagement", + "table": "engagement", + "measures": [{"name": "total_duration", "type": "sum", "sql": "duration"}], + "dimensions": [ + {"name": "id", "type": "string", "sql": "id"}, + {"name": "user_id", "type": "string", "sql": "user_id"}, + {"name": "activity_type", "type": "string", "sql": "activity_type"}, + {"name": "engagement_date", "type": "date", "sql": "engagement_date"}, + ], + "joins": [ + { + "name": "Sales", + "join_type": "right", + "sql": "${Engagement.id} = ${Sales.id}", + } + ], + }, +] diff --git a/tests/unit_tests/ee/helpers/test_query_builder.py b/tests/unit_tests/ee/helpers/test_query_builder.py index 4dd2cc8e9..5f73b2e81 100644 --- a/tests/unit_tests/ee/helpers/test_query_builder.py +++ b/tests/unit_tests/ee/helpers/test_query_builder.py @@ -1,7 +1,7 @@ import unittest from pandasai.ee.helpers.query_builder import QueryBuilder -from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA +from tests.unit_tests.ee.helpers.schema import MULTI_JOIN_SCHEMA, VIZ_QUERY_SCHEMA class TestQueryBuilder(unittest.TestCase): @@ -191,3 +191,40 @@ def test_sql_with_filters_with_set_filter(self): "SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc", "SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc", ] + + def test_sql_with_filters_with_join(self): + query_builder = QueryBuilder(MULTI_JOIN_SCHEMA) + + json_str = { + "type": "bar", + "dimensions": ["Engagement.activity_type"], + "measures": ["Sales.total_revenue"], + "timeDimensions": [], + "options": { + "xLabel": "Activity Type", + "yLabel": "Total Revenue", + "title": "Total Revenue Generated from Users who Logged in Before Purchase", + "legend": {"display": True, "position": "top"}, + }, + "joins": [ + { + "name": "Engagement", + "join_type": "right", + "sql": "${Sales.id} = ${Engagement.id}", + } + ], + "filters": [ + { + "member": "Engagement.engagement_date", + "operator": "beforeDate", + "values": ["${Sales.sales_date}"], + } + ], + "order": [{"id": "Sales.total_revenue", "direction": "asc"}], + } + sql_query = query_builder.generate_sql(json_str) + + assert ( + sql_query + == "SELECT `engagement`.`activity_type` AS activity_type, SUM(`sales`.`revenue`) AS total_revenue FROM `sales` RIGHT JOIN `engagement` ON `engagement`.`id` = `sales`.`id` WHERE `engagement`.`engagement_date` < '${Sales.sales_date}' GROUP BY activity_type ORDER BY total_revenue asc" + ) From c0f900391767575f6ca9e62b6c7869cca246d523 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Wed, 19 Jun 2024 12:18:54 +0200 Subject: [PATCH 2/2] fix(semantic_agent): json load to also look for json in backtick --- pandasai/ee/agents/semantic_agent/__init__.py | 3 +- .../semantic_agent/pipeline/llm_call.py | 4 +- .../prompts/generate_df_schema.py | 5 +- .../prompts/templates/generate_df_schema.tmpl | 167 ++++++++++-------- pandasai/ee/helpers/json_helper.py | 14 ++ 5 files changed, 113 insertions(+), 80 deletions(-) create mode 100644 pandasai/ee/helpers/json_helper.py diff --git a/pandasai/ee/agents/semantic_agent/__init__.py b/pandasai/ee/agents/semantic_agent/__init__.py index dc8aee31b..d6f736372 100644 --- a/pandasai/ee/agents/semantic_agent/__init__.py +++ b/pandasai/ee/agents/semantic_agent/__init__.py @@ -15,6 +15,7 @@ from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import ( GenerateDFSchemaPrompt, ) +from pandasai.ee.helpers.json_helper import extract_json_from_json_str from pandasai.exceptions import InvalidConfigError, InvalidSchemaJson, InvalidTrainJson from pandasai.helpers.cache import Cache from pandasai.helpers.memory import Memory @@ -186,7 +187,7 @@ def _create_schema(self): """ ) self._schema = result.replace("# SAMPLE SCHEMA", "") - schema_data = json.loads(result.replace("# SAMPLE SCHEMA", "")) + schema_data = extract_json_from_json_str(result.replace("# SAMPLE SCHEMA", "")) if isinstance(schema_data, dict): schema_data = [schema_data] diff --git a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py index e9946140d..af1bd2e18 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py @@ -1,6 +1,6 @@ -import json from typing import Any +from pandasai.ee.helpers.json_helper import extract_json_from_json_str from pandasai.helpers.logger import Logger from pandasai.pipelines.base_logic_unit import BaseLogicUnit from pandasai.pipelines.logic_unit_output import LogicUnitOutput @@ -42,7 +42,7 @@ def execute(self, input: Any, **kwargs) -> Any: ) try: # Validate is valid Json - response_json = json.loads(response) + response_json = extract_json_from_json_str(response) pipeline_context.add("llm_call", response) diff --git a/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py b/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py index 80237c944..28390f8b7 100644 --- a/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py +++ b/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py @@ -3,6 +3,7 @@ from jinja2 import Environment, FileSystemLoader +from pandasai.ee.helpers.json_helper import extract_json_from_json_str from pandasai.prompts.base import BasePrompt @@ -30,7 +31,9 @@ def __init__(self, **kwargs): def validate(self, output: str) -> bool: try: - json_data = json.loads(output.replace("# SAMPLE SCHEMA", "")) + json_data = extract_json_from_json_str( + output.replace("# SAMPLE SCHEMA", "") + ) context = self.props["context"] if isinstance(json_data, dict): json_data = [json_data] diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl index 6a45e5fe1..edec51e2d 100644 --- a/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl +++ b/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl @@ -1,132 +1,147 @@ # SAMPLE SCHEMA [ { - "name":"Contracts", - "table":"contracts", - "measures":[ + "name": "Contracts", + "table": "contracts", + "measures": [ { - "name":"contract_count", - "type":"count", - "sql":"store_id" + "name": "contract_count", + "type": "count", + "sql": "store_id" }, { - "name":"contract_duration", - "type":"number", - "sql":"${contract_end_date} - ${contract_start_date}" + "name": "contract_duration", + "type": "number", + "sql": "${contract_end_date} - ${contract_start_date}" }, { - "name":"contract_avg_duration", - "type":"avg", - "sql":"${contract_duration}" + "name": "contract_avg_duration", + "type": "avg", + "sql": "${contract_duration}" } ], - "dimensions":[ + "dimensions": [ { - "name":"contract_code", - "type":"string", - "sql":"contract_code" + "name": "contract_code", + "type": "string", + "sql": "contract_code", + "samples": ["C12345", "C67890"] }, { - "name":"store_id", - "type":"string", - "sql":"store_id" + "name": "store_id", + "type": "string", + "sql": "store_id", + "samples": ["S12345", "S67890"] }, { - "name":"tenant_code", - "type":"string", - "sql":"tenant_code" + "name": "tenant_code", + "type": "string", + "sql": "tenant_code", + "samples": ["T12345", "T67890"] }, { - "name":"tenant_name", - "type":"string", - "sql":"tenant_name" + "name": "tenant_name", + "type": "string", + "sql": "tenant_name", + "samples": ["Tenant A", "Tenant B"] }, { - "name":"store_brand", - "type":"string", - "sql":"store_brand" + "name": "store_brand", + "type": "string", + "sql": "store_brand", + "samples": ["Brand X", "Brand Y"] }, { - "name":"branch_segment_1", - "type":"string", - "sql":"branch_segment_1" + "name": "branch_segment_1", + "type": "string", + "sql": "branch_segment_1", + "samples": ["Segment 1", "Segment 2"] }, { - "name":"branch_segment_2", - "type":"string", - "sql":"branch_segment_2" + "name": "branch_segment_2", + "type": "string", + "sql": "branch_segment_2", + "samples": ["Segment A", "Segment B"] }, { - "name":"contract_start_date", - "type":"date", - "sql":"contract_start_date" + "name": "contract_start_date", + "type": "date", + "sql": "contract_start_date", + "samples": ["2023-01-01", "2023-02-01"] }, { - "name":"contract_end_date", - "type":"date", - "sql":"contract_end_date" + "name": "contract_end_date", + "type": "date", + "sql": "contract_end_date", + "samples": ["2024-01-01", "2024-02-01"] } ], - "joins":[ + "joins": [ { - "name":"corrispettivi", - "join_type":"left", - "sql":"${Contracts.contract_code} = ${Fees.contract_id}" + "name": "Fee", + "join_type": "left", + "sql": "${Contracts.contract_code} = ${Fees.contract_id}" } ] }, { - "name":"Fees", - "table":"fees", - "measures":[ + "name": "Fees", + "table": "fees", + "measures": [ { - "name":"total_taxable", - "type":"sum", - "sql":"imponibile_tot" + "name": "total_taxable", + "type": "sum", + "sql": "imponibile_tot" }, { - "name":"total_revenue", - "type":"sum", - "sql":"totale_tot" + "name": "total_revenue", + "type": "sum", + "sql": "totale_tot" } ], - "dimensions":[ + "dimensions": [ { - "name":"contract_id", - "type":"string", - "sql":"contract_id" + "name": "contract_id", + "type": "string", + "sql": "contract_id", + "samples": ["C12345", "C67890"] }, { - "name":"code", - "type":"string", - "sql":"code" + "name": "code", + "type": "string", + "sql": "code", + "samples": ["F12345", "F67890"] }, { - "name":"station", - "type":"string", - "sql":"station" + "name": "station", + "type": "string", + "sql": "station", + "samples": ["Station X", "Station Y"] }, { - "name":"tenant_id", - "type":"string", - "sql":"tenant_id" + "name": "tenant_id", + "type": "string", + "sql": "tenant_id", + "samples": ["T12345", "T67890"] }, { - "name":"day", - "type":"date", - "sql":"day" + "name": "day", + "type": "date", + "sql": "day", + "samples": ["2023-01-01", "2023-02-01"] }, { - "name":"store_id", - "type":"string", - "sql":"store_id" + "name": "store_id", + "type": "string", + "sql": "store_id", + "samples": ["S12345", "S67890"] } ], - "joins":[ + "joins": [ { - "name":"contracts", - "join_type":"right", - "sql":"${Fees.contract_id} = ${Fees.contract_code}" + "name": "Contracts", + "join_type": "right", + "sql": "${Fees.contract_id} = ${Contracts.contract_code}" } ] } diff --git a/pandasai/ee/helpers/json_helper.py b/pandasai/ee/helpers/json_helper.py new file mode 100644 index 000000000..a7ca0bce2 --- /dev/null +++ b/pandasai/ee/helpers/json_helper.py @@ -0,0 +1,14 @@ +import json + + +def extract_json_from_json_str(json_str): + start_index = json_str.find("```json") + + end_index = json_str.find("```", start_index) + + if start_index == -1: + return json.loads(json_str) + + json_data = json_str[(start_index + len("```json")) : end_index].strip() + + return json.loads(json_data)