Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add output_type parameter (#519) #562

Merged
merged 10 commits into from
Sep 18, 2023
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ nav:
- Documents Building: building_docs.md
- License: license.md
extra:
version: "1.2.1"
version: "1.2.2"
plugins:
- search
- mkdocstrings:
Expand Down
2 changes: 1 addition & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def last_prompt(self) -> str:

def clear_cache(filename: str = None):
"""Clear the cache"""
cache = Cache(filename or "cache")
cache = Cache(filename or "cache_db")
cache.clear()


Expand Down
15 changes: 10 additions & 5 deletions pandasai/connectors/yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import yfinance as yf
import pandas as pd
from .base import ConnectorConfig, BaseConnector
import time
Expand All @@ -15,6 +14,13 @@ class YahooFinanceConnector(BaseConnector):
_cache_interval: int = 600 # 10 minutes

def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
try:
import yfinance
except ImportError:
raise ImportError(
"Could not import yfinance python package. "
"Please install it with `pip install yfinance`."
)
yahoo_finance_config = ConnectorConfig(
dialect="yahoo_finance",
username="",
Expand All @@ -27,6 +33,7 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
)
self._cache_interval = cache_interval
super().__init__(yahoo_finance_config)
self.ticker = yfinance.Ticker(self._config.table)

def head(self):
"""
Expand All @@ -36,8 +43,7 @@ def head(self):
DataFrameType: The head of the data source that the connector is
connected to.
"""
ticker = yf.Ticker(self._config.table)
head_data = ticker.history(period="5d")
head_data = self.ticker.history(period="5d")
return head_data

def _get_cache_path(self, include_additional_filters: bool = False):
Expand Down Expand Up @@ -105,8 +111,7 @@ def execute(self):
return pd.read_csv(cached_path)

# Use yfinance to retrieve historical stock data
ticker = yf.Ticker(self._config.table)
stock_data = ticker.history(period="max")
stock_data = self.ticker.history(period="max")

# Save the result to the cache
stock_data.to_csv(self._get_cache_path(), index=False)
Expand Down
41 changes: 21 additions & 20 deletions pandasai/helpers/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Cache module for caching queries."""
import glob
import os
import shelve
import glob
import duckdb
from .path import find_project_root


Expand All @@ -13,17 +12,20 @@ class Cache:
filename (str): filename to store the cache.
"""

def __init__(self, filename="cache"):
# define cache directory and create directory if it does not exist
def __init__(self, filename="cache_db"):
# Define cache directory and create directory if it does not exist
try:
cache_dir = os.path.join((find_project_root()), "cache")
cache_dir = os.path.join(find_project_root(), "cache")
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")

os.makedirs(cache_dir, mode=0o777, exist_ok=True)

self.filepath = os.path.join(cache_dir, filename)
self.cache = shelve.open(self.filepath)
self.filepath = os.path.join(cache_dir, filename + ".db")
self.connection = duckdb.connect(self.filepath)
self.connection.execute(
"CREATE TABLE IF NOT EXISTS cache (key STRING, value STRING)"
)

def set(self, key: str, value: str) -> None:
"""Set a key value pair in the cache.
Expand All @@ -32,8 +34,7 @@ def set(self, key: str, value: str) -> None:
key (str): key to store the value.
value (str): value to store in the cache.
"""

self.cache[key] = value
self.connection.execute("INSERT INTO cache VALUES (?, ?)", [key, value])

def get(self, key: str) -> str:
"""Get a value from the cache.
Expand All @@ -44,31 +45,31 @@ def get(self, key: str) -> str:
Returns:
str: value from the cache.
"""

return self.cache.get(key)
result = self.connection.execute("SELECT value FROM cache WHERE key=?", [key])
row = result.fetchone()
if row:
return row[0]
else:
return None

def delete(self, key: str) -> None:
"""Delete a key value pair from the cache.

Args:
key (str): key to delete the value from the cache.
"""

if key in self.cache:
del self.cache[key]
self.connection.execute("DELETE FROM cache WHERE key=?", [key])

def close(self) -> None:
"""Close the cache."""

self.cache.close()
self.connection.close()

def clear(self) -> None:
"""Clean the cache."""

self.cache.clear()
self.connection.execute("DELETE FROM cache")

def destroy(self) -> None:
"""Destroy the cache."""
self.cache.close()
self.connection.close()
for cache_file in glob.glob(self.filepath + ".*"):
os.remove(cache_file)
8 changes: 4 additions & 4 deletions pandasai/llm/huggingface_text_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class HuggingFaceTextGen(LLM):
top_k: Optional[int] = None
top_p: Optional[float] = 0.8
typical_p: Optional[float] = 0.8
temperature: float = 1E-3 # must be strictly positive
temperature: float = 1e-3 # must be strictly positive
repetition_penalty: Optional[float] = None
truncate: Optional[int] = None
stop_sequences: List[str] = None
Expand All @@ -29,7 +29,7 @@ def __init__(self, inference_server_url: str, **kwargs):
try:
import text_generation

for (key, val) in kwargs.items():
for key, val in kwargs.items():
if key in self.__annotations__:
setattr(self, key, val)

Expand Down Expand Up @@ -76,8 +76,8 @@ def call(self, instruction: Prompt, suffix: str = "") -> str:
for stop_seq in self.stop_sequences:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
:res.generated_text.index(stop_seq)
]
: res.generated_text.index(stop_seq)
]
return res.generated_text

