Skip to content

Commit

Permalink
Merge pull request #1233 from Sinaptik-AI/semantic_schema_timestamp
Browse files Browse the repository at this point in the history
fix: error prompt generation and extra validation on schema generation
  • Loading branch information
ArslanSaleem authored Jun 14, 2024
2 parents 2bbbb92 + 4f1e527 commit f594212
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
11 changes: 6 additions & 5 deletions pandasai/ee/agents/semantic_agent/pipeline/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ def execute(self, input_data: Any, **kwargs) -> Any:

traceback_errors = traceback.format_exc()

input_data = self.on_failure(input, traceback_errors)
input_data = self.on_failure(input_data, traceback_errors)

retry_count += 1

def _get_type(self, input: dict) -> bool:
return (
"plot"
if input["type"] in ["bar", "line", "histogram", "pie", "scatter"]
if input["type"]
in ["bar", "line", "histogram", "pie", "scatter", "boxplot"]
else input["type"]
)

Expand All @@ -99,7 +100,7 @@ def _generate_code(self, type, query):
"""
elif type == "dataframe":
return """
result = {{"type": "dataframe","value": data}}
result = {"type": "dataframe","value": data}
"""
else:
code = self.generate_matplotlib_code(query)
Expand All @@ -119,8 +120,8 @@ def _generate_code_for_number(self, query: dict) -> str:

def generate_matplotlib_code(self, query: dict) -> str:
chart_type = query["type"]
x_label = query["options"].get("xLabel", None)
y_label = query["options"].get("yLabel", None)
x_label = query.get("options", {}).get("xLabel", None)
y_label = query.get("options", {}).get("yLabel", None)
title = query["options"].get("title", None)
legend_display = {"display": True}
legend_position = "best"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
on_code_generation=on_code_generation,
on_prompt_generation=on_prompt_generation,
)
self.query_exec_tracker = query_exec_tracker

self._context = context
self._logger = logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ def __init__(self, **kwargs):
def validate(self, output: str) -> bool:
try:
json_data = json.loads(output.replace("# SAMPLE SCHEMA", ""))
context = self.props["context"]
if isinstance(json_data, dict):
json_data = [json_data]
if isinstance(json_data, list):
for record in json_data:
if not all(key in record for key in ("name", "table")):
return False
return True

return len(context.dfs) == len(json_data)

except json.JSONDecodeError:
pass
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,10 @@ def test_generate_matplolib_boxplot_chart_code(

logic_unit = code_gen.execute(json_str, context=context, logger=logger)
assert isinstance(logic_unit, LogicUnitOutput)
print(logic_unit.output)
assert (
logic_unit.output
== """
import matplotlib.pyplot as plt
import pandas as pd
sql_query="SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country"
Expand Down

0 comments on commit f594212

Please sign in to comment.