Skip to content

Commit

Permalink
fix(SemanticAgent): join data to be fixed (#1239)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Jun 18, 2024
1 parent e321aa3 commit 9741640
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pandasai/ee/helpers/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _build_from_clause(self, main_table_entry):

def _build_joins_clause(self, main_table_entry, referenced_tables):
sql = ""
main_table = main_table_entry["table"]
main_table = main_table_entry["name"]

for table_name in referenced_tables:
if table_name != main_table:
Expand Down
43 changes: 43 additions & 0 deletions tests/unit_tests/ee/helpers/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,46 @@
],
}
]


MULTI_JOIN_SCHEMA = [
{
"name": "Sales",
"table": "sales",
"measures": [
{"name": "total_revenue", "type": "sum", "sql": "revenue"},
{"name": "total_sales", "type": "count", "sql": "id"},
],
"dimensions": [
{"name": "product", "type": "string", "sql": "product"},
{"name": "region", "type": "string", "sql": "region"},
{"name": "sales_date", "type": "date", "sql": "sales_date"},
{"name": "id", "type": "string", "sql": "id"},
],
"joins": [
{
"name": "Engagement",
"join_type": "left",
"sql": "${Sales.id} = ${Engagement.id}",
}
],
},
{
"name": "Engagement",
"table": "engagement",
"measures": [{"name": "total_duration", "type": "sum", "sql": "duration"}],
"dimensions": [
{"name": "id", "type": "string", "sql": "id"},
{"name": "user_id", "type": "string", "sql": "user_id"},
{"name": "activity_type", "type": "string", "sql": "activity_type"},
{"name": "engagement_date", "type": "date", "sql": "engagement_date"},
],
"joins": [
{
"name": "Sales",
"join_type": "right",
"sql": "${Engagement.id} = ${Sales.id}",
}
],
},
]
39 changes: 38 additions & 1 deletion tests/unit_tests/ee/helpers/test_query_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from pandasai.ee.helpers.query_builder import QueryBuilder
from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA
from tests.unit_tests.ee.helpers.schema import MULTI_JOIN_SCHEMA, VIZ_QUERY_SCHEMA


class TestQueryBuilder(unittest.TestCase):
Expand Down Expand Up @@ -191,3 +191,40 @@ def test_sql_with_filters_with_set_filter(self):
"SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc",
"SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc",
]

def test_sql_with_filters_with_join(self):
query_builder = QueryBuilder(MULTI_JOIN_SCHEMA)

json_str = {
"type": "bar",
"dimensions": ["Engagement.activity_type"],
"measures": ["Sales.total_revenue"],
"timeDimensions": [],
"options": {
"xLabel": "Activity Type",
"yLabel": "Total Revenue",
"title": "Total Revenue Generated from Users who Logged in Before Purchase",
"legend": {"display": True, "position": "top"},
},
"joins": [
{
"name": "Engagement",
"join_type": "right",
"sql": "${Sales.id} = ${Engagement.id}",
}
],
"filters": [
{
"member": "Engagement.engagement_date",
"operator": "beforeDate",
"values": ["${Sales.sales_date}"],
}
],
"order": [{"id": "Sales.total_revenue", "direction": "asc"}],
}
sql_query = query_builder.generate_sql(json_str)

assert (
sql_query
== "SELECT `engagement`.`activity_type` AS activity_type, SUM(`sales`.`revenue`) AS total_revenue FROM `sales` RIGHT JOIN `engagement` ON `engagement`.`id` = `sales`.`id` WHERE `engagement`.`engagement_date` < '${Sales.sales_date}' GROUP BY activity_type ORDER BY total_revenue asc"
)

0 comments on commit 9741640

Please sign in to comment.