Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(SemanticAgent): add samples in schema and support back-tick json load #1241

Merged
merged 2 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandasai/ee/agents/semantic_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import (
GenerateDFSchemaPrompt,
)
from pandasai.ee.helpers.json_helper import extract_json_from_json_str
from pandasai.exceptions import InvalidConfigError, InvalidSchemaJson, InvalidTrainJson
from pandasai.helpers.cache import Cache
from pandasai.helpers.memory import Memory
Expand Down Expand Up @@ -186,7 +187,7 @@ def _create_schema(self):
"""
)
self._schema = result.replace("# SAMPLE SCHEMA", "")
schema_data = json.loads(result.replace("# SAMPLE SCHEMA", ""))
schema_data = extract_json_from_json_str(result.replace("# SAMPLE SCHEMA", ""))
if isinstance(schema_data, dict):
schema_data = [schema_data]

Expand Down
4 changes: 2 additions & 2 deletions pandasai/ee/agents/semantic_agent/pipeline/llm_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Any

from pandasai.ee.helpers.json_helper import extract_json_from_json_str
from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.pipelines.logic_unit_output import LogicUnitOutput
Expand Down Expand Up @@ -42,7 +42,7 @@ def execute(self, input: Any, **kwargs) -> Any:
)
try:
# Validate is valid Json
response_json = json.loads(response)
response_json = extract_json_from_json_str(response)

pipeline_context.add("llm_call", response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from jinja2 import Environment, FileSystemLoader

from pandasai.ee.helpers.json_helper import extract_json_from_json_str
from pandasai.prompts.base import BasePrompt


Expand Down Expand Up @@ -30,7 +31,9 @@ def __init__(self, **kwargs):

def validate(self, output: str) -> bool:
try:
json_data = json.loads(output.replace("# SAMPLE SCHEMA", ""))
json_data = extract_json_from_json_str(
output.replace("# SAMPLE SCHEMA", "")
)
context = self.props["context"]
if isinstance(json_data, dict):
json_data = [json_data]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,132 +1,147 @@
# SAMPLE SCHEMA
[
{
"name":"Contracts",
"table":"contracts",
"measures":[
"name": "Contracts",
"table": "contracts",
"measures": [
{
"name":"contract_count",
"type":"count",
"sql":"store_id"
"name": "contract_count",
"type": "count",
"sql": "store_id"
},
{
"name":"contract_duration",
"type":"number",
"sql":"${contract_end_date} - ${contract_start_date}"
"name": "contract_duration",
"type": "number",
"sql": "${contract_end_date} - ${contract_start_date}"
},
{
"name":"contract_avg_duration",
"type":"avg",
"sql":"${contract_duration}"
"name": "contract_avg_duration",
"type": "avg",
"sql": "${contract_duration}"
}
],
"dimensions":[
"dimensions": [
{
"name":"contract_code",
"type":"string",
"sql":"contract_code"
"name": "contract_code",
"type": "string",
"sql": "contract_code",
"samples": ["C12345", "C67890"]
},
{
"name":"store_id",
"type":"string",
"sql":"store_id"
"name": "store_id",
"type": "string",
"sql": "store_id",
"samples": ["S12345", "S67890"]
},
{
"name":"tenant_code",
"type":"string",
"sql":"tenant_code"
"name": "tenant_code",
"type": "string",
"sql": "tenant_code",
"samples": ["T12345", "T67890"]
},
{
"name":"tenant_name",
"type":"string",
"sql":"tenant_name"
"name": "tenant_name",
"type": "string",
"sql": "tenant_name",
"samples": ["Tenant A", "Tenant B"]
},
{
"name":"store_brand",
"type":"string",
"sql":"store_brand"
"name": "store_brand",
"type": "string",
"sql": "store_brand",
"samples": ["Brand X", "Brand Y"]
},
{
"name":"branch_segment_1",
"type":"string",
"sql":"branch_segment_1"
"name": "branch_segment_1",
"type": "string",
"sql": "branch_segment_1",
"samples": ["Segment 1", "Segment 2"]
},
{
"name":"branch_segment_2",
"type":"string",
"sql":"branch_segment_2"
"name": "branch_segment_2",
"type": "string",
"sql": "branch_segment_2",
"samples": ["Segment A", "Segment B"]
},
{
"name":"contract_start_date",
"type":"date",
"sql":"contract_start_date"
"name": "contract_start_date",
"type": "date",
"sql": "contract_start_date",
"samples": ["2023-01-01", "2023-02-01"]
},
{
"name":"contract_end_date",
"type":"date",
"sql":"contract_end_date"
"name": "contract_end_date",
"type": "date",
"sql": "contract_end_date",
"samples": ["2024-01-01", "2024-02-01"]
}
],
"joins":[
"joins": [
{
"name":"corrispettivi",
"join_type":"left",
"sql":"${Contracts.contract_code} = ${Fees.contract_id}"
"name": "Fee",
"join_type": "left",
"sql": "${Contracts.contract_code} = ${Fees.contract_id}"
}
]
},
{
"name":"Fees",
"table":"fees",
"measures":[
"name": "Fees",
"table": "fees",
"measures": [
{
"name":"total_taxable",
"type":"sum",
"sql":"imponibile_tot"
"name": "total_taxable",
"type": "sum",
"sql": "imponibile_tot"
},
{
"name":"total_revenue",
"type":"sum",
"sql":"totale_tot"
"name": "total_revenue",
"type": "sum",
"sql": "totale_tot"
}
],
"dimensions":[
"dimensions": [
{
"name":"contract_id",
"type":"string",
"sql":"contract_id"
"name": "contract_id",
"type": "string",
"sql": "contract_id",
"samples": ["C12345", "C67890"]
},
{
"name":"code",
"type":"string",
"sql":"code"
"name": "code",
"type": "string",
"sql": "code",
"samples": ["F12345", "F67890"]
},
{
"name":"station",
"type":"string",
"sql":"station"
"name": "station",
"type": "string",
"sql": "station",
"samples": ["Station X", "Station Y"]
},
{
"name":"tenant_id",
"type":"string",
"sql":"tenant_id"
"name": "tenant_id",
"type": "string",
"sql": "tenant_id",
"samples": ["T12345", "T67890"]
},
{
"name":"day",
"type":"date",
"sql":"day"
"name": "day",
"type": "date",
"sql": "day",
"samples": ["2023-01-01", "2023-02-01"]
},
{
"name":"store_id",
"type":"string",
"sql":"store_id"
"name": "store_id",
"type": "string",
"sql": "store_id",
"samples": ["S12345", "S67890"]
}
],
"joins":[
"joins": [
{
"name":"contracts",
"join_type":"right",
"sql":"${Fees.contract_id} = ${Fees.contract_code}"
"name": "Contracts",
"join_type": "right",
"sql": "${Fees.contract_id} = ${Contracts.contract_code}"
}
]
}
Expand Down
14 changes: 14 additions & 0 deletions pandasai/ee/helpers/json_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import json


def extract_json_from_json_str(json_str):
start_index = json_str.find("```json")

end_index = json_str.find("```", start_index)

if start_index == -1:
return json.loads(json_str)

json_data = json_str[(start_index + len("```json")) : end_index].strip()

Check warning on line 12 in pandasai/ee/helpers/json_helper.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/helpers/json_helper.py#L12

Added line #L12 was not covered by tests

return json.loads(json_data)

Check warning on line 14 in pandasai/ee/helpers/json_helper.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/helpers/json_helper.py#L14

Added line #L14 was not covered by tests
2 changes: 1 addition & 1 deletion pandasai/ee/helpers/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _build_from_clause(self, main_table_entry):

def _build_joins_clause(self, main_table_entry, referenced_tables):
sql = ""
main_table = main_table_entry["table"]
main_table = main_table_entry["name"]

for table_name in referenced_tables:
if table_name != main_table:
Expand Down
43 changes: 43 additions & 0 deletions tests/unit_tests/ee/helpers/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,46 @@
],
}
]


MULTI_JOIN_SCHEMA = [
{
"name": "Sales",
"table": "sales",
"measures": [
{"name": "total_revenue", "type": "sum", "sql": "revenue"},
{"name": "total_sales", "type": "count", "sql": "id"},
],
"dimensions": [
{"name": "product", "type": "string", "sql": "product"},
{"name": "region", "type": "string", "sql": "region"},
{"name": "sales_date", "type": "date", "sql": "sales_date"},
{"name": "id", "type": "string", "sql": "id"},
],
"joins": [
{
"name": "Engagement",
"join_type": "left",
"sql": "${Sales.id} = ${Engagement.id}",
}
],
},
{
"name": "Engagement",
"table": "engagement",
"measures": [{"name": "total_duration", "type": "sum", "sql": "duration"}],
"dimensions": [
{"name": "id", "type": "string", "sql": "id"},
{"name": "user_id", "type": "string", "sql": "user_id"},
{"name": "activity_type", "type": "string", "sql": "activity_type"},
{"name": "engagement_date", "type": "date", "sql": "engagement_date"},
],
"joins": [
{
"name": "Sales",
"join_type": "right",
"sql": "${Engagement.id} = ${Sales.id}",
}
],
},
]
39 changes: 38 additions & 1 deletion tests/unit_tests/ee/helpers/test_query_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from pandasai.ee.helpers.query_builder import QueryBuilder
from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA
from tests.unit_tests.ee.helpers.schema import MULTI_JOIN_SCHEMA, VIZ_QUERY_SCHEMA


class TestQueryBuilder(unittest.TestCase):
Expand Down Expand Up @@ -191,3 +191,40 @@ def test_sql_with_filters_with_set_filter(self):
"SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc",
"SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc",
]

def test_sql_with_filters_with_join(self):
query_builder = QueryBuilder(MULTI_JOIN_SCHEMA)

json_str = {
"type": "bar",
"dimensions": ["Engagement.activity_type"],
"measures": ["Sales.total_revenue"],
"timeDimensions": [],
"options": {
"xLabel": "Activity Type",
"yLabel": "Total Revenue",
"title": "Total Revenue Generated from Users who Logged in Before Purchase",
"legend": {"display": True, "position": "top"},
},
"joins": [
{
"name": "Engagement",
"join_type": "right",
"sql": "${Sales.id} = ${Engagement.id}",
}
],
"filters": [
{
"member": "Engagement.engagement_date",
"operator": "beforeDate",
"values": ["${Sales.sales_date}"],
}
],
"order": [{"id": "Sales.total_revenue", "direction": "asc"}],
}
sql_query = query_builder.generate_sql(json_str)

assert (
sql_query
== "SELECT `engagement`.`activity_type` AS activity_type, SUM(`sales`.`revenue`) AS total_revenue FROM `sales` RIGHT JOIN `engagement` ON `engagement`.`id` = `sales`.`id` WHERE `engagement`.`engagement_date` < '${Sales.sales_date}' GROUP BY activity_type ORDER BY total_revenue asc"
)
Loading