Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add output_type parameter (#519) #562

Merged
merged 10 commits into from
Sep 18, 2023
45 changes: 45 additions & 0 deletions pandasai/helpers/output_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Union

from ._output_types import (
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType,
)

output_types_map = {
"number": NumberOutputType,
"dataframe": DataFrameOutputType,
"plot": PlotOutputType,
"string": StringOutputType,
}


def output_type_factory(
output_type,
) -> Union[
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType,
]:
"""
Factory function to get appropriate instance for output type.

Uses `output_types_map` to determine the output type class.

Args:
output_type (str): A name of the output type.

Returns:
(Union[
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType
]): An instance of the output type.
"""
return output_types_map.get(output_type, DefaultOutputType)()
nautics889 marked this conversation as resolved.
Show resolved Hide resolved
160 changes: 160 additions & 0 deletions pandasai/helpers/output_types/_output_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import re
from decimal import Decimal
from abc import abstractmethod, ABC
from typing import Any, Iterable

import pandas as pd
import polars as pl


class BaseOutputType(ABC):
@property
@abstractmethod
def template_hint(self) -> str:
...

@property
@abstractmethod
def name(self) -> str:
...

def _validate_type(self, actual_type: str) -> bool:
if actual_type != self.name:
return False
return True

@abstractmethod
def _validate_value(self, actual_value):
...

def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]:
gventuri marked this conversation as resolved.
Show resolved Hide resolved
"""
Validate 'type' and 'value' from the result dict.

Args:
result (dict[str, Any]): The result of code execution in
dict representation. Should have the following schema:
{
"type": <output_type_name>,
"value": <generated_value>
}

Returns:
(tuple(bool, Iterable(str)):
Boolean value whether the result matches output type
and collection of logs containing messages about
'type' or 'value' mismatches.
"""
validation_logs = []
actual_type, actual_value = result.get("type"), result.get("value")

type_ok = self._validate_type(actual_type)
if not type_ok:
validation_logs.append(
f"The result dict contains inappropriate 'type'. "
f"Expected '{self.name}', actual '{actual_type}'."
)
value_ok = self._validate_value(actual_value)
if not value_ok:
validation_logs.append(
f"Actual value {repr(actual_value)} seems to be inappropriate "
f"for the type '{self.name}'."
)

return all((type_ok, value_ok)), validation_logs


class NumberOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "number")
- value (must be a number)"""

@property
def name(self):
return "number"

def _validate_value(self, actual_value: Any) -> bool:
if isinstance(actual_value, (int, float, Decimal)):
return True
return False


class DataFrameOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "dataframe")
- value (must be a pandas dataframe)"""

@property
def name(self):
return "dataframe"

def _validate_value(self, actual_value: Any) -> bool:
if isinstance(actual_value, (pd.DataFrame, pl.DataFrame)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use df_config.df_type like this:

if df_type(actual_value) is not None:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also allows you not having to import polars and pandas!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent, since this releases the module from having excessive imports of third party.


Done

return True
return False


class PlotOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "plot")
- value (must be a string containing the path of the plot image)"""

@property
def name(self):
return "plot"

def _validate_value(self, actual_value: Any) -> bool:
if not isinstance(actual_value, str):
return False

path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$"
if re.match(path_to_plot_pattern, actual_value):
return True

return False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PlotOutputType class uses a regular expression to validate that the actual value is a string representing a valid file path. This is a good approach, but it might be worth considering whether there are any edge cases that this regex doesn't cover. For example, it doesn't seem to handle spaces in file paths.


class StringOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "string")
- value (must be a conversational answer, as a string)"""

@property
def name(self):
return "string"

def _validate_value(self, actual_value: Any) -> bool:
if isinstance(actual_value, str):
return True
return False


class DefaultOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)""" # noqa E501

@property
def name(self):
return "default"

def _validate_type(self, actual_type: str) -> bool:
return True

def _validate_value(self, actual_value: Any) -> bool:
return True

def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable]:
"""
Validate 'type' and 'value' from the result dict.

Returns:
(bool): True since the `DefaultOutputType`
is supposed to have no validation
"""
return True, ()
22 changes: 4 additions & 18 deletions pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,12 @@ def analyze_data(dfs: list[{engine_df_name}]) -> dict:

Updated code:
""" # noqa: E501
_output_type_map = {
"number": """- type (must be "number")
- value (must be a number)""",
"dataframe": """- type (must be "dataframe")
- value (must be a pandas dataframe)""",
"plot": """- type (must be "plot")
- value (must be a string containing the path of the plot image)""",
"string": """- type (must be "string")
- value (must be a conversational answer, as a string)""",
}
_output_type_default = """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)""" # noqa E501

def __init__(self):

def __init__(self, **kwargs):
default_import = "import pandas as pd"
engine_df_name = "pd.DataFrame"
output_type_hint = kwargs.pop("output_type_hint")

self.set_var("default_import", default_import)
self.set_var("engine_df_name", engine_df_name)

@classmethod
def get_output_type_hint(cls, output_type):
return cls._output_type_map.get(output_type, cls._output_type_default)
self.set_var("output_type_hint", output_type_hint)
12 changes: 11 additions & 1 deletion pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,17 @@ def chat(self, query: str, output_type: Optional[str] = None):

