Skip to content

Commit

Permalink
feat(train): semantic agent to accept jsons (#1208)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem authored Jun 5, 2024
1 parent bdd5ce4 commit 9bcdc58
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 9 deletions.
4 changes: 1 addition & 3 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,7 @@ def train(
"No vector store provided. Please provide a vector store to train the agent."
)

if (queries is not None and codes is None) or (
queries is None and codes is not None
):
if (queries and not codes) or (not queries and codes):
raise ValueError(
"If either queries or codes are provided, both must be provided."
)
Expand Down
28 changes: 27 additions & 1 deletion pandasai/ee/agents/semantic_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import (
GenerateDFSchemaPrompt,
)
from pandasai.exceptions import InvalidConfigError
from pandasai.exceptions import InvalidConfigError, InvalidTrainJson
from pandasai.helpers.cache import Cache
from pandasai.helpers.memory import Memory
from pandasai.llm.bamboo_llm import BambooLLM
Expand Down Expand Up @@ -84,6 +84,32 @@ def __init__(
)
)

def validate_and_convert_json(self, jsons):
json_strs = []

try:
for json_data in jsons:
if isinstance(json_data, str):
json.loads(json_data)
json_strs.append(json_data)
elif isinstance(json_data, dict):
json_strs.append(json.dumps(json_data))

except Exception as e:
raise InvalidTrainJson("Error validating JSON string") from e

return json_strs

def train(
self,
queries: Optional[List[str]] = None,
jsons: Optional[List[Union[dict, str]]] = None,
docs: Optional[List[str]] = None,
) -> None:
json_strs = self.validate_and_convert_json(jsons) if jsons else None

super().train(queries=queries, codes=json_strs, docs=docs)

def query(self, query):
query_pipeline = Pipeline(
context=self.context,
Expand Down
8 changes: 5 additions & 3 deletions pandasai/ee/agents/semantic_agent/pipeline/llm_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def execute(self, input: Any, **kwargs) -> Any:
)
try:
# Validate is valid Json
response = json.loads(response)
response_json = json.loads(response)

pipeline_context.add("llm_call", response)

return LogicUnitOutput(
response,
response_json,
True,
"Code Generated Successfully",
{"content_type": "string", "value": response},
{"content_type": "string", "value": response_json},
)
except Exception:
if retry_count == pipeline_context.config.max_retries:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import (
SemanticPromptGeneration,
)
from pandasai.ee.agents.semantic_agent.pipeline.semantic_result_parsing import (
SemanticResultParser,
)
from pandasai.ee.agents.semantic_agent.pipeline.validate_pipeline_input import (
ValidatePipelineInput,
)
from pandasai.helpers.logger import Logger
from pandasai.pipelines.chat.cache_lookup import CacheLookup
from pandasai.pipelines.chat.code_cleaning import CodeCleaning
from pandasai.pipelines.chat.code_execution import CodeExecution
from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline
from pandasai.pipelines.chat.result_validation import ResultValidation
from pandasai.pipelines.pipeline import Pipeline
from pandasai.pipelines.pipeline_context import PipelineContext

Expand Down Expand Up @@ -65,6 +70,23 @@ def __init__(
],
)

self.code_execution_pipeline = Pipeline(
context=context,
logger=logger,
query_exec_tracker=self.query_exec_tracker,
steps=[
CodeExecution(
before_execution=before_code_execution,
on_failure=self.on_code_execution_failure,
on_retry=self.on_code_retry,
),
ResultValidation(),
SemanticResultParser(
before_execution=on_result,
),
],
)

