From 5a30ef30bb49d73a67ffcbf3a7baca3e31475f4b Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 18 Jun 2024 18:58:44 +0200 Subject: [PATCH] fix(SemanticAgent): join data to be fixed --- pandasai/ee/helpers/query_builder.py | 2 +- tests/unit_tests/ee/helpers/schema.py | 43 +++++++++++++++++++ .../ee/helpers/test_query_builder.py | 39 ++++++++++++++++- 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/pandasai/ee/helpers/query_builder.py b/pandasai/ee/helpers/query_builder.py index 32a11071e..7f7202085 100644 --- a/pandasai/ee/helpers/query_builder.py +++ b/pandasai/ee/helpers/query_builder.py @@ -296,7 +296,7 @@ def _build_from_clause(self, main_table_entry): def _build_joins_clause(self, main_table_entry, referenced_tables): sql = "" - main_table = main_table_entry["table"] + main_table = main_table_entry["name"] for table_name in referenced_tables: if table_name != main_table: diff --git a/tests/unit_tests/ee/helpers/schema.py b/tests/unit_tests/ee/helpers/schema.py index 80bae8bad..f82ac7365 100644 --- a/tests/unit_tests/ee/helpers/schema.py +++ b/tests/unit_tests/ee/helpers/schema.py @@ -43,3 +43,46 @@ ], } ] + + +MULTI_JOIN_SCHEMA = [ + { + "name": "Sales", + "table": "sales", + "measures": [ + {"name": "total_revenue", "type": "sum", "sql": "revenue"}, + {"name": "total_sales", "type": "count", "sql": "id"}, + ], + "dimensions": [ + {"name": "product", "type": "string", "sql": "product"}, + {"name": "region", "type": "string", "sql": "region"}, + {"name": "sales_date", "type": "date", "sql": "sales_date"}, + {"name": "id", "type": "string", "sql": "id"}, + ], + "joins": [ + { + "name": "Engagement", + "join_type": "left", + "sql": "${Sales.id} = ${Engagement.id}", + } + ], + }, + { + "name": "Engagement", + "table": "engagement", + "measures": [{"name": "total_duration", "type": "sum", "sql": "duration"}], + "dimensions": [ + {"name": "id", "type": "string", "sql": "id"}, + {"name": "user_id", "type": "string", "sql": "user_id"}, + {"name": "activity_type", "type": "string", "sql": "activity_type"}, + {"name": "engagement_date", "type": "date", "sql": "engagement_date"}, + ], + "joins": [ + { + "name": "Sales", + "join_type": "right", + "sql": "${Engagement.id} = ${Sales.id}", + } + ], + }, +] diff --git a/tests/unit_tests/ee/helpers/test_query_builder.py b/tests/unit_tests/ee/helpers/test_query_builder.py index 4dd2cc8e9..5f73b2e81 100644 --- a/tests/unit_tests/ee/helpers/test_query_builder.py +++ b/tests/unit_tests/ee/helpers/test_query_builder.py @@ -1,7 +1,7 @@ import unittest from pandasai.ee.helpers.query_builder import QueryBuilder -from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA +from tests.unit_tests.ee.helpers.schema import MULTI_JOIN_SCHEMA, VIZ_QUERY_SCHEMA class TestQueryBuilder(unittest.TestCase): @@ -191,3 +191,40 @@ def test_sql_with_filters_with_set_filter(self): "SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc", "SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc", ] + + def test_sql_with_filters_with_join(self): + query_builder = QueryBuilder(MULTI_JOIN_SCHEMA) + + json_str = { + "type": "bar", + "dimensions": ["Engagement.activity_type"], + "measures": ["Sales.total_revenue"], + "timeDimensions": [], + "options": { + "xLabel": "Activity Type", + "yLabel": "Total Revenue", + "title": "Total Revenue Generated from Users who Logged in Before Purchase", + "legend": {"display": True, "position": "top"}, + }, + "joins": [ + { + "name": "Engagement", + "join_type": "right", + "sql": "${Sales.id} = ${Engagement.id}", + } + ], + "filters": [ + { + "member": "Engagement.engagement_date", + "operator": "beforeDate", + "values": ["${Sales.sales_date}"], + } + ], + "order": [{"id": "Sales.total_revenue", "direction": "asc"}], + } + sql_query = query_builder.generate_sql(json_str) + + assert ( + sql_query + == "SELECT `engagement`.`activity_type` AS activity_type, SUM(`sales`.`revenue`) AS total_revenue FROM `sales` RIGHT JOIN `engagement` ON `engagement`.`id` = `sales`.`id` WHERE `engagement`.`engagement_date` < '${Sales.sales_date}' GROUP BY activity_type ORDER BY total_revenue asc" + )