Skip to content

Commit

Permalink
feat: add yahoo finance connector
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Sep 4, 2023
1 parent 792ae7b commit 3a92d33
Show file tree
Hide file tree
Showing 9 changed files with 494 additions and 36 deletions.
9 changes: 8 additions & 1 deletion pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,12 @@

from .base import BaseConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector
from .yahoo_finance import YahooFinanceConnector

__all__ = ["BaseConnector", "SQLConnector", "MySQLConnector", "PostgreSQLConnector"]
__all__ = [
"BaseConnector",
"SQLConnector",
"MySQLConnector",
"PostgreSQLConnector",
"YahooFinanceConnector",
]
18 changes: 18 additions & 0 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@
from abc import ABC, abstractmethod
from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from pydantic import BaseModel
from typing import Optional


class ConnectorConfig(BaseModel):
"""
Connector configuration.
"""

dialect: Optional[str] = None
driver: Optional[str] = None
username: str
password: str
host: str
port: int
database: str
table: str
where: list[list[str]] = None


class BaseConnector(ABC):
Expand Down
40 changes: 11 additions & 29 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,14 @@

import os
import pandas as pd
from .base import BaseConnector
from .base import BaseConnector, ConnectorConfig
from sqlalchemy import create_engine, sql
from pydantic import BaseModel
from typing import Optional
from functools import cached_property, cache
import hashlib
from ..helpers.path import find_project_root
import time


class SQLConfig(BaseModel):
"""
SQL configuration.
"""

dialect: Optional[str] = None
driver: Optional[str] = None
username: str
password: str
host: str
port: str
database: str
table: str
where: list[list[str]] = None


class SQLConnector(BaseConnector):
"""
SQL connectors are used to connect to SQL databases in different dialects.
Expand All @@ -41,14 +23,14 @@ class SQLConnector(BaseConnector):
_columns_count: int = None
_cache_interval: int = 600 # 10 minutes

def __init__(self, config: SQLConfig, cache_interval: int = 600):
def __init__(self, config: ConnectorConfig, cache_interval: int = 600):
"""
Initialize the SQL connector with the given configuration.
Args:
config (SQLConfig): The configuration for the SQL connector.
config (ConnectorConfig): The configuration for the SQL connector.
"""
config = SQLConfig(**config)
config = ConnectorConfig(**config)
super().__init__(config)

if config.dialect is None:
Expand All @@ -57,12 +39,12 @@ def __init__(self, config: SQLConfig, cache_interval: int = 600):
if config.driver:
self._engine = create_engine(
f"{config.dialect}+{config.driver}://{config.username}:{config.password}"
f"@{config.host}:{config.port}/{config.database}"
f"@{config.host}:{str(config.port)}/{config.database}"
)
else:
self._engine = create_engine(
f"{config.dialect}://{config.username}:{config.password}@{config.host}"
f":{config.port}/{config.database}"
f":{str(config.port)}/{config.database}"
)
self._connection = self._engine.connect()
self._cache_interval = cache_interval
Expand All @@ -84,7 +66,7 @@ def __repr__(self):
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"driver={self._config.driver} username={self._config.username} "
f"password={self._config.password} host={self._config.host} "
f"port={self._config.port} database={self._config.database} "
f"port={str(self._config.port)} database={self._config.database} "
f"table={self._config.table}>"
)

Expand Down Expand Up @@ -350,12 +332,12 @@ class MySQLConnector(SQLConnector):
MySQL connectors are used to connect to MySQL databases.
"""

