Skip to content

Commit

Permalink
fix(semantic_agent): json load to also look for json in backtick
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Jun 19, 2024
1 parent 5a30ef3 commit c0f9003
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 80 deletions.
3 changes: 2 additions & 1 deletion pandasai/ee/agents/semantic_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import (
GenerateDFSchemaPrompt,
)
from pandasai.ee.helpers.json_helper import extract_json_from_json_str
from pandasai.exceptions import InvalidConfigError, InvalidSchemaJson, InvalidTrainJson
from pandasai.helpers.cache import Cache
from pandasai.helpers.memory import Memory
Expand Down Expand Up @@ -186,7 +187,7 @@ def _create_schema(self):
"""
)
self._schema = result.replace("# SAMPLE SCHEMA", "")
schema_data = json.loads(result.replace("# SAMPLE SCHEMA", ""))
schema_data = extract_json_from_json_str(result.replace("# SAMPLE SCHEMA", ""))
if isinstance(schema_data, dict):
schema_data = [schema_data]

Expand Down
4 changes: 2 additions & 2 deletions pandasai/ee/agents/semantic_agent/pipeline/llm_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Any

from pandasai.ee.helpers.json_helper import extract_json_from_json_str
from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.pipelines.logic_unit_output import LogicUnitOutput
Expand Down Expand Up @@ -42,7 +42,7 @@ def execute(self, input: Any, **kwargs) -> Any:
)
try:
# Validate is valid Json
response_json = json.loads(response)
response_json = extract_json_from_json_str(response)

pipeline_context.add("llm_call", response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from jinja2 import Environment, FileSystemLoader

from pandasai.ee.helpers.json_helper import extract_json_from_json_str
from pandasai.prompts.base import BasePrompt


Expand Down Expand Up @@ -30,7 +31,9 @@ def __init__(self, **kwargs):

def validate(self, output: str) -> bool:
try:
json_data = json.loads(output.replace("# SAMPLE SCHEMA", ""))
json_data = extract_json_from_json_str(
output.replace("# SAMPLE SCHEMA", "")
)
context = self.props["context"]
if isinstance(json_data, dict):
json_data = [json_data]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,132 +1,147 @@
# SAMPLE SCHEMA
[
{
"name":"Contracts",
"table":"contracts",
"measures":[
"name": "Contracts",
"table": "contracts",
"measures": [
{
"name":"contract_count",
"type":"count",
"sql":"store_id"
"name": "contract_count",
"type": "count",
"sql": "store_id"
},
{
"name":"contract_duration",
"type":"number",
"sql":"${contract_end_date} - ${contract_start_date}"
"name": "contract_duration",
"type": "number",
"sql": "${contract_end_date} - ${contract_start_date}"
},
{
"name":"contract_avg_duration",
"type":"avg",
"sql":"${contract_duration}"
"name": "contract_avg_duration",
"type": "avg",
"sql": "${contract_duration}"
}
],
"dimensions":[
"dimensions": [
{
"name":"contract_code",
"type":"string",
"sql":"contract_code"
"name": "contract_code",
"type": "string",
"sql": "contract_code",
"samples": ["C12345", "C67890"]
},
{
"name":"store_id",
"type":"string",
"sql":"store_id"
"name": "store_id",
"type": "string",
"sql": "store_id",
"samples": ["S12345", "S67890"]
},
{
"name":"tenant_code",
"type":"string",
"sql":"tenant_code"
"name": "tenant_code",
"type": "string",
"sql": "tenant_code",
"samples": ["T12345", "T67890"]
},
{
"name":"tenant_name",
"type":"string",
"sql":"tenant_name"
"name": "tenant_name",
"type": "string",
"sql": "tenant_name",
"samples": ["Tenant A", "Tenant B"]
},
{
"name":"store_brand",
"type":"string",
"sql":"store_brand"
"name": "store_brand",
"type": "string",
"sql": "store_brand",
"samples": ["Brand X", "Brand Y"]
},
{
"name":"branch_segment_1",
"type":"string",
"sql":"branch_segment_1"
"name": "branch_segment_1",
"type": "string",
"sql": "branch_segment_1",
"samples": ["Segment 1", "Segment 2"]
},
{
"name":"branch_segment_2",
"type":"string",
"sql":"branch_segment_2"
"name": "branch_segment_2",
"type": "string",
"sql": "branch_segment_2",
"samples": ["Segment A", "Segment B"]
},
{
"name":"contract_start_date",
"type":"date",
"sql":"contract_start_date"
"name": "contract_start_date",
"type": "date",
"sql": "contract_start_date",
"samples": ["2023-01-01", "2023-02-01"]
},
{
"name":"contract_end_date",
"type":"date",
"sql":"contract_end_date"
"name": "contract_end_date",
"type": "date",
"sql": "contract_end_date",
"samples": ["2024-01-01", "2024-02-01"]
}
],
"joins":[
"joins": [
{
"name":"corrispettivi",
"join_type":"left",
"sql":"${Contracts.contract_code} = ${Fees.contract_id}"
"name": "Fee",
"join_type": "left",
"sql": "${Contracts.contract_code} = ${Fees.contract_id}"
}
]
},
{
"name":"Fees",
"table":"fees",
"measures":[
"name": "Fees",
"table": "fees",
"measures": [
{
"name":"total_taxable",
"type":"sum",
"sql":"imponibile_tot"
"name": "total_taxable",
"type": "sum",
"sql": "imponibile_tot"
},
{
"name":"total_revenue",
"type":"sum",
"sql":"totale_tot"
"name": "total_revenue",
"type": "sum",
"sql": "totale_tot"
}
],
"dimensions":[
"dimensions": [
{
"name":"contract_id",
"type":"string",
"sql":"contract_id"
"name": "contract_id",
"type": "string",
"sql": "contract_id",
"samples": ["C12345", "C67890"]
},
{
"name":"code",
"type":"string",
"sql":"code"
"name": "code",
"type": "string",
"sql": "code",
"samples": ["F12345", "F67890"]
},
{
"name":"station",
"type":"string",
"sql":"station"
"name": "station",
"type": "string",
"sql": "station",
"samples": ["Station X", "Station Y"]
},
{
"name":"tenant_id",
"type":"string",
"sql":"tenant_id"
"name": "tenant_id",
"type": "string",
"sql": "tenant_id",
"samples": ["T12345", "T67890"]
},
{
"name":"day",
"type":"date",
"sql":"day"
"name": "day",
"type": "date",
"sql": "day",
"samples": ["2023-01-01", "2023-02-01"]
},
{
"name":"store_id",
"type":"string",
"sql":"store_id"
"name": "store_id",
"type": "string",
"sql": "store_id",
"samples": ["S12345", "S67890"]
}
],
"joins":[
"joins": [
{
"name":"contracts",
"join_type":"right",
"sql":"${Fees.contract_id} = ${Fees.contract_code}"
"name": "Contracts",
"join_type": "right",
"sql": "${Fees.contract_id} = ${Contracts.contract_code}"
}
]
}
Expand Down
14 changes: 14 additions & 0 deletions pandasai/ee/helpers/json_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import json


def extract_json_from_json_str(json_str):
start_index = json_str.find("```json")

end_index = json_str.find("```", start_index)

if start_index == -1:
return json.loads(json_str)

json_data = json_str[(start_index + len("```json")) : end_index].strip()

Check warning on line 12 in pandasai/ee/helpers/json_helper.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/helpers/json_helper.py#L12

Added line #L12 was not covered by tests

return json.loads(json_data)

Check warning on line 14 in pandasai/ee/helpers/json_helper.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/helpers/json_helper.py#L14

Added line #L14 was not covered by tests

0 comments on commit c0f9003

Please sign in to comment.