self.code_exec_error_pipeline = ErrorCorrectionPipeline(
context=context,
logger=logger,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pandasai.pipelines.chat.result_parsing import ResultParsing
from pandasai.pipelines.pipeline_context import PipelineContext


class SemanticResultParser(ResultParsing):
"""
Semantic Agent Result Parsing Stage
"""

pass

def _add_result_to_memory(self, result: dict, context: PipelineContext):
"""
Add the result to the memory.
Args:
result (dict): The result to add to the memory
context (PipelineContext) : Pipeline Context
"""
if result is None:
return

context.memory.add(context.get("llm_call"), False)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
=== SemanticAgent ===

{% include 'shared/vectordb_docs.tmpl' with context %}
# SCHEMA
{{schema}}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{% if context.vectorstore %}{% set documents = context.vectorstore.get_relevant_qa_documents(context.memory.get_last_message()) %}
{% if documents|length > 0%}You can utilize these examples as a reference for generating json.{% endif %}
{% for document in documents %}
{{ document}}{% endfor %}{% endif %}
{% if context.vectorstore %}{% set documents = context.vectorstore.get_relevant_docs_documents(context.memory.get_last_message()) %}
{% if documents|length > 0%}Here are additional documents for reference. Feel free to use them to answer.{% endif %}
{% for document in documents %}{{ document}}
{% endfor %}{% endif %}
8 changes: 8 additions & 0 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,11 @@ class PandasConnectorTableNotFound(Exception):
Args:
Exception (Exception): PandasConnectorTableNotFound
"""


class InvalidTrainJson(Exception):
"""
Raise error if train json is not correct
Args:
Exception (Exception): Invalid train json
"""
42 changes: 41 additions & 1 deletion tests/unit_tests/ee/semantic_agent/test_semantic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SQLConnectorConfig,
)
from pandasai.ee.agents.semantic_agent import SemanticAgent
from pandasai.exceptions import InvalidTrainJson
from pandasai.helpers.dataframe_serializer import DataframeSerializerType
from pandasai.llm.bamboo_llm import BambooLLM
from pandasai.llm.fake import FakeLLM
Expand Down Expand Up @@ -121,7 +122,9 @@ def pgsql_connector(self, create_engine):
return PostgreSQLConnector(self.config)

@pytest.fixture
def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent:
def agent(self, sample_df: pd.DataFrame) -> Agent:
llm = MockBambooLLM()
config = {"llm": llm}
return SemanticAgent(sample_df, config, vectorstore=MagicMock())

def test_base_agent_contruct(self, sample_df):
Expand Down Expand Up @@ -189,3 +192,40 @@ def test_cache_of_schema(self, mock_cache_get, sample_df):

assert not llm.call.called
assert agent._schema == VIZ_QUERY_SCHEMA

def test_train_method_with_qa(self, agent):
queries = ["query1"]
jsons = ['{"name": "test"}']
agent.train(queries=queries, jsons=jsons)

agent._vectorstore.add_docs.assert_not_called()
agent._vectorstore.add_question_answer.assert_called_once_with(queries, jsons)

def test_train_method_with_docs(self, agent):
docs = ["doc1"]
agent.train(docs=docs)

agent._vectorstore.add_question_answer.assert_not_called()
agent._vectorstore.add_docs.assert_called_once()
agent._vectorstore.add_docs.assert_called_once_with(docs)

def test_train_method_with_docs_and_qa(self, agent):
docs = ["doc1"]
queries = ["query1"]
jsons = ['{"name": "test"}']
agent.train(queries, jsons, docs=docs)

agent._vectorstore.add_question_answer.assert_called_once()
agent._vectorstore.add_question_answer.assert_called_once_with(queries, jsons)
agent._vectorstore.add_docs.assert_called_once()
agent._vectorstore.add_docs.assert_called_once_with(docs)

def test_train_method_with_queries_but_no_code(self, agent):
queries = ["query1", "query2"]
with pytest.raises(ValueError):
agent.train(queries)

def test_train_method_with_code_but_no_queries(self, agent):
jsons = ["code1", "code2"]
with pytest.raises(InvalidTrainJson):
agent.train(jsons=jsons)
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_validate_input_semantic_prompt(self, sample_df, context, logger):
response.output.to_string()
== """=== SemanticAgent ===
# SCHEMA
[{"name": "Orders", "table": "orders", "measures": [{"name": "order_count", "type": "count"}, {"name": "total_freight", "type": "sum", "sql": "freight"}], "dimensions": [{"name": "order_id", "type": "int", "sql": "order_id"}, {"name": "customer_id", "type": "string", "sql": "customer_id"}, {"name": "employee_id", "type": "int", "sql": "employee_id"}, {"name": "order_date", "type": "date", "sql": "order_date"}, {"name": "required_date", "type": "date", "sql": "required_date"}, {"name": "shipped_date", "type": "date", "sql": "shipped_date"}, {"name": "ship_via", "type": "int", "sql": "ship_via"}, {"name": "ship_name", "type": "string", "sql": "ship_name"}, {"name": "ship_address", "type": "string", "sql": "ship_address"}, {"name": "ship_city", "type": "string", "sql": "ship_city"}, {"name": "ship_region", "type": "string", "sql": "ship_region"}, {"name": "ship_postal_code", "type": "string", "sql": "ship_postal_code"}, {"name": "ship_country", "type": "string", "sql": "ship_country"}], "joins": []}]
Expand Down

0 comments on commit 9bcdc58

Please sign in to comment.