Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add output_type parameter (#519) #562

Merged
merged 10 commits into from
Sep 18, 2023
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ nav:
- Documents Building: building_docs.md
- License: license.md
extra:
version: "1.2.1"
version: "1.2.2"
plugins:
- search
- mkdocstrings:
Expand Down
2 changes: 1 addition & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def last_prompt(self) -> str:

def clear_cache(filename: str = None):
"""Clear the cache"""
cache = Cache(filename or "cache")
cache = Cache(filename or "cache_db")
cache.clear()


Expand Down
15 changes: 10 additions & 5 deletions pandasai/connectors/yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import yfinance as yf
import pandas as pd
from .base import ConnectorConfig, BaseConnector
import time
Expand All @@ -15,6 +14,13 @@ class YahooFinanceConnector(BaseConnector):
_cache_interval: int = 600 # 10 minutes

def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
try:
import yfinance
except ImportError:
raise ImportError(
"Could not import yfinance python package. "
"Please install it with `pip install yfinance`."
)
yahoo_finance_config = ConnectorConfig(
dialect="yahoo_finance",
username="",
Expand All @@ -27,6 +33,7 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
)
self._cache_interval = cache_interval
super().__init__(yahoo_finance_config)
self.ticker = yfinance.Ticker(self._config.table)

def head(self):
"""
Expand All @@ -36,8 +43,7 @@ def head(self):
DataFrameType: The head of the data source that the connector is
connected to.
"""
ticker = yf.Ticker(self._config.table)
head_data = ticker.history(period="5d")
head_data = self.ticker.history(period="5d")
return head_data

def _get_cache_path(self, include_additional_filters: bool = False):
Expand Down Expand Up @@ -105,8 +111,7 @@ def execute(self):
return pd.read_csv(cached_path)

# Use yfinance to retrieve historical stock data
ticker = yf.Ticker(self._config.table)
stock_data = ticker.history(period="max")
stock_data = self.ticker.history(period="max")

# Save the result to the cache
stock_data.to_csv(self._get_cache_path(), index=False)
Expand Down
41 changes: 21 additions & 20 deletions pandasai/helpers/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Cache module for caching queries."""
import glob
import os
import shelve
import glob
import duckdb
from .path import find_project_root


Expand All @@ -13,17 +12,20 @@ class Cache:
filename (str): filename to store the cache.
"""

def __init__(self, filename="cache"):
# define cache directory and create directory if it does not exist
def __init__(self, filename="cache_db"):
# Define cache directory and create directory if it does not exist
try:
cache_dir = os.path.join((find_project_root()), "cache")
cache_dir = os.path.join(find_project_root(), "cache")
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")

os.makedirs(cache_dir, mode=0o777, exist_ok=True)

self.filepath = os.path.join(cache_dir, filename)
self.cache = shelve.open(self.filepath)
self.filepath = os.path.join(cache_dir, filename + ".db")
self.connection = duckdb.connect(self.filepath)
self.connection.execute(
"CREATE TABLE IF NOT EXISTS cache (key STRING, value STRING)"
)

def set(self, key: str, value: str) -> None:
"""Set a key value pair in the cache.
Expand All @@ -32,8 +34,7 @@ def set(self, key: str, value: str) -> None:
key (str): key to store the value.
value (str): value to store in the cache.
"""

self.cache[key] = value
self.connection.execute("INSERT INTO cache VALUES (?, ?)", [key, value])

def get(self, key: str) -> str:
"""Get a value from the cache.
Expand All @@ -44,31 +45,31 @@ def get(self, key: str) -> str:
Returns:
str: value from the cache.
"""

return self.cache.get(key)
result = self.connection.execute("SELECT value FROM cache WHERE key=?", [key])
row = result.fetchone()
if row:
return row[0]
else:
return None

def delete(self, key: str) -> None:
"""Delete a key value pair from the cache.

Args:
key (str): key to delete the value from the cache.
"""

if key in self.cache:
del self.cache[key]
self.connection.execute("DELETE FROM cache WHERE key=?", [key])

def close(self) -> None:
"""Close the cache."""

self.cache.close()
self.connection.close()

def clear(self) -> None:
"""Clean the cache."""

self.cache.clear()
self.connection.execute("DELETE FROM cache")

def destroy(self) -> None:
"""Destroy the cache."""
self.cache.close()
self.connection.close()
for cache_file in glob.glob(self.filepath + ".*"):
os.remove(cache_file)
45 changes: 45 additions & 0 deletions pandasai/helpers/output_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Union

from ._output_types import (
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType,
)

output_types_map = {
"number": NumberOutputType,
"dataframe": DataFrameOutputType,
"plot": PlotOutputType,
"string": StringOutputType,
}


def output_type_factory(
output_type,
) -> Union[
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType,
]:
"""
Factory function to get appropriate instance for output type.

Uses `output_types_map` to determine the output type class.

Args:
output_type (str): A name of the output type.

Returns:
(Union[
NumberOutputType,
DataFrameOutputType,
PlotOutputType,
StringOutputType,
DefaultOutputType
]): An instance of the output type.
"""
return output_types_map.get(output_type, DefaultOutputType)()
nautics889 marked this conversation as resolved.
Show resolved Hide resolved
160 changes: 160 additions & 0 deletions pandasai/helpers/output_types/_output_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import re
from decimal import Decimal
from abc import abstractmethod, ABC
from typing import Any, Iterable

import pandas as pd
import polars as pl


class BaseOutputType(ABC):
@property
@abstractmethod
def template_hint(self) -> str:
...

@property
@abstractmethod
def name(self) -> str:
...

def _validate_type(self, actual_type: str) -> bool:
if actual_type != self.name:
return False
return True

@abstractmethod
def _validate_value(self, actual_value):
...

def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]:
gventuri marked this conversation as resolved.
Show resolved Hide resolved
"""
Validate 'type' and 'value' from the result dict.

Args:
result (dict[str, Any]): The result of code execution in
dict representation. Should have the following schema:
{
"type": <output_type_name>,
"value": <generated_value>
}

Returns:
(tuple(bool, Iterable(str)):
Boolean value whether the result matches output type
and collection of logs containing messages about
'type' or 'value' mismatches.
"""
validation_logs = []
actual_type, actual_value = result.get("type"), result.get("value")