def __init__(self, config: SQLConfig):
def __init__(self, config: ConnectorConfig):
"""
Initialize the MySQL connector with the given configuration.
Args:
config (SQLConfig): The configuration for the MySQL connector.
config (ConnectorConfig): The configuration for the MySQL connector.
"""
config["dialect"] = "mysql"
config["driver"] = "pymysql"
Expand All @@ -379,12 +361,12 @@ class PostgreSQLConnector(SQLConnector):
PostgreSQL connectors are used to connect to PostgreSQL databases.
"""

def __init__(self, config: SQLConfig):
def __init__(self, config: ConnectorConfig):
"""
Initialize the PostgreSQL connector with the given configuration.
Args:
config (SQLConfig): The configuration for the PostgreSQL connector.
config (ConnectorConfig): The configuration for the PostgreSQL connector.
"""
config["dialect"] = "postgresql"
config["driver"] = "psycopg2"
Expand Down
164 changes: 164 additions & 0 deletions pandasai/connectors/yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import os
import yfinance as yf
import pandas as pd
from .base import ConnectorConfig, BaseConnector
import time
from ..helpers.path import find_project_root
import hashlib


class YahooFinanceConnector(BaseConnector):
"""
Yahoo Finance connector for retrieving stock data.
"""

_cache_interval: int = 600 # 10 minutes

def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
yahoo_finance_config = ConnectorConfig(
dialect="yahoo_finance",
username="",
password="",
host="yahoo.finance.com",
port=443,
database="stock_data",
table=stock_ticker,
where=where,
)
self._cache_interval = cache_interval
super().__init__(yahoo_finance_config)

def head(self):
"""
Return the head of the data source that the connector is connected to.
Returns:
DataFrameType: The head of the data source that the connector is
connected to.
"""
ticker = yf.Ticker(self._config.table)
head_data = ticker.history(period="5d")
return head_data

def _get_cache_path(self, include_additional_filters: bool = False):
"""
Return the path of the cache file.
Returns:
str: The path of the cache file.
"""
cache_dir = os.path.join(os.getcwd(), "")
try:
cache_dir = os.path.join((find_project_root()), "cache")
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")

return os.path.join(cache_dir, f"{self._config.table}_data.csv")

def _get_cache_path(self):
"""
Return the path of the cache file for Yahoo Finance data.
"""
try:
cache_dir = os.path.join((find_project_root()), "cache")
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")

os.makedirs(cache_dir, mode=0o777, exist_ok=True)

return os.path.join(cache_dir, f"{self._config.table}_data.csv")

def _cached(self):
"""
Return the cached Yahoo Finance data if it exists and is not older than the
cache interval.
Returns:
DataFrame|None: The cached data if it exists and is not older than the cache
interval, None otherwise.
"""
cache_path = self._get_cache_path()
if not os.path.exists(cache_path):
return None

# If the file is older than 1 day, delete it
if os.path.getmtime(cache_path) < time.time() - self._cache_interval:
if self.logger:
self.logger.log(f"Deleting expired cached data from {cache_path}")
os.remove(cache_path)
return None

if self.logger:
self.logger.log(f"Loading cached data from {cache_path}")

return cache_path

def execute(self):
"""
Execute the connector and return the result.
Returns:
DataFrameType: The result of the connector.
"""
cached_path = self._cached()
if cached_path:
return pd.read_csv(cached_path)

# Use yfinance to retrieve historical stock data
ticker = yf.Ticker(self._config.table)
stock_data = ticker.history(period="max")

# Save the result to the cache
stock_data.to_csv(self._get_cache_path(), index=False)

return stock_data

@property
def rows_count(self):
"""
Return the number of rows in the data source that the connector is
connected to.
Returns:
int: The number of rows in the data source that the connector is
connected to.
"""
stock_data = self.execute()
return len(stock_data)

@property
def columns_count(self):
"""
Return the number of columns in the data source that the connector is
connected to.
Returns:
int: The number of columns in the data source that the connector is
connected to.
"""
stock_data = self.execute()
return len(stock_data.columns)

@property
def column_hash(self):
"""
Return the hash code that is unique to the columns of the data source
that the connector is connected to.
Returns:
int: The hash code that is unique to the columns of the data source
that the connector is connected to.
"""
stock_data = self.execute()
columns_str = "|".join(stock_data.columns)
return hashlib.sha256(columns_str.encode("utf-8")).hexdigest()

@property
def fallback_name(self):
"""
Return the fallback name of the connector.
Returns:
str: The fallback name of the connector.
"""
return self._config.table
21 changes: 20 additions & 1 deletion pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _load_dataframe(self, df):
self.dataframe = None
self.connector = df
self.connector.logger = self._logger
self.dataframe_loaded = False
self._df_loaded = False
elif isinstance(df, str):
self.dataframe = self._import_from_file(df)
elif isinstance(df, pd.Series):
Expand Down Expand Up @@ -110,6 +110,9 @@ def _import_from_file(self, file_path: str):
raise ValueError("Invalid file format.")

def _load_engine(self):
"""
Load the engine of the dataframe (Pandas or Polars)
"""
engine = df_type(self._df)

if engine is None:
Expand All @@ -120,6 +123,15 @@ def _load_engine(self):
self._engine = engine

def _validate_and_convert_dataframe(self, df: DataFrameType) -> DataFrameType:
"""
Validate the dataframe and convert it to a Pandas or Polars dataframe.
Args:
df (DataFrameType): Pandas or Polars dataframe or path to a file
Returns:
DataFrameType: Pandas or Polars dataframe
"""
if isinstance(df, str):
return self._import_from_file(df)
elif isinstance(df, (list, dict)):
Expand Down Expand Up @@ -328,6 +340,9 @@ def column_hash(self) -> str:
def save(self, name: str = None):
"""
Saves the dataframe configuration to be used for later
Args:
name (str, optional): Name of the dataframe configuration. Defaults to None.
"""

config_manager = DfConfigManager(self)
Expand All @@ -336,6 +351,10 @@ def save(self, name: str = None):
def load_connector(self, temporary: bool = False):
"""
Load a connector into the smart dataframe
Args:
temporary (bool, optional): Whether the connector is temporary or not.
Defaults to False.
"""
self._core.load_connector(temporary)

Expand Down
Loading

0 comments on commit 3a92d33

Please sign in to comment.