Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Judge): implementation of judge agent to validate code matches t… #1238

Merged
merged 6 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/judge_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

import pandas as pd

from pandasai.agent.agent import Agent
from pandasai.ee.agents.judge_agent import JudgeAgent
from pandasai.llm.openai import OpenAI

os.environ["PANDASAI_API_KEY"] = "$2a****************************"

github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv")

judge = JudgeAgent()
agent = Agent([github_stars], judge=judge)

print(agent.chat("return total stars count"))


# Using Judge standalone
llm = OpenAI("openai_key")
judge_agent = JudgeAgent(config={"llm": llm})
judge_agent.evaluate(
query="return total github star count for year 2023",
code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc"
data = execute_sql_query(sql_query)
plt.plot(data['starred_at_by_month'], data['user_count'])
plt.xlabel('Month')
plt.ylabel('User Count')
plt.title('GitHub Star Count Per Month - Year 2023')
plt.legend(loc='best')
plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png')
result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'}
""",
)
4 changes: 4 additions & 0 deletions pandasai/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

from pandasai.agent.base import BaseAgent
from pandasai.agent.base_judge import BaseJudge
from pandasai.connectors.base import BaseConnector
from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline
from pandasai.schemas.df_config import Config
Expand All @@ -20,6 +21,7 @@ def __init__(
pipeline: Optional[Type[GenerateChatPipeline]] = None,
vectorstore: Optional[VectorStore] = None,
description: str = None,
judge: BaseJudge = None,
):
super().__init__(dfs, config, memory_size, vectorstore, description)

Expand All @@ -31,6 +33,7 @@ def __init__(
on_code_generation=self._callbacks.on_code_generation,
before_code_execution=self._callbacks.before_code_execution,
on_result=self._callbacks.on_result,
judge=judge,
)
if pipeline
else GenerateChatPipeline(
Expand All @@ -40,6 +43,7 @@ def __init__(
on_code_generation=self._callbacks.on_code_generation,
before_code_execution=self._callbacks.before_code_execution,
on_result=self._callbacks.on_result,
judge=judge,
)
)

Expand Down
18 changes: 18 additions & 0 deletions pandasai/agent/base_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pandasai.helpers.logger import Logger
from pandasai.pipelines.pipeline import Pipeline
from pandasai.pipelines.pipeline_context import PipelineContext


class BaseJudge:
context: PipelineContext
pipeline: Pipeline
logger: Logger

def __init__(
self,
pipeline: Pipeline,
) -> None:
self.pipeline = pipeline

def evaluate(self, query: str, code: str) -> bool:
raise NotImplementedError

Check warning on line 18 in pandasai/agent/base_judge.py

View check run for this annotation

Codecov / codecov/patch

pandasai/agent/base_judge.py#L18

Added line #L18 was not covered by tests
30 changes: 30 additions & 0 deletions pandasai/ee/agents/judge_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional, Union

from pandasai.agent.base_judge import BaseJudge
from pandasai.config import load_config_from_json
from pandasai.ee.agents.judge_agent.pipeline.judge_pipeline import JudgePipeline
from pandasai.pipelines.abstract_pipeline import AbstractPipeline
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput
from pandasai.pipelines.pipeline_context import PipelineContext
from pandasai.schemas.df_config import Config


class JudgeAgent(BaseJudge):
def __init__(
self,
config: Optional[Union[Config, dict]] = None,
pipeline: AbstractPipeline = None,
) -> None:
context = None
if config:
if isinstance(config, dict):
config = Config(**load_config_from_json(config))

Check warning on line 21 in pandasai/ee/agents/judge_agent/__init__.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/agents/judge_agent/__init__.py#L20-L21

Added lines #L20 - L21 were not covered by tests

context = PipelineContext(None, config)

Check warning on line 23 in pandasai/ee/agents/judge_agent/__init__.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/agents/judge_agent/__init__.py#L23

Added line #L23 was not covered by tests

pipeline = pipeline or JudgePipeline(context=context)
super().__init__(pipeline)

def evaluate(self, query: str, code: str) -> bool:
input_data = JudgePipelineInput(query, code)
return self.pipeline.run(input_data)

Check warning on line 30 in pandasai/ee/agents/judge_agent/__init__.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/agents/judge_agent/__init__.py#L29-L30

Added lines #L29 - L30 were not covered by tests
34 changes: 34 additions & 0 deletions pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import (
JudgePromptGeneration,
)
from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall
from pandasai.helpers.logger import Logger
from pandasai.helpers.query_exec_tracker import QueryExecTracker
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput
from pandasai.pipelines.pipeline import Pipeline
from pandasai.pipelines.pipeline_context import PipelineContext


class JudgePipeline:
def __init__(
self,
context: Optional[PipelineContext] = None,
logger: Optional[Logger] = None,
query_exec_tracker: QueryExecTracker = None,
):
self.query_exec_tracker = query_exec_tracker

self.pipeline = Pipeline(
context=context,
logger=logger,
query_exec_tracker=self.query_exec_tracker,
steps=[
JudgePromptGeneration(),
LLMCall(),
],
)

def run(self, input: JudgePipelineInput):
return self.pipeline.run(input)

Check warning on line 34 in pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py#L34

Added line #L34 was not covered by tests
50 changes: 50 additions & 0 deletions pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import datetime
from typing import Any

from pandasai.ee.agents.judge_agent.prompts.judge_agent_prompt import JudgeAgentPrompt
from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput
from pandasai.pipelines.logic_unit_output import LogicUnitOutput


class JudgePromptGeneration(BaseLogicUnit):
"""
Code Prompt Generation Stage
"""

pass

def execute(self, input_data: JudgePipelineInput, **kwargs) -> Any:
"""
This method will return output according to
Implementation.

:param input: Last logic unit output
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.

:return: LogicUnitOutput(prompt)
"""
self.context = kwargs.get("context")
self.logger: Logger = kwargs.get("logger")

now = datetime.datetime.now()
human_readable_datetime = now.strftime("%A, %B %d, %Y %I:%M %p")

prompt = JudgeAgentPrompt(
query=input_data.query,
code=input_data.code,
context=self.context,
date=human_readable_datetime,
)
self.logger.log(f"Using prompt: {prompt}")

return LogicUnitOutput(
prompt,
True,
"Prompt Generated Successfully",
{"content_type": "prompt", "value": prompt.to_string()},
)
64 changes: 64 additions & 0 deletions pandasai/ee/agents/judge_agent/pipeline/llm_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any

from pandasai.exceptions import InvalidOutputValueMismatch
from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.pipelines.logic_unit_output import LogicUnitOutput
from pandasai.pipelines.pipeline_context import PipelineContext


class LLMCall(BaseLogicUnit):
"""
LLM Code Generation Stage
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def execute(self, input: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.

:param input: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.

:return: The result of the execution.
"""
pipeline_context: PipelineContext = kwargs.get("context")
logger: Logger = kwargs.get("logger")

retry_count = 0
while retry_count <= pipeline_context.config.max_retries:
response = pipeline_context.config.llm.call(input, pipeline_context)

logger.log(
f"""LLM response:
{response}
"""
)
try:
result = False
if "<Yes>" in response:
result = True
elif "<No>" in response:
result = False
else:
raise InvalidOutputValueMismatch("Invalid response of LLM Call")

pipeline_context.add("llm_call", response)

return LogicUnitOutput(
result,
True,
"Code Generated Successfully",
{"content_type": "string", "value": response},
)
except Exception:
if retry_count == pipeline_context.config.max_retries:
raise

retry_count += 1
39 changes: 39 additions & 0 deletions pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from pathlib import Path

from jinja2 import Environment, FileSystemLoader

from pandasai.prompts.base import BasePrompt


class JudgeAgentPrompt(BasePrompt):
"""Prompt to generate Python code from a dataframe."""

template_path = "judge_agent_prompt.tmpl"

def __init__(self, **kwargs):
"""Initialize the prompt."""
self.props = kwargs

if self.template:
env = Environment()
self.prompt = env.from_string(self.template)

Check warning on line 19 in pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py#L18-L19

Added lines #L18 - L19 were not covered by tests
elif self.template_path:
# find path to template file
current_dir_path = Path(__file__).parent

path_to_template = current_dir_path / "templates"
env = Environment(loader=FileSystemLoader(path_to_template))
self.prompt = env.get_template(self.template_path)

self._resolved_prompt = None

def to_json(self):
context = self.props["context"]
memory = context.memory
conversations = memory.to_json()
system_prompt = memory.get_system_prompt()
return {

Check warning on line 35 in pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py

View check run for this annotation

Codecov / codecov/patch

pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py#L31-L35

Added lines #L31 - L35 were not covered by tests
"conversation": conversations,
"system_prompt": system_prompt,
"prompt": self.to_string(),
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Today is {{date}}
### QUERY
{{query}}
### GENERATED CODE
{{code}}

Reason step by step and at the end answer:
1. Explain what the code does
2. Explain what the user query asks for
3. Strictly compare the query with the code that is generated
Always return <Yes> or <No> if exactly meets the requirements
4 changes: 4 additions & 0 deletions pandasai/ee/agents/semantic_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

from pandasai.agent.base import BaseAgent
from pandasai.agent.base_judge import BaseJudge
from pandasai.connectors.base import BaseConnector
from pandasai.connectors.pandas import PandasConnector
from pandasai.constants import PANDASBI_SETUP_MESSAGE
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(
pipeline: Optional[Type[GenerateChatPipeline]] = None,
vectorstore: Optional[VectorStore] = None,
description: str = None,
judge: BaseJudge = None,
):
super().__init__(dfs, config, memory_size, vectorstore, description)

Expand Down Expand Up @@ -70,6 +72,7 @@ def __init__(
pipeline(
self.context,
self.logger,
judge=judge,
on_prompt_generation=self._callbacks.on_prompt_generation,
on_code_generation=self._callbacks.on_code_generation,
before_code_execution=self._callbacks.before_code_execution,
Expand All @@ -79,6 +82,7 @@ def __init__(
else SemanticChatPipeline(
self.context,
self.logger,
judge=judge,
on_prompt_generation=self._callbacks.on_prompt_generation,
on_code_generation=self._callbacks.on_code_generation,
before_code_execution=self._callbacks.before_code_execution,
Expand Down
7 changes: 2 additions & 5 deletions pandasai/ee/agents/semantic_agent/pipeline/llm_call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Callable
from typing import Any

from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
Expand All @@ -12,11 +12,8 @@ class LLMCall(BaseLogicUnit):
LLM Code Generation Stage
"""

def __init__(
self, on_code_generation: Callable[[str, Exception], None] = None, **kwargs
):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.on_code_generation = on_code_generation

def execute(self, input: Any, **kwargs) -> Any:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from pandasai.agent.base_judge import BaseJudge
from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator
from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.error_correction_pipeline import (
ErrorCorrectionPipeline,
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(
self,
context: Optional[PipelineContext] = None,
logger: Optional[Logger] = None,
judge: BaseJudge = None,
on_prompt_generation=None,
on_code_generation=None,
before_code_execution=None,
Expand All @@ -49,6 +51,7 @@ def __init__(
super().__init__(
context,
logger,
judge=judge,
on_prompt_generation=on_prompt_generation,
on_code_generation=on_code_generation,
before_code_execution=before_code_execution,
Expand Down
Loading
Loading