diff --git a/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py b/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py index 2f8a87efa..0b01e82fd 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py @@ -86,7 +86,8 @@ def execute(self, input_data: Any, **kwargs) -> Any: def _get_type(self, input: dict) -> bool: return ( "plot" - if input["type"] in ["bar", "line", "histogram", "pie", "scatter"] + if input["type"] + in ["bar", "line", "histogram", "pie", "scatter", "boxplot"] else input["type"] ) diff --git a/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py b/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py index c50d26752..fb1f261b6 100644 --- a/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py +++ b/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py @@ -307,11 +307,10 @@ def test_generate_matplolib_boxplot_chart_code( logic_unit = code_gen.execute(json_str, context=context, logger=logger) assert isinstance(logic_unit, LogicUnitOutput) - print(logic_unit.output) assert ( logic_unit.output == """ - +import matplotlib.pyplot as plt import pandas as pd sql_query="SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country"