Args:
query (str): Query to run on the dataframe
output_type (Optional[str]):
output_type (Optional[str]): Add a hint for LLM of which
type should be returned by `analyze_data()` in generated
code. Possible values: "number", "dataframe", "plot", "string":
* number - specifies that user expects to get a number
as a response object
* dataframe - specifies that user expects to get
pandas/polars dataframe as a response object
* plot - specifies that user expects LLM to build
a plot
* string - specifies that user expects to get text
as a response object

Raises:
ValueError: If the query is empty
Expand Down
38 changes: 33 additions & 5 deletions pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import logging
import os

from ..helpers.output_types import output_type_factory
from ..llm.base import LLM
from ..llm.langchain import LangchainLLM

Expand Down Expand Up @@ -206,6 +207,7 @@ def _get_prompt(
key: str,
default_prompt: Type[Prompt],
default_values: Optional[dict] = None,
output_type_hint: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would pass it as part of default values in the generate python code prompt only, so that this is accessible both on the default prompt and in the custom prompts. This might not be needed in every prompt in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, you're right, it makes the additional context to be in one place default_values.
Apologize for spreading the setting logic like this, just we've got rly complicated implementation in ._get_prompt() as for me. As i've mentioned: some variables're being set in __init__(), the rest are being set in _get_prompt() itself, when iterating over default_values. That was kind a tricky moment when i'd had output types mapping inside prompt class before.


Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nautics889 yeah, totally, it's super involved, we'll definitely need to simplify that a little bit in the future 😄

) -> Prompt:
"""
Return a prompt by key.
Expand All @@ -214,7 +216,9 @@ def _get_prompt(
key (str): The key of the prompt
default_prompt (Type[Prompt]): The default prompt to use
default_values (Optional[dict], optional): The default values to use for the
prompt. Defaults to None.
prompt. Defaults to None.
output_type_hint (Optional[str]): Interpolate an according output
type hint to the prompt for LLM.

Returns:
Prompt: The prompt
Expand All @@ -223,7 +227,11 @@ def _get_prompt(
default_values = {}

custom_prompt = self._config.custom_prompts.get(key)
prompt = custom_prompt if custom_prompt else default_prompt()
prompt = (
custom_prompt
if custom_prompt
else default_prompt(output_type_hint=output_type_hint)
)
gventuri marked this conversation as resolved.
Show resolved Hide resolved

# set default values for the prompt
if "dfs" not in default_values:
Expand Down Expand Up @@ -257,7 +265,19 @@ def chat(self, query: str, output_type: Optional[str] = None):

Args:
query (str): Query to run on the dataframe
output_type (Optional[str]):
output_type (Optional[str]): Add a hint for LLM which
type should be returned by `analyze_data()` in generated
code. Possible values: "number", "dataframe", "plot", "string":
* number - specifies that user expects to get a number
as a response object
* dataframe - specifies that user expects to get
pandas/polars dataframe as a response object
* plot - specifies that user expects LLM to build
a plot
* string - specifies that user expects to get text
as a response object
If none `output_type` is specified, the type can be any
of the above or "text".

Raises:
ValueError: If the query is empty
Expand All @@ -273,6 +293,8 @@ def chat(self, query: str, output_type: Optional[str] = None):
self._memory.add(query, True)

try:
output_type_helper = output_type_factory(output_type)

if (
self._config.enable_cache
and self._cache
Expand All @@ -281,17 +303,16 @@ def chat(self, query: str, output_type: Optional[str] = None):
self.logger.log("Using cached response")
code = self._cache.get(self._get_cache_key())
else:
prompt_cls = GeneratePythonCodePrompt
default_values = {
# TODO: find a better way to determine the engine,
"engine": self._dfs[0].engine,
"save_charts_path": self._config.save_charts_path.rstrip("/"),
"output_type_hint": prompt_cls.get_output_type_hint(output_type),
}
generate_python_code_instruction = self._get_prompt(
"generate_python_code",
default_prompt=GeneratePythonCodePrompt,
default_values=default_values,
output_type_hint=output_type_helper.template_hint,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output_type_hint is now passed to the _get_prompt() method. This change assumes that all prompts will accept this argument. If there are any prompts that do not accept this argument, this will cause an error.

- default_values=default_values,
+ default_values=default_values,
+ output_type_hint=output_type_helper.template_hint,


code = self._llm.generate_code(generate_python_code_instruction)
Expand Down Expand Up @@ -344,6 +365,13 @@ def chat(self, query: str, output_type: Optional[str] = None):
code_to_run = self._retry_run_code(code, e)

if result is not None:
if isinstance(result, dict):
validation_ok, validation_logs = output_type_helper.validate(result)
if not validation_ok:
self.logger.log(
"\n".join(validation_logs), level=logging.WARNING
)

gventuri marked this conversation as resolved.
Show resolved Hide resolved
self.last_result = result
self.logger.log(f"Answer: {result}")
gventuri marked this conversation as resolved.
Show resolved Hide resolved
except Exception as exception:
Expand Down
Loading