@property
Expand Down
19 changes: 17 additions & 2 deletions pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def analyze_data(dfs: list[{engine_df_name}]) -> dict:
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.)
4. Output: return a dictionary of:
- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
{output_type_hint}
Example output: {{ "type": "text", "value": "The average loan amount is $15,000." }}
\"\"\"
```
Expand All @@ -70,10 +69,26 @@ def analyze_data(dfs: list[{engine_df_name}]) -> dict:

Updated code:
""" # noqa: E501
_output_type_map = {
"number": """- type (must be "number")
- value (must be a number)""",
"dataframe": """- type (must be "dataframe")
- value (must be a pandas dataframe)""",
"plot": """- type (must be "plot")
- value (must be a string containing the path of the plot image)""",
"string": """- type (must be "string")
- value (must be a conversational answer, as a string)""",
}
_output_type_default = """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)""" # noqa E501

def __init__(self):
default_import = "import pandas as pd"
engine_df_name = "pd.DataFrame"

self.set_var("default_import", default_import)
self.set_var("engine_df_name", engine_df_name)

@classmethod
def get_output_type_hint(cls, output_type):
return cls._output_type_map.get(output_type, cls._output_type_default)
5 changes: 3 additions & 2 deletions pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,18 @@ def add_middlewares(self, *middlewares: Optional[Middleware]):
"""
self.lake.add_middlewares(*middlewares)

def chat(self, query: str):
def chat(self, query: str, output_type: Optional[str] = None):
"""
Run a query on the dataframe.

Args:
query (str): Query to run on the dataframe
output_type (Optional[str]):

Raises:
ValueError: If the query is empty
"""
Comment on lines 336 to 338
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no validation for the output_type parameter. If an invalid value is passed, it could lead to unexpected behavior or errors. Consider adding validation to check if the provided output_type is one of the expected values.

if output_type and output_type not in OutputType.values():
    raise ValueError(f"Invalid output type: {output_type}. Expected one of {OutputType.values()}")

return self.lake.chat(query)
return self.lake.chat(query, output_type)

def column_hash(self) -> str:
"""
Expand Down
5 changes: 4 additions & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def _get_cache_key(self) -> str:

return cache_key

def chat(self, query: str):
def chat(self, query: str, output_type: Optional[str] = None):
"""
Run a query on the dataframe.

Args:
query (str): Query to run on the dataframe
output_type (Optional[str]):

Raises:
ValueError: If the query is empty
Expand All @@ -280,10 +281,12 @@ def chat(self, query: str):
self.logger.log("Using cached response")
code = self._cache.get(self._get_cache_key())
else:
prompt_cls = GeneratePythonCodePrompt
default_values = {
# TODO: find a better way to determine the engine,
"engine": self._dfs[0].engine,
"save_charts_path": self._config.save_charts_path.rstrip("/"),
"output_type_hint": prompt_cls.get_output_type_hint(output_type),
nautics889 marked this conversation as resolved.
Show resolved Hide resolved
}
generate_python_code_instruction = self._get_prompt(
"generate_python_code",
Expand Down
Loading
Loading