diff --git a/docs/judge-agent.mdx b/docs/judge-agent.mdx new file mode 100644 index 000000000..67f31bc6a --- /dev/null +++ b/docs/judge-agent.mdx @@ -0,0 +1,64 @@ +--- +title: "Judge Agent" +description: "Enhance the PandasAI library with the JudgeAgent that evaluates the generated code" +--- + +## Introduction to the Judge Agent + +The `JudgeAgent` extends the capabilities of the PandasAI library by adding an extra judgement in agents pipeline that validates the code generated against the query + +> **Note:** Usage of the Judge Agent may be subject to a license. For more details, refer to the [license documentation](https://github.com/Sinaptik-AI/pandas-ai/blob/master/pandasai/ee/LICENSE). + +## Instantiating the Judge Agent + +JudgeAgent can be used both as a standalone agent and in conjunction with other agents. To use it with other agents, pass JudgeAgent as a parameter to them. + +### Using with other agents + +```python +import os + +import pandas as pd + +from pandasai.agent.agent import Agent +from pandasai.ee.agents.judge_agent import JudgeAgent + +os.environ["PANDASAI_API_KEY"] = "$2a****************************" + +github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv") + +judge = JudgeAgent() +agent = Agent([github_stars], judge=judge) + +print(agent.chat("return total stars count")) +``` + +### Using as a standalone + +```python +import os + +import pandas as pd + +from pandasai.ee.agents.judge_agent import JudgeAgent +from pandasai.llm.openai import OpenAI + +# can be used with all LLM's +llm = OpenAI("openai_key") +judge_agent = JudgeAgent(config={"llm": llm}) +judge_agent.evaluate( + query="return total github star count for year 2023", + code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc" + data = execute_sql_query(sql_query) + plt.plot(data['starred_at_by_month'], data['user_count']) + plt.xlabel('Month') + plt.ylabel('User Count') + plt.title('GitHub Star Count Per Month - Year 2023') + plt.legend(loc='best') + plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png') + result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'} + """, +) +``` + +Judge Agent integration with other agents also gives the flexibility to use different LLM's diff --git a/docs/mint.json b/docs/mint.json index 0cb659455..ad8b4a736 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -33,7 +33,7 @@ }, { "group": "Advanced agents", - "pages": ["semantic-agent"] + "pages": ["semantic-agent", "judge-agent"] }, { "group": "Advanced usage", diff --git a/examples/judge_agent.py b/examples/judge_agent.py new file mode 100644 index 000000000..4d38d648d --- /dev/null +++ b/examples/judge_agent.py @@ -0,0 +1,34 @@ +import os + +import pandas as pd + +from pandasai.agent.agent import Agent +from pandasai.ee.agents.judge_agent import JudgeAgent +from pandasai.llm.openai import OpenAI + +os.environ["PANDASAI_API_KEY"] = "$2a****************************" + +github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv") + +judge = JudgeAgent() +agent = Agent([github_stars], judge=judge) + +print(agent.chat("return total stars count")) + + +# Using Judge standalone +llm = OpenAI("openai_key") +judge_agent = JudgeAgent(config={"llm": llm}) +judge_agent.evaluate( + query="return total github star count for year 2023", + code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc" + data = execute_sql_query(sql_query) + plt.plot(data['starred_at_by_month'], data['user_count']) + plt.xlabel('Month') + plt.ylabel('User Count') + plt.title('GitHub Star Count Per Month - Year 2023') + plt.legend(loc='best') + plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png') + result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'} + """, +) diff --git a/pandasai/agent/agent.py b/pandasai/agent/agent.py index 179ae8c23..8a3fbfcc8 100644 --- a/pandasai/agent/agent.py +++ b/pandasai/agent/agent.py @@ -3,6 +3,7 @@ import pandas as pd from pandasai.agent.base import BaseAgent +from pandasai.agent.base_judge import BaseJudge from pandasai.connectors.base import BaseConnector from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.schemas.df_config import Config @@ -20,6 +21,7 @@ def __init__( pipeline: Optional[Type[GenerateChatPipeline]] = None, vectorstore: Optional[VectorStore] = None, description: str = None, + judge: BaseJudge = None, ): super().__init__(dfs, config, memory_size, vectorstore, description) @@ -31,6 +33,7 @@ def __init__( on_code_generation=self._callbacks.on_code_generation, before_code_execution=self._callbacks.before_code_execution, on_result=self._callbacks.on_result, + judge=judge, ) if pipeline else GenerateChatPipeline( @@ -40,6 +43,7 @@ def __init__( on_code_generation=self._callbacks.on_code_generation, before_code_execution=self._callbacks.before_code_execution, on_result=self._callbacks.on_result, + judge=judge, ) ) diff --git a/pandasai/agent/base_judge.py b/pandasai/agent/base_judge.py new file mode 100644 index 000000000..d4c6b6136 --- /dev/null +++ b/pandasai/agent/base_judge.py @@ -0,0 +1,18 @@ +from pandasai.helpers.logger import Logger +from pandasai.pipelines.pipeline import Pipeline +from pandasai.pipelines.pipeline_context import PipelineContext + + +class BaseJudge: + context: PipelineContext + pipeline: Pipeline + logger: Logger + + def __init__( + self, + pipeline: Pipeline, + ) -> None: + self.pipeline = pipeline + + def evaluate(self, query: str, code: str) -> bool: + raise NotImplementedError diff --git a/pandasai/ee/agents/judge_agent/__init__.py b/pandasai/ee/agents/judge_agent/__init__.py new file mode 100644 index 000000000..a47d45045 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/__init__.py @@ -0,0 +1,30 @@ +from typing import Optional, Union + +from pandasai.agent.base_judge import BaseJudge +from pandasai.config import load_config_from_json +from pandasai.ee.agents.judge_agent.pipeline.judge_pipeline import JudgePipeline +from pandasai.pipelines.abstract_pipeline import AbstractPipeline +from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput +from pandasai.pipelines.pipeline_context import PipelineContext +from pandasai.schemas.df_config import Config + + +class JudgeAgent(BaseJudge): + def __init__( + self, + config: Optional[Union[Config, dict]] = None, + pipeline: AbstractPipeline = None, + ) -> None: + context = None + if config: + if isinstance(config, dict): + config = Config(**load_config_from_json(config)) + + context = PipelineContext(None, config) + + pipeline = pipeline or JudgePipeline(context=context) + super().__init__(pipeline) + + def evaluate(self, query: str, code: str) -> bool: + input_data = JudgePipelineInput(query, code) + return self.pipeline.run(input_data) diff --git a/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py b/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py new file mode 100644 index 000000000..8797a6ee0 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py @@ -0,0 +1,34 @@ +from typing import Optional + +from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import ( + JudgePromptGeneration, +) +from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall +from pandasai.helpers.logger import Logger +from pandasai.helpers.query_exec_tracker import QueryExecTracker +from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput +from pandasai.pipelines.pipeline import Pipeline +from pandasai.pipelines.pipeline_context import PipelineContext + + +class JudgePipeline: + def __init__( + self, + context: Optional[PipelineContext] = None, + logger: Optional[Logger] = None, + query_exec_tracker: QueryExecTracker = None, + ): + self.query_exec_tracker = query_exec_tracker + + self.pipeline = Pipeline( + context=context, + logger=logger, + query_exec_tracker=self.query_exec_tracker, + steps=[ + JudgePromptGeneration(), + LLMCall(), + ], + ) + + def run(self, input: JudgePipelineInput): + return self.pipeline.run(input) diff --git a/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py b/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py new file mode 100644 index 000000000..a8ab9b565 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py @@ -0,0 +1,50 @@ +import datetime +from typing import Any + +from pandasai.ee.agents.judge_agent.prompts.judge_agent_prompt import JudgeAgentPrompt +from pandasai.helpers.logger import Logger +from pandasai.pipelines.base_logic_unit import BaseLogicUnit +from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput +from pandasai.pipelines.logic_unit_output import LogicUnitOutput + + +class JudgePromptGeneration(BaseLogicUnit): + """ + Code Prompt Generation Stage + """ + + pass + + def execute(self, input_data: JudgePipelineInput, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Last logic unit output + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: LogicUnitOutput(prompt) + """ + self.context = kwargs.get("context") + self.logger: Logger = kwargs.get("logger") + + now = datetime.datetime.now() + human_readable_datetime = now.strftime("%A, %B %d, %Y %I:%M %p") + + prompt = JudgeAgentPrompt( + query=input_data.query, + code=input_data.code, + context=self.context, + date=human_readable_datetime, + ) + self.logger.log(f"Using prompt: {prompt}") + + return LogicUnitOutput( + prompt, + True, + "Prompt Generated Successfully", + {"content_type": "prompt", "value": prompt.to_string()}, + ) diff --git a/pandasai/ee/agents/judge_agent/pipeline/llm_call.py b/pandasai/ee/agents/judge_agent/pipeline/llm_call.py new file mode 100644 index 000000000..47758b263 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/pipeline/llm_call.py @@ -0,0 +1,64 @@ +from typing import Any + +from pandasai.exceptions import InvalidOutputValueMismatch +from pandasai.helpers.logger import Logger +from pandasai.pipelines.base_logic_unit import BaseLogicUnit +from pandasai.pipelines.logic_unit_output import LogicUnitOutput +from pandasai.pipelines.pipeline_context import PipelineContext + + +class LLMCall(BaseLogicUnit): + """ + LLM Code Generation Stage + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + logger: Logger = kwargs.get("logger") + + retry_count = 0 + while retry_count <= pipeline_context.config.max_retries: + response = pipeline_context.config.llm.call(input, pipeline_context) + + logger.log( + f"""LLM response: + {response} + """ + ) + try: + result = False + if "" in response: + result = True + elif "" in response: + result = False + else: + raise InvalidOutputValueMismatch("Invalid response of LLM Call") + + pipeline_context.add("llm_call", response) + + return LogicUnitOutput( + result, + True, + "Code Generated Successfully", + {"content_type": "string", "value": response}, + ) + except Exception: + if retry_count == pipeline_context.config.max_retries: + raise + + retry_count += 1 diff --git a/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py b/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py new file mode 100644 index 000000000..91616aaf8 --- /dev/null +++ b/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py @@ -0,0 +1,39 @@ +from pathlib import Path + +from jinja2 import Environment, FileSystemLoader + +from pandasai.prompts.base import BasePrompt + + +class JudgeAgentPrompt(BasePrompt): + """Prompt to generate Python code from a dataframe.""" + + template_path = "judge_agent_prompt.tmpl" + + def __init__(self, **kwargs): + """Initialize the prompt.""" + self.props = kwargs + + if self.template: + env = Environment() + self.prompt = env.from_string(self.template) + elif self.template_path: + # find path to template file + current_dir_path = Path(__file__).parent + + path_to_template = current_dir_path / "templates" + env = Environment(loader=FileSystemLoader(path_to_template)) + self.prompt = env.get_template(self.template_path) + + self._resolved_prompt = None + + def to_json(self): + context = self.props["context"] + memory = context.memory + conversations = memory.to_json() + system_prompt = memory.get_system_prompt() + return { + "conversation": conversations, + "system_prompt": system_prompt, + "prompt": self.to_string(), + } diff --git a/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl b/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl new file mode 100644 index 000000000..315f7057d --- /dev/null +++ b/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl @@ -0,0 +1,11 @@ +Today is {{date}} +### QUERY +{{query}} +### GENERATED CODE +{{code}} + +Reason step by step and at the end answer: +1. Explain what the code does +2. Explain what the user query asks for +3. Strictly compare the query with the code that is generated +Always return or if exactly meets the requirements diff --git a/pandasai/ee/agents/semantic_agent/__init__.py b/pandasai/ee/agents/semantic_agent/__init__.py index b28642cd7..dc8aee31b 100644 --- a/pandasai/ee/agents/semantic_agent/__init__.py +++ b/pandasai/ee/agents/semantic_agent/__init__.py @@ -4,6 +4,7 @@ import pandas as pd from pandasai.agent.base import BaseAgent +from pandasai.agent.base_judge import BaseJudge from pandasai.connectors.base import BaseConnector from pandasai.connectors.pandas import PandasConnector from pandasai.constants import PANDASBI_SETUP_MESSAGE @@ -41,6 +42,7 @@ def __init__( pipeline: Optional[Type[GenerateChatPipeline]] = None, vectorstore: Optional[VectorStore] = None, description: str = None, + judge: BaseJudge = None, ): super().__init__(dfs, config, memory_size, vectorstore, description) @@ -70,6 +72,7 @@ def __init__( pipeline( self.context, self.logger, + judge=judge, on_prompt_generation=self._callbacks.on_prompt_generation, on_code_generation=self._callbacks.on_code_generation, before_code_execution=self._callbacks.before_code_execution, @@ -79,6 +82,7 @@ def __init__( else SemanticChatPipeline( self.context, self.logger, + judge=judge, on_prompt_generation=self._callbacks.on_prompt_generation, on_code_generation=self._callbacks.on_code_generation, before_code_execution=self._callbacks.before_code_execution, diff --git a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py index cb5789523..e9946140d 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable +from typing import Any from pandasai.helpers.logger import Logger from pandasai.pipelines.base_logic_unit import BaseLogicUnit @@ -12,11 +12,8 @@ class LLMCall(BaseLogicUnit): LLM Code Generation Stage """ - def __init__( - self, on_code_generation: Callable[[str, Exception], None] = None, **kwargs - ): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.on_code_generation = on_code_generation def execute(self, input: Any, **kwargs) -> Any: """ diff --git a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py index 14e9ea870..023cb3ff6 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py @@ -1,5 +1,6 @@ from typing import Optional +from pandasai.agent.base_judge import BaseJudge from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.error_correction_pipeline import ( ErrorCorrectionPipeline, @@ -41,6 +42,7 @@ def __init__( self, context: Optional[PipelineContext] = None, logger: Optional[Logger] = None, + judge: BaseJudge = None, on_prompt_generation=None, on_code_generation=None, before_code_execution=None, @@ -49,6 +51,7 @@ def __init__( super().__init__( context, logger, + judge=judge, on_prompt_generation=on_prompt_generation, on_code_generation=on_code_generation, before_code_execution=before_code_execution, diff --git a/pandasai/pipelines/chat/generate_chat_pipeline.py b/pandasai/pipelines/chat/generate_chat_pipeline.py index 94609fd68..2b3595107 100644 --- a/pandasai/pipelines/chat/generate_chat_pipeline.py +++ b/pandasai/pipelines/chat/generate_chat_pipeline.py @@ -1,5 +1,6 @@ from typing import Optional +from pandasai.agent.base_judge import BaseJudge from pandasai.helpers.query_exec_tracker import QueryExecTracker from pandasai.pipelines.chat.chat_pipeline_input import ( ChatPipelineInput, @@ -41,6 +42,7 @@ def __init__( self, context: Optional[PipelineContext] = None, logger: Optional[Logger] = None, + judge: BaseJudge = None, on_prompt_generation=None, on_code_generation=None, before_code_execution=None, @@ -99,6 +101,17 @@ def __init__( on_prompt_generation=on_prompt_generation, ) + self.judge = judge + + if self.judge: + if self.judge.pipeline.pipeline.context: + self.judge.pipeline.pipeline.context.memory = context.memory + else: + self.judge.pipeline.pipeline.context = context + + self.judge.pipeline.pipeline.logger = logger + self.judge.pipeline.pipeline.query_exec_tracker = self.query_exec_tracker + self.context = context self._logger = logger self.last_error = None @@ -304,7 +317,19 @@ def run(self, input: ChatPipelineInput) -> dict: } ) try: - if self.code_execution_pipeline: + if self.judge: + code = self.code_generation_pipeline.run(input) + + retry_count = 0 + while retry_count < self.context.config.max_retries: + if self.judge.evaluate(query=input.query, code=code): + break + code = self.code_generation_pipeline.run(input) + retry_count += 1 + + output = self.code_execution_pipeline.run(code) + + elif self.code_execution_pipeline: output = ( self.code_generation_pipeline | self.code_execution_pipeline ).run(input) diff --git a/pandasai/pipelines/judge/judge_pipeline_input.py b/pandasai/pipelines/judge/judge_pipeline_input.py new file mode 100644 index 000000000..aaceea15c --- /dev/null +++ b/pandasai/pipelines/judge/judge_pipeline_input.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + + +@dataclass +class JudgePipelineInput: + query: str + code: str + + def __init__(self, query: str, code: str) -> None: + self.query = query + self.code = code diff --git a/pandasai/pipelines/pipeline.py b/pandasai/pipelines/pipeline.py index 59965007b..76ae0a57c 100644 --- a/pandasai/pipelines/pipeline.py +++ b/pandasai/pipelines/pipeline.py @@ -43,21 +43,21 @@ def __init__( logger: (Logger): logger """ - if not isinstance(context, PipelineContext): + if context and not isinstance(context, PipelineContext): config = Config(**load_config_from_json(config)) connectors = context context = PipelineContext(connectors, config) self._logger = ( Logger(save_logs=context.config.save_logs, verbose=context.config.verbose) - if logger is None + if logger is None and context else logger ) self._context = context self._steps = steps or [] - self._query_exec_tracker = query_exec_tracker or QueryExecTracker( - server_config=self._context.config.log_server + self._query_exec_tracker = query_exec_tracker or ( + context and QueryExecTracker(server_config=self._context.config.log_server) ) def add_step(self, logic: BaseLogicUnit): @@ -167,3 +167,27 @@ def __or__(self, pipeline: "Pipeline") -> Any: combined_pipeline.add_step(step) return combined_pipeline + + @property + def context(self): + return self._context + + @context.setter + def context(self, context: PipelineContext): + self._context = context + + @property + def logger(self): + return self._logger + + @logger.setter + def logger(self, logger: Logger): + self._logger = logger + + @property + def query_exec_tracker(self): + return self._query_exec_tracker + + @query_exec_tracker.setter + def query_exec_tracker(self, query_exec_tracker: QueryExecTracker): + self._query_exec_tracker = query_exec_tracker diff --git a/tests/unit_tests/ee/judge_agent/test_judge_agent.py b/tests/unit_tests/ee/judge_agent/test_judge_agent.py new file mode 100644 index 000000000..bae937445 --- /dev/null +++ b/tests/unit_tests/ee/judge_agent/test_judge_agent.py @@ -0,0 +1,229 @@ +from typing import Optional +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from pandasai.agent import Agent +from pandasai.connectors.sql import ( + PostgreSQLConnector, + SQLConnector, + SQLConnectorConfig, +) +from pandasai.ee.agents.judge_agent import JudgeAgent +from pandasai.helpers.dataframe_serializer import DataframeSerializerType +from pandasai.llm.bamboo_llm import BambooLLM +from pandasai.llm.fake import FakeLLM +from tests.unit_tests.ee.helpers.schema import ( + VIZ_QUERY_SCHEMA_STR, +) + + +class MockBambooLLM(BambooLLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.call = MagicMock(return_value=VIZ_QUERY_SCHEMA_STR) + + +class TestJudgeAgent: + "Unit tests for Agent class" + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "order_id": [ + 10248, + 10249, + 10250, + 10251, + 10252, + 10253, + 10254, + 10255, + 10256, + 10257, + ], + "customer_id": [ + "VINET", + "TOMSP", + "HANAR", + "VICTE", + "SUPRD", + "HANAR", + "CHOPS", + "RICSU", + "WELLI", + "HILAA", + ], + "employee_id": [5, 6, 4, 3, 4, 3, 4, 7, 3, 4], + "order_date": pd.to_datetime( + [ + "1996-07-04", + "1996-07-05", + "1996-07-08", + "1996-07-08", + "1996-07-09", + "1996-07-10", + "1996-07-11", + "1996-07-12", + "1996-07-15", + "1996-07-16", + ] + ), + "required_date": pd.to_datetime( + [ + "1996-08-01", + "1996-08-16", + "1996-08-05", + "1996-08-05", + "1996-08-06", + "1996-08-07", + "1996-08-08", + "1996-08-09", + "1996-08-12", + "1996-08-13", + ] + ), + "shipped_date": pd.to_datetime( + [ + "1996-07-16", + "1996-07-10", + "1996-07-12", + "1996-07-15", + "1996-07-11", + "1996-07-16", + "1996-07-23", + "1996-07-26", + "1996-07-17", + "1996-07-22", + ] + ), + "ship_via": [3, 1, 2, 1, 2, 2, 2, 3, 2, 1], + "ship_name": [ + "Vins et alcools Chevalier", + "Toms Spezialitäten", + "Hanari Carnes", + "Victuailles en stock", + "Suprêmes délices", + "Hanari Carnes", + "Chop-suey Chinese", + "Richter Supermarkt", + "Wellington Importadora", + "HILARION-Abastos", + ], + "ship_address": [ + "59 rue de l'Abbaye", + "Luisenstr. 48", + "Rua do Paço, 67", + "2, rue du Commerce", + "Boulevard Tirou, 255", + "Rua do Paço, 67", + "Hauptstr. 31", + "Starenweg 5", + "Rua do Mercado, 12", + "Carrera 22 con Ave. Carlos Soublette #8-35", + ], + "ship_city": [ + "Reims", + "Münster", + "Rio de Janeiro", + "Lyon", + "Charleroi", + "Rio de Janeiro", + "Bern", + "Genève", + "Resende", + "San Cristóbal", + ], + "ship_region": [ + "CJ", + None, + "RJ", + "RH", + None, + "RJ", + None, + None, + "SP", + "Táchira", + ], + "ship_postal_code": [ + "51100", + "44087", + "05454-876", + "69004", + "B-6000", + "05454-876", + "3012", + "1204", + "08737-363", + "5022", + ], + "ship_country": [ + "France", + "Germany", + "Brazil", + "France", + "Belgium", + "Brazil", + "Switzerland", + "Switzerland", + "Brazil", + "Venezuela", + ], + } + ) + + @pytest.fixture + def llm(self, output: Optional[str] = None) -> FakeLLM: + return FakeLLM(output=output) + + @pytest.fixture + def config(self, llm: FakeLLM) -> dict: + return {"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV} + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def sql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return SQLConnector(self.config) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def pgsql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return PostgreSQLConnector(self.config) + + @pytest.fixture + def agent(self) -> Agent: + return JudgeAgent() + + def test_contruct_with_pipeline(self, sample_df): + JudgeAgent(pipeline=MagicMock()) diff --git a/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py b/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py new file mode 100644 index 000000000..52fc017db --- /dev/null +++ b/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py @@ -0,0 +1,179 @@ +from typing import Optional +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from pandasai.connectors.sql import ( + PostgreSQLConnector, + SQLConnector, + SQLConnectorConfig, +) +from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall +from pandasai.exceptions import InvalidOutputValueMismatch +from pandasai.helpers.logger import Logger +from pandasai.llm.bamboo_llm import BambooLLM +from pandasai.llm.fake import FakeLLM +from pandasai.pipelines.logic_unit_output import LogicUnitOutput +from pandasai.pipelines.pipeline_context import PipelineContext +from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA_STR + + +class MockBambooLLM(BambooLLM): + def __init__(self): + pass + + def call(self, *args, **kwargs): + return VIZ_QUERY_SCHEMA_STR + + +class TestJudgeLLMCall: + "Unit test for Validate Pipeline Input" + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def sql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return SQLConnector(self.config) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def pgsql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="pgsql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return PostgreSQLConnector(self.config) + + @pytest.fixture + def config(self, llm): + return {"llm": llm, "enable_cache": True} + + @pytest.fixture + def context(self, sample_df, config): + return PipelineContext([sample_df], config) + + @pytest.fixture + def logger(self): + return Logger(True, False) + + def test_init(self, context, config): + # Test the initialization of the CodeGenerator + code_generator = LLMCall() + assert isinstance(code_generator, LLMCall) + + def test_llm_call(self, sample_df, context, logger, config): + input_validator = LLMCall() + + config["llm"].call = MagicMock(return_value="") + + context = PipelineContext([sample_df], config) + + result = input_validator.execute(input="test", context=context, logger=logger) + + assert isinstance(result, LogicUnitOutput) + assert result.output is True + + def test_llm_call_no(self, sample_df, context, logger, config): + input_validator = LLMCall() + + config["llm"].call = MagicMock(return_value="") + + context = PipelineContext([sample_df], config) + + result = input_validator.execute(input="test", context=context, logger=logger) + + assert isinstance(result, LogicUnitOutput) + assert result.output is False + + def test_llm_call_(self, sample_df, context, logger, config): + input_validator = LLMCall() + + config["llm"].call = MagicMock(return_value="") + + context = PipelineContext([sample_df], config) + + result = input_validator.execute(input="test", context=context, logger=logger) + + assert isinstance(result, LogicUnitOutput) + assert result.output is False + + def test_llm_call_with_no_tags(self, sample_df, context, logger, config): + input_validator = LLMCall() + + config["llm"].call = MagicMock(return_value="yes") + + context = PipelineContext([sample_df], config) + + with pytest.raises(InvalidOutputValueMismatch): + input_validator.execute(input="test", context=context, logger=logger) diff --git a/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py b/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py new file mode 100644 index 000000000..d300ec371 --- /dev/null +++ b/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py @@ -0,0 +1,178 @@ +import re +from typing import Optional +from unittest.mock import patch + +import pandas as pd +import pytest + +from pandasai.connectors.sql import ( + PostgreSQLConnector, + SQLConnector, + SQLConnectorConfig, +) +from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import ( + JudgePromptGeneration, +) +from pandasai.helpers.logger import Logger +from pandasai.llm.bamboo_llm import BambooLLM +from pandasai.llm.fake import FakeLLM +from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput +from pandasai.pipelines.pipeline_context import PipelineContext +from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA, VIZ_QUERY_SCHEMA_STR + + +class MockBambooLLM(BambooLLM): + def __init__(self): + pass + + def call(self, *args, **kwargs): + return VIZ_QUERY_SCHEMA_STR + + +class TestJudgePromptGeneration: + "Unit test for Validate Pipeline Input" + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def sql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return SQLConnector(self.config) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def pgsql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="pgsql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return PostgreSQLConnector(self.config) + + @pytest.fixture + def config(self, llm): + return {"llm": llm, "enable_cache": True} + + @pytest.fixture + def context(self, sample_df, config): + return PipelineContext([sample_df], config) + + @pytest.fixture + def logger(self): + return Logger(True, False) + + def test_init(self, context, config): + # Test the initialization of the CodeGenerator + code_generator = JudgePromptGeneration() + assert isinstance(code_generator, JudgePromptGeneration) + + def test_validate_input_semantic_prompt(self, sample_df, context, logger): + semantic_prompter = JudgePromptGeneration() + + llm = MockBambooLLM() + + # context for true config + config = {"llm": llm, "enable_cache": True, "direct_sql": False} + + context = PipelineContext([sample_df], config) + + context.memory.add("hello word!", True) + + context.add("df_schema", VIZ_QUERY_SCHEMA) + + input_data = JudgePipelineInput( + query="What is test?", code="print('Code Data')" + ) + + response = semantic_prompter.execute( + input_data=input_data, context=context, logger=logger + ) + + match = re.search( + r"Today is ([A-Za-z]+, [A-Za-z]+ \d{1,2}, \d{4} \d{2}:\d{2} [APM]{2})", + response.output.to_string(), + ) + datetime_str = match.group(1) + + assert ( + response.output.to_string() + == f"""Today is {datetime_str} +### QUERY +What is test? +### GENERATED CODE +print('Code Data') + +Reason step by step and at the end answer: +1. Explain what the code does +2. Explain what the user query asks for +3. Strictly compare the query with the code that is generated +Always return or if exactly meets the requirements""" + ) diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py b/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py index 05894733b..89abc1ff3 100644 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py +++ b/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py @@ -9,9 +9,7 @@ SQLConnector, SQLConnectorConfig, ) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import ( - LLMCall, -) +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.helpers.logger import Logger from pandasai.llm.bamboo_llm import BambooLLM from pandasai.llm.fake import FakeLLM diff --git a/tests/unit_tests/pipelines/test_pipeline.py b/tests/unit_tests/pipelines/test_pipeline.py index 16bb93317..3e60e77a1 100644 --- a/tests/unit_tests/pipelines/test_pipeline.py +++ b/tests/unit_tests/pipelines/test_pipeline.py @@ -5,8 +5,11 @@ import pytest from pandasai.connectors import BaseConnector, PandasConnector +from pandasai.ee.agents.judge_agent import JudgeAgent +from pandasai.helpers.logger import Logger from pandasai.llm.fake import FakeLLM from pandasai.pipelines.base_logic_unit import BaseLogicUnit +from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext from pandasai.schemas.df_config import Config @@ -77,6 +80,10 @@ def config(self, llm): def context(self, sample_df, config): return PipelineContext([sample_df], config) + @pytest.fixture + def logger(self): + return Logger(True, False) + def test_init(self, context, config): # Test the initialization of the Pipeline pipeline = Pipeline(context) @@ -156,3 +163,15 @@ def execute(self, data, logger, config, context): result = pipeline_2.run(5) assert result == 8 + + def test_pipeline_constructor_with_judge(self, context): + judge_agent = JudgeAgent() + pipeline = GenerateChatPipeline(context=context, judge=judge_agent) + assert pipeline.judge == judge_agent + assert isinstance(pipeline.context, PipelineContext) + + def test_pipeline_constructor_with_no_judge(self, context): + judge_agent = JudgeAgent() + pipeline = GenerateChatPipeline(context=context, judge=judge_agent) + assert pipeline.judge == judge_agent + assert isinstance(pipeline.context, PipelineContext)