From 3aabd24f418f7b0349942ceaebdba1f3e03f7ce2 Mon Sep 17 00:00:00 2001 From: mspronesti Date: Thu, 14 Sep 2023 21:00:13 +0200 Subject: [PATCH] fix(connectors): make sqlalchemy non-optional and check yfinance imports --- pandasai/connectors/yahoo_finance.py | 15 ++++++++++----- pyproject.toml | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pandasai/connectors/yahoo_finance.py b/pandasai/connectors/yahoo_finance.py index 9beccd66f..2d77b3735 100644 --- a/pandasai/connectors/yahoo_finance.py +++ b/pandasai/connectors/yahoo_finance.py @@ -1,5 +1,4 @@ import os -import yfinance as yf import pandas as pd from .base import ConnectorConfig, BaseConnector import time @@ -15,6 +14,13 @@ class YahooFinanceConnector(BaseConnector): _cache_interval: int = 600 # 10 minutes def __init__(self, stock_ticker, where=None, cache_interval: int = 600): + try: + import yfinance + except ImportError: + raise ImportError( + "Could not import yfinance python package. " + "Please install it with `pip install yfinance`." + ) yahoo_finance_config = ConnectorConfig( dialect="yahoo_finance", username="", @@ -27,6 +33,7 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600): ) self._cache_interval = cache_interval super().__init__(yahoo_finance_config) + self.ticker = yfinance.Ticker(self._config.table) def head(self): """ @@ -36,8 +43,7 @@ def head(self): DataFrameType: The head of the data source that the connector is connected to. """ - ticker = yf.Ticker(self._config.table) - head_data = ticker.history(period="5d") + head_data = self.ticker.history(period="5d") return head_data def _get_cache_path(self, include_additional_filters: bool = False): @@ -105,8 +111,7 @@ def execute(self): return pd.read_csv(cached_path) # Use yfinance to retrieve historical stock data - ticker = yf.Ticker(self._config.table) - stock_data = ticker.history(period="max") + stock_data = self.ticker.history(period="max") # Save the result to the cache stock_data.to_csv(self._get_cache_path(), index=False) diff --git a/pyproject.toml b/pyproject.toml index 666590bfa..efe0cdcf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ openai = "^0.27.5" ipython = "^8.13.1" matplotlib = "^3.7.1" pydantic = "^1" +sqlalchemy = "^2.0.19" google-generativeai = {version = "^0.1.0rc2", optional = true} google-cloud-aiplatform = {version = "^1.26.1", optional = true} langchain = {version = "^0.0.199", optional = true} @@ -32,7 +33,6 @@ streamlit = {version = "^1.23.1", optional = true} beautifulsoup4 = { version = "^4.12.2", optional = true } text-generation = { version = ">=0.6.0", optional = true } openpyxl = { version = "^3.0.7", optional = true } -sqlalchemy = { version = "^2.0.19", optional = true } pymysql = { version = "^1.1.0", optional = true } psycopg2 = { version = "^2.9.7", optional = true } yfinance = { version = "^0.2.28", optional = true } @@ -51,7 +51,7 @@ coverage = "^7.2.7" google-cloud-aiplatform = "^1.26.1" [tool.poetry.extras] -connectors = ["sqlalchemy", "pymysql", "psycopg2"] +connectors = ["pymysql", "psycopg2"] google-ai = ["google-generativeai", "google-cloud-aiplatform"] google-sheets = ["beautifulsoup4"] excel = ["openpyxl"]