type_ok = self._validate_type(actual_type)
if not type_ok:
validation_logs.append(
f"The result dict contains inappropriate 'type'. "
f"Expected '{self.name}', actual '{actual_type}'."
)
value_ok = self._validate_value(actual_value)
if not value_ok:
validation_logs.append(
f"Actual value {repr(actual_value)} seems to be inappropriate "
f"for the type '{self.name}'."
)

return all((type_ok, value_ok)), validation_logs


class NumberOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "number")
- value (must be a number)"""

@property
def name(self):
return "number"

def _validate_value(self, actual_value: Any) -> bool:
if isinstance(actual_value, (int, float, Decimal)):
return True
return False


class DataFrameOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "dataframe")
- value (must be a pandas dataframe)"""

@property
def name(self):
return "dataframe"

def _validate_value(self, actual_value: Any) -> bool:
if isinstance(actual_value, (pd.DataFrame, pl.DataFrame)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather use df_config.df_type like this:

if df_type(actual_value) is not None:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also allows you not having to import polars and pandas!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent, since this releases the module from having excessive imports of third party.


Done

return True
return False


class PlotOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "plot")
- value (must be a string containing the path of the plot image)"""

@property
def name(self):
return "plot"

def _validate_value(self, actual_value: Any) -> bool:
if not isinstance(actual_value, str):
return False

path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$"
if re.match(path_to_plot_pattern, actual_value):
return True

return False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PlotOutputType class uses a regular expression to validate that the actual value is a string representing a valid file path. This is a good approach, but it might be worth considering whether there are any edge cases that this regex doesn't cover. For example, it doesn't seem to handle spaces in file paths.


class StringOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "string")
- value (must be a conversational answer, as a string)"""

@property
def name(self):
return "string"

def _validate_value(self, actual_value: Any) -> bool:
if isinstance(actual_value, str):
return True
return False


class DefaultOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)""" # noqa E501

@property
def name(self):
return "default"

def _validate_type(self, actual_type: str) -> bool:
return True

def _validate_value(self, actual_value: Any) -> bool:
return True

def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable]:
"""
Validate 'type' and 'value' from the result dict.

Returns:
(bool): True since the `DefaultOutputType`
is supposed to have no validation
"""
return True, ()
8 changes: 4 additions & 4 deletions pandasai/llm/huggingface_text_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class HuggingFaceTextGen(LLM):
top_k: Optional[int] = None
top_p: Optional[float] = 0.8
typical_p: Optional[float] = 0.8
temperature: float = 1E-3 # must be strictly positive
temperature: float = 1e-3 # must be strictly positive
repetition_penalty: Optional[float] = None
truncate: Optional[int] = None
stop_sequences: List[str] = None
Expand All @@ -29,7 +29,7 @@ def __init__(self, inference_server_url: str, **kwargs):
try:
import text_generation

for (key, val) in kwargs.items():
for key, val in kwargs.items():
if key in self.__annotations__:
setattr(self, key, val)

Expand Down Expand Up @@ -76,8 +76,8 @@ def call(self, instruction: Prompt, suffix: str = "") -> str:
for stop_seq in self.stop_sequences:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
:res.generated_text.index(stop_seq)
]
: res.generated_text.index(stop_seq)
]
return res.generated_text

@property
Expand Down
Loading