diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index 1d0b1cd75..5338f861a 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -341,9 +341,7 @@ def train( "No vector store provided. Please provide a vector store to train the agent." ) - if (queries is not None and codes is None) or ( - queries is None and codes is not None - ): + if (queries and not codes) or (not queries and codes): raise ValueError( "If either queries or codes are provided, both must be provided." ) diff --git a/pandasai/ee/agents/semantic_agent/__init__.py b/pandasai/ee/agents/semantic_agent/__init__.py index a9144f57b..f3a2f0ad6 100644 --- a/pandasai/ee/agents/semantic_agent/__init__.py +++ b/pandasai/ee/agents/semantic_agent/__init__.py @@ -14,7 +14,7 @@ from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import ( GenerateDFSchemaPrompt, ) -from pandasai.exceptions import InvalidConfigError +from pandasai.exceptions import InvalidConfigError, InvalidTrainJson from pandasai.helpers.cache import Cache from pandasai.helpers.memory import Memory from pandasai.llm.bamboo_llm import BambooLLM @@ -84,6 +84,32 @@ def __init__( ) ) + def validate_and_convert_json(self, jsons): + json_strs = [] + + try: + for json_data in jsons: + if isinstance(json_data, str): + json.loads(json_data) + json_strs.append(json_data) + elif isinstance(json_data, dict): + json_strs.append(json.dumps(json_data)) + + except Exception as e: + raise InvalidTrainJson("Error validating JSON string") from e + + return json_strs + + def train( + self, + queries: Optional[List[str]] = None, + jsons: Optional[List[Union[dict, str]]] = None, + docs: Optional[List[str]] = None, + ) -> None: + json_strs = self.validate_and_convert_json(jsons) if jsons else None + + super().train(queries=queries, codes=json_strs, docs=docs) + def query(self, query): query_pipeline = Pipeline( context=self.context, diff --git a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py index e767cec2a..cb5789523 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py @@ -45,13 +45,15 @@ def execute(self, input: Any, **kwargs) -> Any: ) try: # Validate is valid Json - response = json.loads(response) + response_json = json.loads(response) + + pipeline_context.add("llm_call", response) return LogicUnitOutput( - response, + response_json, True, "Code Generated Successfully", - {"content_type": "string", "value": response}, + {"content_type": "string", "value": response_json}, ) except Exception: if retry_count == pipeline_context.config.max_retries: diff --git a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py index 729b47fb0..b511640b2 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py @@ -8,13 +8,18 @@ from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( SemanticPromptGeneration, ) +from pandasai.ee.agents.semantic_agent.pipeline.semantic_result_parsing import ( + SemanticResultParser, +) from pandasai.ee.agents.semantic_agent.pipeline.validate_pipeline_input import ( ValidatePipelineInput, ) from pandasai.helpers.logger import Logger from pandasai.pipelines.chat.cache_lookup import CacheLookup from pandasai.pipelines.chat.code_cleaning import CodeCleaning +from pandasai.pipelines.chat.code_execution import CodeExecution from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline +from pandasai.pipelines.chat.result_validation import ResultValidation from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext @@ -65,6 +70,23 @@ def __init__( ], ) + self.code_execution_pipeline = Pipeline( + context=context, + logger=logger, + query_exec_tracker=self.query_exec_tracker, + steps=[ + CodeExecution( + before_execution=before_code_execution, + on_failure=self.on_code_execution_failure, + on_retry=self.on_code_retry, + ), + ResultValidation(), + SemanticResultParser( + before_execution=on_result, + ), + ], + ) + self.code_exec_error_pipeline = ErrorCorrectionPipeline( context=context, logger=logger, diff --git a/pandasai/ee/agents/semantic_agent/pipeline/semantic_result_parsing.py b/pandasai/ee/agents/semantic_agent/pipeline/semantic_result_parsing.py new file mode 100644 index 000000000..f897e81f2 --- /dev/null +++ b/pandasai/ee/agents/semantic_agent/pipeline/semantic_result_parsing.py @@ -0,0 +1,23 @@ +from pandasai.pipelines.chat.result_parsing import ResultParsing +from pandasai.pipelines.pipeline_context import PipelineContext + + +class SemanticResultParser(ResultParsing): + """ + Semantic Agent Result Parsing Stage + """ + + pass + + def _add_result_to_memory(self, result: dict, context: PipelineContext): + """ + Add the result to the memory. + + Args: + result (dict): The result to add to the memory + context (PipelineContext) : Pipeline Context + """ + if result is None: + return + + context.memory.add(context.get("llm_call"), False) diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/semantic_agent_prompt.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/semantic_agent_prompt.tmpl index 83524b7b3..06fa68338 100644 --- a/pandasai/ee/agents/semantic_agent/prompts/templates/semantic_agent_prompt.tmpl +++ b/pandasai/ee/agents/semantic_agent/prompts/templates/semantic_agent_prompt.tmpl @@ -1,5 +1,5 @@ === SemanticAgent === - +{% include 'shared/vectordb_docs.tmpl' with context %} # SCHEMA {{schema}} diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/shared/vectordb_docs.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/shared/vectordb_docs.tmpl new file mode 100644 index 000000000..0fe6be43a --- /dev/null +++ b/pandasai/ee/agents/semantic_agent/prompts/templates/shared/vectordb_docs.tmpl @@ -0,0 +1,8 @@ +{% if context.vectorstore %}{% set documents = context.vectorstore.get_relevant_qa_documents(context.memory.get_last_message()) %} +{% if documents|length > 0%}You can utilize these examples as a reference for generating json.{% endif %} +{% for document in documents %} +{{ document}}{% endfor %}{% endif %} +{% if context.vectorstore %}{% set documents = context.vectorstore.get_relevant_docs_documents(context.memory.get_last_message()) %} +{% if documents|length > 0%}Here are additional documents for reference. Feel free to use them to answer.{% endif %} +{% for document in documents %}{{ document}} +{% endfor %}{% endif %} \ No newline at end of file diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 8d4c1ed49..ef8be33a7 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -246,3 +246,11 @@ class PandasConnectorTableNotFound(Exception): Args: Exception (Exception): PandasConnectorTableNotFound """ + + +class InvalidTrainJson(Exception): + """ + Raise error if train json is not correct + Args: + Exception (Exception): Invalid train json + """ diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py b/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py index 5031ea2dd..0226cc24c 100644 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py +++ b/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py @@ -12,6 +12,7 @@ SQLConnectorConfig, ) from pandasai.ee.agents.semantic_agent import SemanticAgent +from pandasai.exceptions import InvalidTrainJson from pandasai.helpers.dataframe_serializer import DataframeSerializerType from pandasai.llm.bamboo_llm import BambooLLM from pandasai.llm.fake import FakeLLM @@ -121,7 +122,9 @@ def pgsql_connector(self, create_engine): return PostgreSQLConnector(self.config) @pytest.fixture - def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent: + def agent(self, sample_df: pd.DataFrame) -> Agent: + llm = MockBambooLLM() + config = {"llm": llm} return SemanticAgent(sample_df, config, vectorstore=MagicMock()) def test_base_agent_contruct(self, sample_df): @@ -189,3 +192,40 @@ def test_cache_of_schema(self, mock_cache_get, sample_df): assert not llm.call.called assert agent._schema == VIZ_QUERY_SCHEMA + + def test_train_method_with_qa(self, agent): + queries = ["query1"] + jsons = ['{"name": "test"}'] + agent.train(queries=queries, jsons=jsons) + + agent._vectorstore.add_docs.assert_not_called() + agent._vectorstore.add_question_answer.assert_called_once_with(queries, jsons) + + def test_train_method_with_docs(self, agent): + docs = ["doc1"] + agent.train(docs=docs) + + agent._vectorstore.add_question_answer.assert_not_called() + agent._vectorstore.add_docs.assert_called_once() + agent._vectorstore.add_docs.assert_called_once_with(docs) + + def test_train_method_with_docs_and_qa(self, agent): + docs = ["doc1"] + queries = ["query1"] + jsons = ['{"name": "test"}'] + agent.train(queries, jsons, docs=docs) + + agent._vectorstore.add_question_answer.assert_called_once() + agent._vectorstore.add_question_answer.assert_called_once_with(queries, jsons) + agent._vectorstore.add_docs.assert_called_once() + agent._vectorstore.add_docs.assert_called_once_with(docs) + + def test_train_method_with_queries_but_no_code(self, agent): + queries = ["query1", "query2"] + with pytest.raises(ValueError): + agent.train(queries) + + def test_train_method_with_code_but_no_queries(self, agent): + jsons = ["code1", "code2"] + with pytest.raises(InvalidTrainJson): + agent.train(jsons=jsons) diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_semantic_prompt_gen.py b/tests/unit_tests/ee/semantic_agent/test_semantic_semantic_prompt_gen.py index 83daef4ba..2e770b131 100644 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_semantic_prompt_gen.py +++ b/tests/unit_tests/ee/semantic_agent/test_semantic_semantic_prompt_gen.py @@ -154,6 +154,7 @@ def test_validate_input_semantic_prompt(self, sample_df, context, logger): response.output.to_string() == """=== SemanticAgent === + # SCHEMA [{"name": "Orders", "table": "orders", "measures": [{"name": "order_count", "type": "count"}, {"name": "total_freight", "type": "sum", "sql": "freight"}], "dimensions": [{"name": "order_id", "type": "int", "sql": "order_id"}, {"name": "customer_id", "type": "string", "sql": "customer_id"}, {"name": "employee_id", "type": "int", "sql": "employee_id"}, {"name": "order_date", "type": "date", "sql": "order_date"}, {"name": "required_date", "type": "date", "sql": "required_date"}, {"name": "shipped_date", "type": "date", "sql": "shipped_date"}, {"name": "ship_via", "type": "int", "sql": "ship_via"}, {"name": "ship_name", "type": "string", "sql": "ship_name"}, {"name": "ship_address", "type": "string", "sql": "ship_address"}, {"name": "ship_city", "type": "string", "sql": "ship_city"}, {"name": "ship_region", "type": "string", "sql": "ship_region"}, {"name": "ship_postal_code", "type": "string", "sql": "ship_postal_code"}, {"name": "ship_country", "type": "string", "sql": "ship_country"}], "joins": []}]