From b520871346725c92ca3b4a6c8e1a1ab1bc2dcbef Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 18 Jun 2024 00:08:33 +0200 Subject: [PATCH 1/6] feat(Judge): implementation of judge agent to validate code matches the user query --- examples/judge_agent.py | 35 +++ pandasai/agent/agent.py | 4 + pandasai/agent/base_judge.py | 18 ++ pandasai/ee/agents/judge_agent/__init__.py | 30 +++ .../judge_agent/pipeline/judge_pipeline.py | 34 +++ .../pipeline/judge_prompt_generation.py | 43 ++++ .../agents/judge_agent/pipeline/llm_call.py | 64 +++++ .../judge_agent/prompts/judge_agent_prompt.py | 39 +++ .../prompts/templates/judge_agent_prompt.tmpl | 10 + pandasai/ee/agents/semantic_agent/__init__.py | 4 + .../error_correction_pipeline.py | 2 +- .../fix_semantic_json_pipeline.py | 2 +- .../semantic_agent/pipeline/llm_call.py | 7 +- .../pipeline/semantic_chat_pipeline.py | 5 +- .../pipelines/chat/generate_chat_pipeline.py | 23 +- .../pipelines/judge/judge_pipeline_input.py | 11 + pandasai/pipelines/pipeline.py | 32 ++- .../ee/judge_agent/test_judge_agent.py | 229 ++++++++++++++++++ .../ee/judge_agent/test_judge_llm_call.py | 180 ++++++++++++++ .../ee/judge_agent/test_judge_prompt_gen.py | 170 +++++++++++++ .../semantic_agent/test_semantic_llm_call.py | 5 +- tests/unit_tests/pipelines/test_pipeline.py | 19 ++ 22 files changed, 950 insertions(+), 16 deletions(-) create mode 100644 examples/judge_agent.py create mode 100644 pandasai/agent/base_judge.py create mode 100644 pandasai/ee/agents/judge_agent/__init__.py create mode 100644 pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py create mode 100644 pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py create mode 100644 pandasai/ee/agents/judge_agent/pipeline/llm_call.py create mode 100644 pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py create mode 100644 pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl create mode 100644 pandasai/pipelines/judge/judge_pipeline_input.py create mode 100644 tests/unit_tests/ee/judge_agent/test_judge_agent.py create mode 100644 tests/unit_tests/ee/judge_agent/test_judge_llm_call.py create mode 100644 tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py diff --git a/examples/judge_agent.py b/examples/judge_agent.py new file mode 100644 index 000000000..dc930600f --- /dev/null +++ b/examples/judge_agent.py @@ -0,0 +1,35 @@ +import os + +import pandas as pd + +from pandasai.agent.agent import Agent +from pandasai.ee.agents.judge_agent import JudgeAgent +from pandasai.llm.openai import OpenAI + + +os.environ["PANDASAI_API_KEY"] = "$2a****************************" + +github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv") + +judge = JudgeAgent() +agent = Agent([github_stars], judge=judge) + +print(agent.chat("return total stars count")) + + +# Using Judge standalone +llm = OpenAI("openai_key") +judge_agent = JudgeAgent(config={"llm": llm}) +judge_agent.evaluate( + query="return total github star count for year 2023", + code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc" + data = execute_sql_query(sql_query) + plt.plot(data['starred_at_by_month'], data['user_count']) + plt.xlabel('Month') + plt.ylabel('User Count') + plt.title('GitHub Star Count Per Month - Year 2023') + plt.legend(loc='best') + plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png') + result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'} + """, +) diff --git a/pandasai/agent/agent.py b/pandasai/agent/agent.py index 179ae8c23..8a3fbfcc8 100644 --- a/pandasai/agent/agent.py +++ b/pandasai/agent/agent.py @@ -3,6 +3,7 @@ import pandas as pd from pandasai.agent.base import BaseAgent +from pandasai.agent.base_judge import BaseJudge from pandasai.connectors.base import BaseConnector from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.schemas.df_config import Config @@ -20,6 +21,7 @@ def __init__( pipeline: Optional[Type[GenerateChatPipeline]] = None, vectorstore: Optional[VectorStore] = None, description: str = None, + judge: BaseJudge = None, ): super().__init__(dfs, config, memory_size, vectorstore, description) @@ -31,6 +33,7 @@ def __init__( on_code_generation=self._callbacks.on_code_generation, before_code_execution=self._callbacks.before_code_execution, on_result=self._callbacks.on_result, + judge=judge, ) if pipeline else GenerateChatPipeline( @@ -40,6 +43,7 @@ def __init__( on_code_generation=self._callbacks.on_code_generation, before_code_execution=self._callbacks.before_code_execution, on_result=self._callbacks.on_result, + judge=judge, ) ) diff --git a/pandasai/agent/base_judge.py b/pandasai/agent/base_judge.py new file mode 100644 index 000000000..d4c6b6136 --- /dev/null +++ b/pandasai/agent/base_judge.py @@ -0,0 +1,18 @@ +from pandasai.helpers.logger import Logger +from pandasai.pipelines.pipeline import Pipeline +from pandasai.pipelines.pipeline_context import PipelineContext + + +class BaseJudge: + context: PipelineContext + pipeline: Pipeline + logger: Logger + + def __init__( + self, + pipeline: Pipeline, + ) -> None: + self.pipeline = pipeline + + def evaluate(self, query: str, code: str) -> bool: + raise NotImplementedError diff --git a/pandasai/ee/agents/judge_agent/__init__.py b/pandasai/ee/agents/judge_agent/__init__.py new file mode 100644 index 000000000..e33d43a40 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/__init__.py @@ -0,0 +1,30 @@ +from typing import Optional, Union +from pandasai.agent.base_judge import BaseJudge +from pandasai.config import load_config_from_json +from pandasai.ee.agents.judge_agent.pipeline.judge_pipeline import JudgePipeline +from pandasai.pipelines.abstract_pipeline import AbstractPipeline +from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput +from pandasai.pipelines.pipeline_context import PipelineContext +from pandasai.schemas.df_config import Config + + +class JudgeAgent(BaseJudge): + def __init__( + self, + pipeline: AbstractPipeline = None, + config: Optional[Union[Config, dict]] = None, + ) -> None: + context = None + if config: + if isinstance(config, dict): + config = Config(**load_config_from_json(config)) + + connectors = context + context = PipelineContext(connectors, config) + + pipeline = pipeline or JudgePipeline(context=context) + super().__init__(pipeline) + + def evaluate(self, query: str, code: str) -> bool: + input_data = JudgePipelineInput(query, code) + return self.pipeline.run(input_data) diff --git a/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py b/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py new file mode 100644 index 000000000..8797a6ee0 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py @@ -0,0 +1,34 @@ +from typing import Optional + +from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import ( + JudgePromptGeneration, +) +from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall +from pandasai.helpers.logger import Logger +from pandasai.helpers.query_exec_tracker import QueryExecTracker +from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput +from pandasai.pipelines.pipeline import Pipeline +from pandasai.pipelines.pipeline_context import PipelineContext + + +class JudgePipeline: + def __init__( + self, + context: Optional[PipelineContext] = None, + logger: Optional[Logger] = None, + query_exec_tracker: QueryExecTracker = None, + ): + self.query_exec_tracker = query_exec_tracker + + self.pipeline = Pipeline( + context=context, + logger=logger, + query_exec_tracker=self.query_exec_tracker, + steps=[ + JudgePromptGeneration(), + LLMCall(), + ], + ) + + def run(self, input: JudgePipelineInput): + return self.pipeline.run(input) diff --git a/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py b/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py new file mode 100644 index 000000000..23dc7c11e --- /dev/null +++ b/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py @@ -0,0 +1,43 @@ +from typing import Any + +from pandasai.ee.agents.judge_agent.prompts.judge_agent_prompt import JudgeAgentPrompt +from pandasai.helpers.logger import Logger +from pandasai.pipelines.base_logic_unit import BaseLogicUnit +from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput +from pandasai.pipelines.logic_unit_output import LogicUnitOutput + + +class JudgePromptGeneration(BaseLogicUnit): + """ + Code Prompt Generation Stage + """ + + pass + + def execute(self, input_data: JudgePipelineInput, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Last logic unit output + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: LogicUnitOutput(prompt) + """ + self.context = kwargs.get("context") + self.logger: Logger = kwargs.get("logger") + + prompt = JudgeAgentPrompt( + query=input_data.query, code=input_data.code, context=self.context + ) + self.logger.log(f"Using prompt: {prompt}") + + return LogicUnitOutput( + prompt, + True, + "Prompt Generated Successfully", + {"content_type": "prompt", "value": prompt.to_string()}, + ) diff --git a/pandasai/ee/agents/judge_agent/pipeline/llm_call.py b/pandasai/ee/agents/judge_agent/pipeline/llm_call.py new file mode 100644 index 000000000..47758b263 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/pipeline/llm_call.py @@ -0,0 +1,64 @@ +from typing import Any + +from pandasai.exceptions import InvalidOutputValueMismatch +from pandasai.helpers.logger import Logger +from pandasai.pipelines.base_logic_unit import BaseLogicUnit +from pandasai.pipelines.logic_unit_output import LogicUnitOutput +from pandasai.pipelines.pipeline_context import PipelineContext + + +class LLMCall(BaseLogicUnit): + """ + LLM Code Generation Stage + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + logger: Logger = kwargs.get("logger") + + retry_count = 0 + while retry_count <= pipeline_context.config.max_retries: + response = pipeline_context.config.llm.call(input, pipeline_context) + + logger.log( + f"""LLM response: + {response} + """ + ) + try: + result = False + if "" in response: + result = True + elif "" in response: + result = False + else: + raise InvalidOutputValueMismatch("Invalid response of LLM Call") + + pipeline_context.add("llm_call", response) + + return LogicUnitOutput( + result, + True, + "Code Generated Successfully", + {"content_type": "string", "value": response}, + ) + except Exception: + if retry_count == pipeline_context.config.max_retries: + raise + + retry_count += 1 diff --git a/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py b/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py new file mode 100644 index 000000000..91616aaf8 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py @@ -0,0 +1,39 @@ +from pathlib import Path + +from jinja2 import Environment, FileSystemLoader + +from pandasai.prompts.base import BasePrompt + + +class JudgeAgentPrompt(BasePrompt): + """Prompt to generate Python code from a dataframe.""" + + template_path = "judge_agent_prompt.tmpl" + + def __init__(self, **kwargs): + """Initialize the prompt.""" + self.props = kwargs + + if self.template: + env = Environment() + self.prompt = env.from_string(self.template) + elif self.template_path: + # find path to template file + current_dir_path = Path(__file__).parent + + path_to_template = current_dir_path / "templates" + env = Environment(loader=FileSystemLoader(path_to_template)) + self.prompt = env.get_template(self.template_path) + + self._resolved_prompt = None + + def to_json(self): + context = self.props["context"] + memory = context.memory + conversations = memory.to_json() + system_prompt = memory.get_system_prompt() + return { + "conversation": conversations, + "system_prompt": system_prompt, + "prompt": self.to_string(), + } diff --git a/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl b/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl new file mode 100644 index 000000000..a444781cd --- /dev/null +++ b/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl @@ -0,0 +1,10 @@ +### QUERY +{{query}} +### GENERATED CODE +{{code}} + +Reason step by step and at the end answer: +1. Explain what the code does +2. Explain what the user query asks for +3. Strictly compare the query with the code that is generated +Always return or if exactly meets the requirements diff --git a/pandasai/ee/agents/semantic_agent/__init__.py b/pandasai/ee/agents/semantic_agent/__init__.py index b28642cd7..dc8aee31b 100644 --- a/pandasai/ee/agents/semantic_agent/__init__.py +++ b/pandasai/ee/agents/semantic_agent/__init__.py @@ -4,6 +4,7 @@ import pandas as pd from pandasai.agent.base import BaseAgent +from pandasai.agent.base_judge import BaseJudge from pandasai.connectors.base import BaseConnector from pandasai.connectors.pandas import PandasConnector from pandasai.constants import PANDASBI_SETUP_MESSAGE @@ -41,6 +42,7 @@ def __init__( pipeline: Optional[Type[GenerateChatPipeline]] = None, vectorstore: Optional[VectorStore] = None, description: str = None, + judge: BaseJudge = None, ): super().__init__(dfs, config, memory_size, vectorstore, description) @@ -70,6 +72,7 @@ def __init__( pipeline( self.context, self.logger, + judge=judge, on_prompt_generation=self._callbacks.on_prompt_generation, on_code_generation=self._callbacks.on_code_generation, before_code_execution=self._callbacks.before_code_execution, @@ -79,6 +82,7 @@ def __init__( else SemanticChatPipeline( self.context, self.logger, + judge=judge, on_prompt_generation=self._callbacks.on_prompt_generation, on_code_generation=self._callbacks.on_code_generation, before_code_execution=self._callbacks.before_code_execution, diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py index 65e34db94..79e4b3dd7 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py @@ -4,7 +4,6 @@ from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import ( FixSemanticJsonPipeline, ) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( SemanticPromptGeneration, ) @@ -14,6 +13,7 @@ from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( ErrorCorrectionPipelineInput, ) +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py index 3ec39ea40..0d5479871 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py @@ -3,12 +3,12 @@ from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_schema_prompt import ( FixSemanticSchemaPrompt, ) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.helpers.logger import Logger from pandasai.helpers.query_exec_tracker import QueryExecTracker from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( ErrorCorrectionPipelineInput, ) +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext diff --git a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py index cb5789523..e9946140d 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable +from typing import Any from pandasai.helpers.logger import Logger from pandasai.pipelines.base_logic_unit import BaseLogicUnit @@ -12,11 +12,8 @@ class LLMCall(BaseLogicUnit): LLM Code Generation Stage """ - def __init__( - self, on_code_generation: Callable[[str, Exception], None] = None, **kwargs - ): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.on_code_generation = on_code_generation def execute(self, input: Any, **kwargs) -> Any: """ diff --git a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py index 14e9ea870..eb8ba480c 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py @@ -1,5 +1,6 @@ from typing import Optional +from pandasai.agent.base_judge import BaseJudge from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.error_correction_pipeline import ( ErrorCorrectionPipeline, @@ -7,7 +8,6 @@ from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import ( FixSemanticJsonPipeline, ) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( SemanticPromptGeneration, ) @@ -26,6 +26,7 @@ ) from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.pipelines.chat.result_validation import ResultValidation +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext @@ -41,6 +42,7 @@ def __init__( self, context: Optional[PipelineContext] = None, logger: Optional[Logger] = None, + judge: BaseJudge = None, on_prompt_generation=None, on_code_generation=None, before_code_execution=None, @@ -49,6 +51,7 @@ def __init__( super().__init__( context, logger, + judge=judge, on_prompt_generation=on_prompt_generation, on_code_generation=on_code_generation, before_code_execution=before_code_execution, diff --git a/pandasai/pipelines/chat/generate_chat_pipeline.py b/pandasai/pipelines/chat/generate_chat_pipeline.py index 94609fd68..61126451d 100644 --- a/pandasai/pipelines/chat/generate_chat_pipeline.py +++ b/pandasai/pipelines/chat/generate_chat_pipeline.py @@ -1,5 +1,6 @@ from typing import Optional +from pandasai.agent.base_judge import BaseJudge from pandasai.helpers.query_exec_tracker import QueryExecTracker from pandasai.pipelines.chat.chat_pipeline_input import ( ChatPipelineInput, @@ -41,6 +42,7 @@ def __init__( self, context: Optional[PipelineContext] = None, logger: Optional[Logger] = None, + judge: BaseJudge = None, on_prompt_generation=None, on_code_generation=None, before_code_execution=None, @@ -99,6 +101,13 @@ def __init__( on_prompt_generation=on_prompt_generation, ) + self.judge = judge + + if self.judge: + self.judge.pipeline.pipeline.context = context + self.judge.pipeline.pipeline.logger = logger + self.judge.pipeline.pipeline.query_exec_tracker = self.query_exec_tracker + self.context = context self._logger = logger self.last_error = None @@ -304,7 +313,19 @@ def run(self, input: ChatPipelineInput) -> dict: } ) try: - if self.code_execution_pipeline: + if self.judge: + code = self.code_generation_pipeline.run(input) + + retry_count = 0 + while retry_count < self.context.config.max_retries: + if self.judge.evaluate(query=input.query, code=code): + break + code = self.code_generation_pipeline.run(input) + retry_count += 1 + + output = self.code_execution_pipeline.run(code) + + elif self.code_execution_pipeline: output = ( self.code_generation_pipeline | self.code_execution_pipeline ).run(input) diff --git a/pandasai/pipelines/judge/judge_pipeline_input.py b/pandasai/pipelines/judge/judge_pipeline_input.py new file mode 100644 index 000000000..aaceea15c --- /dev/null +++ b/pandasai/pipelines/judge/judge_pipeline_input.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + + +@dataclass +class JudgePipelineInput: + query: str + code: str + + def __init__(self, query: str, code: str) -> None: + self.query = query + self.code = code diff --git a/pandasai/pipelines/pipeline.py b/pandasai/pipelines/pipeline.py index 59965007b..76ae0a57c 100644 --- a/pandasai/pipelines/pipeline.py +++ b/pandasai/pipelines/pipeline.py @@ -43,21 +43,21 @@ def __init__( logger: (Logger): logger """ - if not isinstance(context, PipelineContext): + if context and not isinstance(context, PipelineContext): config = Config(**load_config_from_json(config)) connectors = context context = PipelineContext(connectors, config) self._logger = ( Logger(save_logs=context.config.save_logs, verbose=context.config.verbose) - if logger is None + if logger is None and context else logger ) self._context = context self._steps = steps or [] - self._query_exec_tracker = query_exec_tracker or QueryExecTracker( - server_config=self._context.config.log_server + self._query_exec_tracker = query_exec_tracker or ( + context and QueryExecTracker(server_config=self._context.config.log_server) ) def add_step(self, logic: BaseLogicUnit): @@ -167,3 +167,27 @@ def __or__(self, pipeline: "Pipeline") -> Any: combined_pipeline.add_step(step) return combined_pipeline + + @property + def context(self): + return self._context + + @context.setter + def context(self, context: PipelineContext): + self._context = context + + @property + def logger(self): + return self._logger + + @logger.setter + def logger(self, logger: Logger): + self._logger = logger + + @property + def query_exec_tracker(self): + return self._query_exec_tracker + + @query_exec_tracker.setter + def query_exec_tracker(self, query_exec_tracker: QueryExecTracker): + self._query_exec_tracker = query_exec_tracker diff --git a/tests/unit_tests/ee/judge_agent/test_judge_agent.py b/tests/unit_tests/ee/judge_agent/test_judge_agent.py new file mode 100644 index 000000000..bae937445 --- /dev/null +++ b/tests/unit_tests/ee/judge_agent/test_judge_agent.py @@ -0,0 +1,229 @@ +from typing import Optional +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from pandasai.agent import Agent +from pandasai.connectors.sql import ( + PostgreSQLConnector, + SQLConnector, + SQLConnectorConfig, +) +from pandasai.ee.agents.judge_agent import JudgeAgent +from pandasai.helpers.dataframe_serializer import DataframeSerializerType +from pandasai.llm.bamboo_llm import BambooLLM +from pandasai.llm.fake import FakeLLM +from tests.unit_tests.ee.helpers.schema import ( + VIZ_QUERY_SCHEMA_STR, +) + + +class MockBambooLLM(BambooLLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.call = MagicMock(return_value=VIZ_QUERY_SCHEMA_STR) + + +class TestJudgeAgent: + "Unit tests for Agent class" + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "order_id": [ + 10248, + 10249, + 10250, + 10251, + 10252, + 10253, + 10254, + 10255, + 10256, + 10257, + ], + "customer_id": [ + "VINET", + "TOMSP", + "HANAR", + "VICTE", + "SUPRD", + "HANAR", + "CHOPS", + "RICSU", + "WELLI", + "HILAA", + ], + "employee_id": [5, 6, 4, 3, 4, 3, 4, 7, 3, 4], + "order_date": pd.to_datetime( + [ + "1996-07-04", + "1996-07-05", + "1996-07-08", + "1996-07-08", + "1996-07-09", + "1996-07-10", + "1996-07-11", + "1996-07-12", + "1996-07-15", + "1996-07-16", + ] + ), + "required_date": pd.to_datetime( + [ + "1996-08-01", + "1996-08-16", + "1996-08-05", + "1996-08-05", + "1996-08-06", + "1996-08-07", + "1996-08-08", + "1996-08-09", + "1996-08-12", + "1996-08-13", + ] + ), + "shipped_date": pd.to_datetime( + [ + "1996-07-16", + "1996-07-10", + "1996-07-12", + "1996-07-15", + "1996-07-11", + "1996-07-16", + "1996-07-23", + "1996-07-26", + "1996-07-17", + "1996-07-22", + ] + ), + "ship_via": [3, 1, 2, 1, 2, 2, 2, 3, 2, 1], + "ship_name": [ + "Vins et alcools Chevalier", + "Toms Spezialitäten", + "Hanari Carnes", + "Victuailles en stock", + "Suprêmes délices", + "Hanari Carnes", + "Chop-suey Chinese", + "Richter Supermarkt", + "Wellington Importadora", + "HILARION-Abastos", + ], + "ship_address": [ + "59 rue de l'Abbaye", + "Luisenstr. 48", + "Rua do Paço, 67", + "2, rue du Commerce", + "Boulevard Tirou, 255", + "Rua do Paço, 67", + "Hauptstr. 31", + "Starenweg 5", + "Rua do Mercado, 12", + "Carrera 22 con Ave. Carlos Soublette #8-35", + ], + "ship_city": [ + "Reims", + "Münster", + "Rio de Janeiro", + "Lyon", + "Charleroi", + "Rio de Janeiro", + "Bern", + "Genève", + "Resende", + "San Cristóbal", + ], + "ship_region": [ + "CJ", + None, + "RJ", + "RH", + None, + "RJ", + None, + None, + "SP", + "Táchira", + ], + "ship_postal_code": [ + "51100", + "44087", + "05454-876", + "69004", + "B-6000", + "05454-876", + "3012", + "1204", + "08737-363", + "5022", + ], + "ship_country": [ + "France", + "Germany", + "Brazil", + "France", + "Belgium", + "Brazil", + "Switzerland", + "Switzerland", + "Brazil", + "Venezuela", + ], + } + ) + + @pytest.fixture + def llm(self, output: Optional[str] = None) -> FakeLLM: + return FakeLLM(output=output) + + @pytest.fixture + def config(self, llm: FakeLLM) -> dict: + return {"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV} + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def sql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return SQLConnector(self.config) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def pgsql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return PostgreSQLConnector(self.config) + + @pytest.fixture + def agent(self) -> Agent: + return JudgeAgent() + + def test_contruct_with_pipeline(self, sample_df): + JudgeAgent(pipeline=MagicMock()) diff --git a/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py b/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py new file mode 100644 index 000000000..35bb977a5 --- /dev/null +++ b/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py @@ -0,0 +1,180 @@ +from typing import Optional +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from pandasai.connectors.sql import ( + PostgreSQLConnector, + SQLConnector, + SQLConnectorConfig, +) + +from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall +from pandasai.exceptions import InvalidOutputValueMismatch +from pandasai.helpers.logger import Logger +from pandasai.llm.bamboo_llm import BambooLLM +from pandasai.llm.fake import FakeLLM +from pandasai.pipelines.logic_unit_output import LogicUnitOutput +from pandasai.pipelines.pipeline_context import PipelineContext +from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA_STR + + +class MockBambooLLM(BambooLLM): + def __init__(self): + pass + + def call(self, *args, **kwargs): + return VIZ_QUERY_SCHEMA_STR + + +class TestJudgeLLMCall: + "Unit test for Validate Pipeline Input" + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def sql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return SQLConnector(self.config) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def pgsql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="pgsql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return PostgreSQLConnector(self.config) + + @pytest.fixture + def config(self, llm): + return {"llm": llm, "enable_cache": True} + + @pytest.fixture + def context(self, sample_df, config): + return PipelineContext([sample_df], config) + + @pytest.fixture + def logger(self): + return Logger(True, False) + + def test_init(self, context, config): + # Test the initialization of the CodeGenerator + code_generator = LLMCall() + assert isinstance(code_generator, LLMCall) + + def test_llm_call(self, sample_df, context, logger, config): + input_validator = LLMCall() + + config["llm"].call = MagicMock(return_value="") + + context = PipelineContext([sample_df], config) + + result = input_validator.execute(input="test", context=context, logger=logger) + + assert isinstance(result, LogicUnitOutput) + assert result.output is True + + def test_llm_call_no(self, sample_df, context, logger, config): + input_validator = LLMCall() + + config["llm"].call = MagicMock(return_value="") + + context = PipelineContext([sample_df], config) + + result = input_validator.execute(input="test", context=context, logger=logger) + + assert isinstance(result, LogicUnitOutput) + assert result.output is False + + def test_llm_call_(self, sample_df, context, logger, config): + input_validator = LLMCall() + + config["llm"].call = MagicMock(return_value="") + + context = PipelineContext([sample_df], config) + + result = input_validator.execute(input="test", context=context, logger=logger) + + assert isinstance(result, LogicUnitOutput) + assert result.output is False + + def test_llm_call_with_no_tags(self, sample_df, context, logger, config): + input_validator = LLMCall() + + config["llm"].call = MagicMock(return_value="yes") + + context = PipelineContext([sample_df], config) + + with pytest.raises(InvalidOutputValueMismatch): + input_validator.execute(input="test", context=context, logger=logger) diff --git a/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py b/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py new file mode 100644 index 000000000..d670115bb --- /dev/null +++ b/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py @@ -0,0 +1,170 @@ +from typing import Optional +from unittest.mock import patch + +import pandas as pd +import pytest + +from pandasai.connectors.sql import ( + PostgreSQLConnector, + SQLConnector, + SQLConnectorConfig, +) +from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import ( + JudgePromptGeneration, +) +from pandasai.helpers.logger import Logger +from pandasai.llm.bamboo_llm import BambooLLM +from pandasai.llm.fake import FakeLLM +from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput +from pandasai.pipelines.pipeline_context import PipelineContext +from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA, VIZ_QUERY_SCHEMA_STR + + +class MockBambooLLM(BambooLLM): + def __init__(self): + pass + + def call(self, *args, **kwargs): + return VIZ_QUERY_SCHEMA_STR + + +class TestJudgePromptGeneration: + "Unit test for Validate Pipeline Input" + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def sql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return SQLConnector(self.config) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def pgsql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="pgsql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return PostgreSQLConnector(self.config) + + @pytest.fixture + def config(self, llm): + return {"llm": llm, "enable_cache": True} + + @pytest.fixture + def context(self, sample_df, config): + return PipelineContext([sample_df], config) + + @pytest.fixture + def logger(self): + return Logger(True, False) + + def test_init(self, context, config): + # Test the initialization of the CodeGenerator + code_generator = JudgePromptGeneration() + assert isinstance(code_generator, JudgePromptGeneration) + + def test_validate_input_semantic_prompt(self, sample_df, context, logger): + semantic_prompter = JudgePromptGeneration() + + llm = MockBambooLLM() + + # context for true config + config = {"llm": llm, "enable_cache": True, "direct_sql": False} + + context = PipelineContext([sample_df], config) + + context.memory.add("hello word!", True) + + context.add("df_schema", VIZ_QUERY_SCHEMA) + + input_data = JudgePipelineInput( + query="What is test?", code="print('Code Data')" + ) + + response = semantic_prompter.execute( + input_data=input_data, context=context, logger=logger + ) + + assert ( + response.output.to_string() + == """### QUERY +What is test? +### GENERATED CODE +print('Code Data') + +Reason step by step and at the end answer: +1. Explain what the code does +2. Explain what the user query asks for +3. Strictly compare the query with the code that is generated +Always return or if exactly meets the requirements""" + ) diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py b/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py index 05894733b..df309a8a2 100644 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py +++ b/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py @@ -9,12 +9,11 @@ SQLConnector, SQLConnectorConfig, ) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import ( - LLMCall, -) + from pandasai.helpers.logger import Logger from pandasai.llm.bamboo_llm import BambooLLM from pandasai.llm.fake import FakeLLM +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.pipelines.pipeline_context import PipelineContext from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA_STR diff --git a/tests/unit_tests/pipelines/test_pipeline.py b/tests/unit_tests/pipelines/test_pipeline.py index 16bb93317..3e60e77a1 100644 --- a/tests/unit_tests/pipelines/test_pipeline.py +++ b/tests/unit_tests/pipelines/test_pipeline.py @@ -5,8 +5,11 @@ import pytest from pandasai.connectors import BaseConnector, PandasConnector +from pandasai.ee.agents.judge_agent import JudgeAgent +from pandasai.helpers.logger import Logger from pandasai.llm.fake import FakeLLM from pandasai.pipelines.base_logic_unit import BaseLogicUnit +from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext from pandasai.schemas.df_config import Config @@ -77,6 +80,10 @@ def config(self, llm): def context(self, sample_df, config): return PipelineContext([sample_df], config) + @pytest.fixture + def logger(self): + return Logger(True, False) + def test_init(self, context, config): # Test the initialization of the Pipeline pipeline = Pipeline(context) @@ -156,3 +163,15 @@ def execute(self, data, logger, config, context): result = pipeline_2.run(5) assert result == 8 + + def test_pipeline_constructor_with_judge(self, context): + judge_agent = JudgeAgent() + pipeline = GenerateChatPipeline(context=context, judge=judge_agent) + assert pipeline.judge == judge_agent + assert isinstance(pipeline.context, PipelineContext) + + def test_pipeline_constructor_with_no_judge(self, context): + judge_agent = JudgeAgent() + pipeline = GenerateChatPipeline(context=context, judge=judge_agent) + assert pipeline.judge == judge_agent + assert isinstance(pipeline.context, PipelineContext) From a9fce026b5d838a21d6d5c08617aba1a6aa0a216 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 18 Jun 2024 00:13:53 +0200 Subject: [PATCH 2/6] fix: ruff errors --- examples/judge_agent.py | 1 - pandasai/ee/agents/judge_agent/__init__.py | 1 + .../error_correction_pipeline/error_correction_pipeline.py | 2 +- .../error_correction_pipeline/fix_semantic_json_pipeline.py | 2 +- .../agents/semantic_agent/pipeline/semantic_chat_pipeline.py | 2 +- tests/unit_tests/ee/judge_agent/test_judge_llm_call.py | 1 - tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py | 3 +-- 7 files changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/judge_agent.py b/examples/judge_agent.py index dc930600f..4d38d648d 100644 --- a/examples/judge_agent.py +++ b/examples/judge_agent.py @@ -6,7 +6,6 @@ from pandasai.ee.agents.judge_agent import JudgeAgent from pandasai.llm.openai import OpenAI - os.environ["PANDASAI_API_KEY"] = "$2a****************************" github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv") diff --git a/pandasai/ee/agents/judge_agent/__init__.py b/pandasai/ee/agents/judge_agent/__init__.py index e33d43a40..93099f950 100644 --- a/pandasai/ee/agents/judge_agent/__init__.py +++ b/pandasai/ee/agents/judge_agent/__init__.py @@ -1,4 +1,5 @@ from typing import Optional, Union + from pandasai.agent.base_judge import BaseJudge from pandasai.config import load_config_from_json from pandasai.ee.agents.judge_agent.pipeline.judge_pipeline import JudgePipeline diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py index 79e4b3dd7..65e34db94 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py @@ -4,6 +4,7 @@ from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import ( FixSemanticJsonPipeline, ) +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( SemanticPromptGeneration, ) @@ -13,7 +14,6 @@ from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( ErrorCorrectionPipelineInput, ) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py index 0d5479871..3ec39ea40 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py @@ -3,12 +3,12 @@ from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_schema_prompt import ( FixSemanticSchemaPrompt, ) +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.helpers.logger import Logger from pandasai.helpers.query_exec_tracker import QueryExecTracker from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( ErrorCorrectionPipelineInput, ) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext diff --git a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py index eb8ba480c..023cb3ff6 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py @@ -8,6 +8,7 @@ from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import ( FixSemanticJsonPipeline, ) +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( SemanticPromptGeneration, ) @@ -26,7 +27,6 @@ ) from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.pipelines.chat.result_validation import ResultValidation -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext diff --git a/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py b/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py index 35bb977a5..52fc017db 100644 --- a/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py +++ b/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py @@ -9,7 +9,6 @@ SQLConnector, SQLConnectorConfig, ) - from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall from pandasai.exceptions import InvalidOutputValueMismatch from pandasai.helpers.logger import Logger diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py b/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py index df309a8a2..89abc1ff3 100644 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py +++ b/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py @@ -9,11 +9,10 @@ SQLConnector, SQLConnectorConfig, ) - +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.helpers.logger import Logger from pandasai.llm.bamboo_llm import BambooLLM from pandasai.llm.fake import FakeLLM -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.pipelines.pipeline_context import PipelineContext from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA_STR From 8f0079a4fc69fa311d4b7156eafef3f0ba0a6c5a Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 18 Jun 2024 10:59:01 +0200 Subject: [PATCH 3/6] feat(JudgeAgent): make judge agent using memory from chat agent --- pandasai/ee/agents/judge_agent/__init__.py | 5 ++--- pandasai/pipelines/chat/generate_chat_pipeline.py | 6 +++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pandasai/ee/agents/judge_agent/__init__.py b/pandasai/ee/agents/judge_agent/__init__.py index 93099f950..a47d45045 100644 --- a/pandasai/ee/agents/judge_agent/__init__.py +++ b/pandasai/ee/agents/judge_agent/__init__.py @@ -12,16 +12,15 @@ class JudgeAgent(BaseJudge): def __init__( self, - pipeline: AbstractPipeline = None, config: Optional[Union[Config, dict]] = None, + pipeline: AbstractPipeline = None, ) -> None: context = None if config: if isinstance(config, dict): config = Config(**load_config_from_json(config)) - connectors = context - context = PipelineContext(connectors, config) + context = PipelineContext(None, config) pipeline = pipeline or JudgePipeline(context=context) super().__init__(pipeline) diff --git a/pandasai/pipelines/chat/generate_chat_pipeline.py b/pandasai/pipelines/chat/generate_chat_pipeline.py index 61126451d..2b3595107 100644 --- a/pandasai/pipelines/chat/generate_chat_pipeline.py +++ b/pandasai/pipelines/chat/generate_chat_pipeline.py @@ -104,7 +104,11 @@ def __init__( self.judge = judge if self.judge: - self.judge.pipeline.pipeline.context = context + if self.judge.pipeline.pipeline.context: + self.judge.pipeline.pipeline.context.memory = context.memory + else: + self.judge.pipeline.pipeline.context = context + self.judge.pipeline.pipeline.logger = logger self.judge.pipeline.pipeline.query_exec_tracker = self.query_exec_tracker From 3177843f01b1aff24462b306bf3390f61636e14f Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 18 Jun 2024 12:52:20 +0200 Subject: [PATCH 4/6] chore add datetime in prompt --- .../judge_agent/pipeline/judge_prompt_generation.py | 9 ++++++++- .../prompts/templates/judge_agent_prompt.tmpl | 1 + .../unit_tests/ee/judge_agent/test_judge_prompt_gen.py | 10 +++++++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py b/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py index 23dc7c11e..a8ab9b565 100644 --- a/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py +++ b/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py @@ -1,3 +1,4 @@ +import datetime from typing import Any from pandasai.ee.agents.judge_agent.prompts.judge_agent_prompt import JudgeAgentPrompt @@ -30,8 +31,14 @@ def execute(self, input_data: JudgePipelineInput, **kwargs) -> Any: self.context = kwargs.get("context") self.logger: Logger = kwargs.get("logger") + now = datetime.datetime.now() + human_readable_datetime = now.strftime("%A, %B %d, %Y %I:%M %p") + prompt = JudgeAgentPrompt( - query=input_data.query, code=input_data.code, context=self.context + query=input_data.query, + code=input_data.code, + context=self.context, + date=human_readable_datetime, ) self.logger.log(f"Using prompt: {prompt}") diff --git a/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl b/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl index a444781cd..315f7057d 100644 --- a/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl +++ b/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl @@ -1,3 +1,4 @@ +Today is {{date}} ### QUERY {{query}} ### GENERATED CODE diff --git a/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py b/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py index d670115bb..d300ec371 100644 --- a/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py +++ b/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py @@ -1,3 +1,4 @@ +import re from typing import Optional from unittest.mock import patch @@ -155,9 +156,16 @@ def test_validate_input_semantic_prompt(self, sample_df, context, logger): input_data=input_data, context=context, logger=logger ) + match = re.search( + r"Today is ([A-Za-z]+, [A-Za-z]+ \d{1,2}, \d{4} \d{2}:\d{2} [APM]{2})", + response.output.to_string(), + ) + datetime_str = match.group(1) + assert ( response.output.to_string() - == """### QUERY + == f"""Today is {datetime_str} +### QUERY What is test? ### GENERATED CODE print('Code Data') From 9a0c63ca79076698d7bc9642be4935f4f021a074 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 18 Jun 2024 14:23:49 +0200 Subject: [PATCH 5/6] add documentation --- docs/judge-agent.mdx | 64 ++++++++++++++++++++++++++++++++++++++++++++ docs/mint.json | 2 +- 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 docs/judge-agent.mdx diff --git a/docs/judge-agent.mdx b/docs/judge-agent.mdx new file mode 100644 index 000000000..6d713416a --- /dev/null +++ b/docs/judge-agent.mdx @@ -0,0 +1,64 @@ +--- +title: "Judge Agent" +description: "Enhance the PandasAI library with the JudgeAgent that evaluates the generated code" +--- + +## Introduction to the Judge Agent + +The `JudgeAgent` extends the capabilities of the PandasAI library by adding an extra judgement in agents pipeline that validates the code generated against the query + +> **Note:** Usage of the Judge Agent may be subject to a license. For more details, refer to the [license documentation](https://github.com/Sinaptik-AI/pandas-ai/blob/master/pandasai/ee/LICENSE). + +## Instantiating the Judge Agent + +JudgeAgent can be used as a standalone as well as with the other agents as well. For other agents judge agents needs to passed as a param to other agents. + +### Using with other agents + +```python +import os + +import pandas as pd + +from pandasai.agent.agent import Agent +from pandasai.ee.agents.judge_agent import JudgeAgent + +os.environ["PANDASAI_API_KEY"] = "$2a****************************" + +github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv") + +judge = JudgeAgent() +agent = Agent([github_stars], judge=judge) + +print(agent.chat("return total stars count")) +``` + +### Using as a standalone + +```python +import os + +import pandas as pd + +from pandasai.ee.agents.judge_agent import JudgeAgent +from pandasai.llm.openai import OpenAI + +# can be used with all LLM's +llm = OpenAI("openai_key") +judge_agent = JudgeAgent(config={"llm": llm}) +judge_agent.evaluate( + query="return total github star count for year 2023", + code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc" + data = execute_sql_query(sql_query) + plt.plot(data['starred_at_by_month'], data['user_count']) + plt.xlabel('Month') + plt.ylabel('User Count') + plt.title('GitHub Star Count Per Month - Year 2023') + plt.legend(loc='best') + plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png') + result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'} + """, +) +``` + +Judge Agent integration with other agents also gives the flexibility to use different LLM's diff --git a/docs/mint.json b/docs/mint.json index 0cb659455..ad8b4a736 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -33,7 +33,7 @@ }, { "group": "Advanced agents", - "pages": ["semantic-agent"] + "pages": ["semantic-agent", "judge-agent"] }, { "group": "Advanced usage", From 4d332fc1e43bad28c6bc7c5e56a50de5e1fa1335 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 18 Jun 2024 16:13:42 +0200 Subject: [PATCH 6/6] docs(judge): update judge documentation --- docs/judge-agent.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/judge-agent.mdx b/docs/judge-agent.mdx index 6d713416a..67f31bc6a 100644 --- a/docs/judge-agent.mdx +++ b/docs/judge-agent.mdx @@ -11,7 +11,7 @@ The `JudgeAgent` extends the capabilities of the PandasAI library by adding an e ## Instantiating the Judge Agent -JudgeAgent can be used as a standalone as well as with the other agents as well. For other agents judge agents needs to passed as a param to other agents. +JudgeAgent can be used both as a standalone agent and in conjunction with other agents. To use it with other agents, pass JudgeAgent as a parameter to them. ### Using with other agents