diff --git a/examples/from_snowflake.py b/examples/from_snowflake.py new file mode 100644 index 000000000..3df9968ef --- /dev/null +++ b/examples/from_snowflake.py @@ -0,0 +1,30 @@ +"""Example of using PandasAI with a Snowflake""" + +from pandasai import SmartDataframe +from pandasai.llm import OpenAI +from pandasai.connectors import SnowFlakeConnector + + +snowflake_connector = SnowFlakeConnector( + config={ + "account": "ehxzojy-ue47135", + "database": "SNOWFLAKE_SAMPLE_DATA", + "username": "test", + "password": "*****", + "table": "lineitem", + "warehouse": "COMPUTE_WH", + "dbSchema": "tpch_sf1", + "where": [ + # this is optional and filters the data to + # reduce the size of the dataframe + ["l_quantity", ">", "49"] + ], + } +) + +OPEN_AI_API = "Your-API-Key" +llm = OpenAI(api_token=OPEN_AI_API) +df = SmartDataframe(snowflake_connector, config={"llm": llm}) + +response = df.chat("Count line status is F") +print(response) diff --git a/examples/from_yahoo_finance.py b/examples/from_yahoo_finance.py new file mode 100644 index 000000000..78b07495e --- /dev/null +++ b/examples/from_yahoo_finance.py @@ -0,0 +1,14 @@ +from pandasai.connectors.yahoo_finance import YahooFinanceConnector +from pandasai import SmartDataframe +from pandasai.llm import OpenAI + + +yahoo_connector = YahooFinanceConnector("MSFT") + + +OPEN_AI_API = "OPEN_API_KEY" +llm = OpenAI(api_token=OPEN_AI_API) +df = SmartDataframe(yahoo_connector, config={"llm": llm}) + +response = df.chat("closing price yesterday") +print(response) diff --git a/pandasai/connectors/__init__.py b/pandasai/connectors/__init__.py index 00d6e2083..80beb2a01 100644 --- a/pandasai/connectors/__init__.py +++ b/pandasai/connectors/__init__.py @@ -5,7 +5,7 @@ """ from .base import BaseConnector -from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector +from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector, SnowFlakeConnector from .yahoo_finance import YahooFinanceConnector __all__ = [ @@ -14,4 +14,5 @@ "MySQLConnector", "PostgreSQLConnector", "YahooFinanceConnector", + "SnowFlakeConnector" ] diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py index 55701b43f..178e6474a 100644 --- a/pandasai/connectors/base.py +++ b/pandasai/connectors/base.py @@ -6,25 +6,54 @@ from ..helpers.df_info import DataFrameType from ..helpers.logger import Logger from pydantic import BaseModel -from typing import Optional +from typing import Optional, Union -class ConnectorConfig(BaseModel): +class BaseConnectorConfig(BaseModel): """ - Connector configuration. + Base Connector configuration. """ dialect: Optional[str] = None driver: Optional[str] = None - username: str - password: str - host: str - port: int database: str table: str where: list[list[str]] = None +class YahooFinanceConnectorConfig(BaseConnectorConfig): + """ + Connector configuration for Yahoo Finance. + """ + + host: str + port: int + + +class SQLConnectorConfig(BaseConnectorConfig): + """ + Connector configuration. + """ + + host: str + port: int + username: str + password: str + + +class SnowFlakeConnectorConfig(BaseConnectorConfig): + """ + Connector configuration for SnowFlake. + """ + + account: str + database: str + username: str + password: str + dbSchema: str + warehouse: str + + class BaseConnector(ABC): """ Base connector class to be extended by all connectors. @@ -43,6 +72,23 @@ def __init__(self, config): """ self._config = config + def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): + """Loads passed Configuration to object + + Args: + config (BaseConnectorConfig): Construct config in structure + + Returns: + config: BaseConnectorConfig + """ + pass + + def _init_connection(self, config: BaseConnectorConfig): + """ + make connection to database + """ + pass + @abstractmethod def head(self): """ diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index bf604a575..edc3b53a0 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -5,7 +5,8 @@ import re import os import pandas as pd -from .base import BaseConnector, ConnectorConfig +from .base import BaseConnector, SQLConnectorConfig +from .base import BaseConnectorConfig, SnowFlakeConnectorConfig from sqlalchemy import create_engine, sql, text, select, asc from functools import cached_property, cache import hashlib @@ -25,19 +26,46 @@ class SQLConnector(BaseConnector): _columns_count: int = None _cache_interval: int = 600 # 10 minutes - def __init__(self, config: Union[ConnectorConfig, dict], cache_interval: int = 600): + def __init__( + self, config: Union[BaseConnectorConfig, dict], cache_interval: int = 600 + ): """ Initialize the SQL connector with the given configuration. Args: config (ConnectorConfig): The configuration for the SQL connector. """ - config = ConnectorConfig(**config) + config = self._load_connector_config(config) super().__init__(config) if config.dialect is None: raise Exception("SQL dialect must be specified") + self._init_connection(config) + + self._cache_interval = cache_interval + + def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): + """ + Loads passed Configuration to object + + Args: + config (BaseConnectorConfig): Construct config in structure + + Returns: + config: BaseConenctorConfig + """ + return SQLConnectorConfig(**config) + + def _init_connection(self, config: SQLConnectorConfig): + """ + Initialize Database Connection + + Args: + config (SQLConnectorConfig): Configurations to load database + + """ + if config.driver: self._engine = create_engine( f"{config.dialect}+{config.driver}://{config.username}:{config.password}" @@ -48,8 +76,8 @@ def __init__(self, config: Union[ConnectorConfig, dict], cache_interval: int = 6 f"{config.dialect}://{config.username}:{config.password}@{config.host}" f":{str(config.port)}/{config.database}" ) + self._connection = self._engine.connect() - self._cache_interval = cache_interval def __del__(self): """ @@ -340,7 +368,7 @@ class MySQLConnector(SQLConnector): MySQL connectors are used to connect to MySQL databases. """ - def __init__(self, config: ConnectorConfig): + def __init__(self, config: SQLConnectorConfig): """ Initialize the MySQL connector with the given configuration. @@ -369,7 +397,7 @@ class PostgreSQLConnector(SQLConnector): PostgreSQL connectors are used to connect to PostgreSQL databases. """ - def __init__(self, config: ConnectorConfig): + def __init__(self, config: SQLConnectorConfig): """ Initialize the PostgreSQL connector with the given configuration. @@ -413,3 +441,88 @@ def head(self): # Return the head of the data source return pd.read_sql(query, self._connection) + + +class SnowFlakeConnector(SQLConnector): + """ + SnowFlake connectors are used to connect to SnowFlake Data Cloud. + """ + + def __init__(self, config: SnowFlakeConnectorConfig): + """ + Initialize the SnowFlake connector with the given configuration. + + Args: + config (ConnectorConfig): The configuration for the SnowFlake connector. + """ + config["dialect"] = "snowflake" + + if "account" not in config and os.getenv("SNOWFLAKE_HOST"): + config["account"] = os.getenv("SNOWFLAKE_HOST") + if "database" not in config and os.getenv("SNOWFLAKE_DATABASE"): + config["database"] = os.getenv("SNOWFLAKE_DATABASE") + if "warehouse" not in config and os.getenv("SNOWFLAKE_WAREHOUSE"): + config["warehouse"] = os.getenv("SNOWFLAKE_WAREHOUSE") + if "dbSchema" not in config and os.getenv("SNOWFLAKE_SCHEMA"): + config["dbSchema"] = os.getenv("SNOWFLAKE_SCHEMA") + if "username" not in config and os.getenv("SNOWFLAKE_USERNAME"): + config["username"] = os.getenv("SNOWFLAKE_USERNAME") + if "password" not in config and os.getenv("SNOWFLAKE_PASSWORD"): + config["password"] = os.getenv("SNOWFLAKE_PASSWORD") + + super().__init__(config) + + def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): + return SnowFlakeConnectorConfig(**config) + + def _init_connection(self, config: SnowFlakeConnectorConfig): + """ + Initialize Database Connection + + Args: + config (SQLConnectorConfig): Configurations to load database + + """ + self._engine = create_engine( + f"{config.dialect}://{config.username}:{config.password}@{config.account}/?warehouse={config.warehouse}&database={config.database}&schema={config.dbSchema}" + ) + + self._connection = self._engine.connect() + + @cache + def head(self): + """ + Return the head of the data source that the connector is connected to. + This information is passed to the LLM to provide the schema of the data source. + + Returns: + DataFrame: The head of the data source. + """ + + if self.logger: + self.logger.log( + f"Getting head of {self._config.table} " + f"using dialect {self._config.dialect}" + ) + + # Run a SQL query to get all the columns names and 5 random rows + query = self._build_query(limit=5, order="RANDOM()") + + # Return the head of the data source + return pd.read_sql(query, self._connection) + + def __repr__(self): + """ + Return the string representation of the SnowFlake connector. + + Returns: + str: The string representation of the SnowFlake connector. + """ + return ( + f"<{self.__class__.__name__} dialect={self._config.dialect} " + f"username={self._config.username} " + f"password={self._config.password} Account={self._config.account} " + f"warehouse={self._config.warehouse} " + f"database={self._config.database} schema={str(self._config.dbSchema)} " + f"table={self._config.table}>" + ) diff --git a/pandasai/connectors/yahoo_finance.py b/pandasai/connectors/yahoo_finance.py index 2d77b3735..cae4e263d 100644 --- a/pandasai/connectors/yahoo_finance.py +++ b/pandasai/connectors/yahoo_finance.py @@ -1,6 +1,7 @@ import os import pandas as pd -from .base import ConnectorConfig, BaseConnector + +from .base import YahooFinanceConnectorConfig, BaseConnector import time from ..helpers.path import find_project_root import hashlib @@ -21,10 +22,8 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600): "Could not import yfinance python package. " "Please install it with `pip install yfinance`." ) - yahoo_finance_config = ConnectorConfig( + yahoo_finance_config = YahooFinanceConnectorConfig( dialect="yahoo_finance", - username="", - password="", host="yahoo.finance.com", port=443, database="stock_data", @@ -34,7 +33,7 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600): self._cache_interval = cache_interval super().__init__(yahoo_finance_config) self.ticker = yfinance.Ticker(self._config.table) - + def head(self): """ Return the head of the data source that the connector is connected to. diff --git a/pyproject.toml b/pyproject.toml index db97544d7..0361557a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ coverage = "^7.2.7" google-cloud-aiplatform = "^1.26.1" [tool.poetry.extras] -connectors = ["pymysql", "psycopg2"] +connectors = ["pymysql", "psycopg2", "snowflake-sqlalchemy"] google-ai = ["google-generativeai", "google-cloud-aiplatform"] google-sheets = ["beautifulsoup4"] excel = ["openpyxl"] diff --git a/tests/connectors/test_base.py b/tests/connectors/test_base.py index c110375d3..f80ba5b11 100644 --- a/tests/connectors/test_base.py +++ b/tests/connectors/test_base.py @@ -1,5 +1,6 @@ import pytest from pandasai.connectors import BaseConnector +from pandasai.connectors.base import BaseConnectorConfig from pandasai.helpers import Logger @@ -13,6 +14,13 @@ def __init__(self, host, port, database, table): # Mock subclass of BaseConnector for testing class MockConnector(BaseConnector): + + def _load_connector_config(self, config: BaseConnectorConfig): + pass + + def _init_connection(self, config: BaseConnectorConfig): + pass + def head(self): pass diff --git a/tests/connectors/test_snowflake.py b/tests/connectors/test_snowflake.py new file mode 100644 index 000000000..d34de3d2d --- /dev/null +++ b/tests/connectors/test_snowflake.py @@ -0,0 +1,96 @@ +import unittest +import pandas as pd +from unittest.mock import Mock, patch +from pandasai.connectors.base import SnowFlakeConnectorConfig +from pandasai.connectors.sql import SnowFlakeConnector + + +class TestSQLConnector(unittest.TestCase): + @patch("pandasai.connectors.sql.create_engine", autospec=True) + @patch("pandasai.connectors.sql.sql", autospec=True) + def setUp(self, mock_sql, mock_create_engine): + # Create a mock engine and connection + self.mock_engine = Mock() + self.mock_connection = Mock() + self.mock_engine.connect.return_value = self.mock_connection + mock_create_engine.return_value = self.mock_engine + + # Define your ConnectorConfig instance here + self.config = SnowFlakeConnectorConfig( + dialect="snowflake", + account="ehxzojy-ue47135", + username="your_username", + password="your_password", + database="SNOWFLAKE_SAMPLE_DATA", + warehouse="COMPUTED", + dbSchema="tpch_sf1", + table="lineitem", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + self.connector = SnowFlakeConnector(self.config) + + def test_constructor_and_properties(self): + # Test constructor and properties + self.assertEqual(self.connector._config, self.config) + self.assertEqual(self.connector._engine, self.mock_engine) + self.assertEqual(self.connector._connection, self.mock_connection) + self.assertEqual(self.connector._cache_interval, 600) + + def test_repr_method(self): + # Test __repr__ method + expected_repr = ( + "" + ) + self.assertEqual(repr(self.connector), expected_repr) + + def test_build_query_method(self): + # Test _build_query method + query = self.connector._build_query(limit=5, order="RANDOM()") + expected_query = """SELECT * +FROM lineitem +WHERE column_name = :value_0 ORDER BY RANDOM() ASC + LIMIT :param_1""" + + self.assertEqual(str(query), expected_query) + + @patch("pandasai.connectors.sql.pd.read_sql", autospec=True) + def test_head_method(self, mock_read_sql): + expected_data = pd.DataFrame({"Column1": [1, 2, 3], "Column2": [4, 5, 6]}) + mock_read_sql.return_value = expected_data + head_data = self.connector.head() + pd.testing.assert_frame_equal(head_data, expected_data) + + def test_rows_count_property(self): + # Test rows_count property + self.connector._rows_count = None + self.mock_connection.execute.return_value.fetchone.return_value = ( + 50, + ) # Sample rows count + rows_count = self.connector.rows_count + self.assertEqual(rows_count, 50) + + def test_columns_count_property(self): + # Test columns_count property + self.connector._columns_count = None + self.mock_connection.execute.return_value.fetchone.return_value = ( + 8, + ) # Sample columns count + columns_count = self.connector.columns_count + self.assertEqual(columns_count, 8) + + def test_column_hash_property(self): + # Test column_hash property + mock_df = Mock() + mock_df.columns = ["Column1", "Column2"] + self.connector.head = Mock(return_value=mock_df) + column_hash = self.connector.column_hash + self.assertIsNotNone(column_hash) + + def test_fallback_name_property(self): + # Test fallback_name property + fallback_name = self.connector.fallback_name + self.assertEqual(fallback_name, "lineitem") diff --git a/tests/connectors/test_sql.py b/tests/connectors/test_sql.py index f2152aa60..485b57c10 100644 --- a/tests/connectors/test_sql.py +++ b/tests/connectors/test_sql.py @@ -1,7 +1,7 @@ import unittest import pandas as pd from unittest.mock import Mock, patch -from pandasai.connectors.base import ConnectorConfig +from pandasai.connectors.base import SQLConnectorConfig from pandasai.connectors.sql import SQLConnector @@ -16,7 +16,7 @@ def setUp(self, mock_sql, mock_create_engine): mock_create_engine.return_value = self.mock_engine # Define your ConnectorConfig instance here - self.config = ConnectorConfig( + self.config = SQLConnectorConfig( dialect="mysql", driver="pymysql", username="your_username",