Skip to content

Commit

Permalink
feat(SnowFlake): snowflake connector (#574)
Browse files Browse the repository at this point in the history
* Add Basic code for SnowFlake Connector

* feat[Snowflake]: Adding SnowFlake Connector

* fix: missing , in where clause of example

* test: snowflake parser improvements

* fix: Yahoo connector

* fix: ruff issues

* fix: example of yahoo finance

* Adding test cases for snowflake

* fix doc string
  • Loading branch information
ArslanSaleem authored Sep 19, 2023
1 parent 5464a4d commit 58559db
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 22 deletions.
30 changes: 30 additions & 0 deletions examples/from_snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Example of using PandasAI with a Snowflake"""

from pandasai import SmartDataframe
from pandasai.llm import OpenAI
from pandasai.connectors import SnowFlakeConnector


snowflake_connector = SnowFlakeConnector(
config={
"account": "ehxzojy-ue47135",
"database": "SNOWFLAKE_SAMPLE_DATA",
"username": "test",
"password": "*****",
"table": "lineitem",
"warehouse": "COMPUTE_WH",
"dbSchema": "tpch_sf1",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["l_quantity", ">", "49"]
],
}
)

OPEN_AI_API = "Your-API-Key"
llm = OpenAI(api_token=OPEN_AI_API)
df = SmartDataframe(snowflake_connector, config={"llm": llm})

response = df.chat("Count line status is F")
print(response)
14 changes: 14 additions & 0 deletions examples/from_yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pandasai.connectors.yahoo_finance import YahooFinanceConnector
from pandasai import SmartDataframe
from pandasai.llm import OpenAI


yahoo_connector = YahooFinanceConnector("MSFT")


OPEN_AI_API = "OPEN_API_KEY"
llm = OpenAI(api_token=OPEN_AI_API)
df = SmartDataframe(yahoo_connector, config={"llm": llm})

response = df.chat("closing price yesterday")
print(response)
3 changes: 2 additions & 1 deletion pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from .base import BaseConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector, SnowFlakeConnector
from .yahoo_finance import YahooFinanceConnector

__all__ = [
Expand All @@ -14,4 +14,5 @@
"MySQLConnector",
"PostgreSQLConnector",
"YahooFinanceConnector",
"SnowFlakeConnector"
]
60 changes: 53 additions & 7 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,54 @@
from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from pydantic import BaseModel
from typing import Optional
from typing import Optional, Union


class ConnectorConfig(BaseModel):
class BaseConnectorConfig(BaseModel):
"""
Connector configuration.
Base Connector configuration.
"""

dialect: Optional[str] = None
driver: Optional[str] = None
username: str
password: str
host: str
port: int
database: str
table: str
where: list[list[str]] = None


class YahooFinanceConnectorConfig(BaseConnectorConfig):
"""
Connector configuration for Yahoo Finance.
"""

host: str
port: int


class SQLConnectorConfig(BaseConnectorConfig):
"""
Connector configuration.
"""

host: str
port: int
username: str
password: str


class SnowFlakeConnectorConfig(BaseConnectorConfig):
"""
Connector configuration for SnowFlake.
"""

account: str
database: str
username: str
password: str
dbSchema: str
warehouse: str


class BaseConnector(ABC):
"""
Base connector class to be extended by all connectors.
Expand All @@ -43,6 +72,23 @@ def __init__(self, config):
"""
self._config = config

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
"""Loads passed Configuration to object
Args:
config (BaseConnectorConfig): Construct config in structure
Returns:
config: BaseConnectorConfig
"""
pass

def _init_connection(self, config: BaseConnectorConfig):
"""
make connection to database
"""
pass

@abstractmethod
def head(self):
"""
Expand Down
125 changes: 119 additions & 6 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import re
import os
import pandas as pd
from .base import BaseConnector, ConnectorConfig
from .base import BaseConnector, SQLConnectorConfig
from .base import BaseConnectorConfig, SnowFlakeConnectorConfig
from sqlalchemy import create_engine, sql, text, select, asc
from functools import cached_property, cache
import hashlib
Expand All @@ -25,19 +26,46 @@ class SQLConnector(BaseConnector):
_columns_count: int = None
_cache_interval: int = 600 # 10 minutes

def __init__(self, config: Union[ConnectorConfig, dict], cache_interval: int = 600):
def __init__(
self, config: Union[BaseConnectorConfig, dict], cache_interval: int = 600
):
"""
Initialize the SQL connector with the given configuration.
Args:
config (ConnectorConfig): The configuration for the SQL connector.
"""
config = ConnectorConfig(**config)
config = self._load_connector_config(config)
super().__init__(config)

if config.dialect is None:
raise Exception("SQL dialect must be specified")

self._init_connection(config)

self._cache_interval = cache_interval

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
"""
Loads passed Configuration to object
Args:
config (BaseConnectorConfig): Construct config in structure
Returns:
config: BaseConenctorConfig
"""
return SQLConnectorConfig(**config)

def _init_connection(self, config: SQLConnectorConfig):
"""
Initialize Database Connection
Args:
config (SQLConnectorConfig): Configurations to load database
"""

if config.driver:
self._engine = create_engine(
f"{config.dialect}+{config.driver}://{config.username}:{config.password}"
Expand All @@ -48,8 +76,8 @@ def __init__(self, config: Union[ConnectorConfig, dict], cache_interval: int = 6
f"{config.dialect}://{config.username}:{config.password}@{config.host}"
f":{str(config.port)}/{config.database}"
)

self._connection = self._engine.connect()
self._cache_interval = cache_interval

def __del__(self):
"""
Expand Down Expand Up @@ -340,7 +368,7 @@ class MySQLConnector(SQLConnector):
MySQL connectors are used to connect to MySQL databases.
"""

def __init__(self, config: ConnectorConfig):
def __init__(self, config: SQLConnectorConfig):
"""
Initialize the MySQL connector with the given configuration.
Expand Down Expand Up @@ -369,7 +397,7 @@ class PostgreSQLConnector(SQLConnector):
PostgreSQL connectors are used to connect to PostgreSQL databases.
"""

def __init__(self, config: ConnectorConfig):
def __init__(self, config: SQLConnectorConfig):
"""
Initialize the PostgreSQL connector with the given configuration.
Expand Down Expand Up @@ -413,3 +441,88 @@ def head(self):

# Return the head of the data source
return pd.read_sql(query, self._connection)


class SnowFlakeConnector(SQLConnector):
"""
SnowFlake connectors are used to connect to SnowFlake Data Cloud.
"""

def __init__(self, config: SnowFlakeConnectorConfig):
"""
Initialize the SnowFlake connector with the given configuration.
Args:
config (ConnectorConfig): The configuration for the SnowFlake connector.
"""
config["dialect"] = "snowflake"

if "account" not in config and os.getenv("SNOWFLAKE_HOST"):
config["account"] = os.getenv("SNOWFLAKE_HOST")
if "database" not in config and os.getenv("SNOWFLAKE_DATABASE"):
config["database"] = os.getenv("SNOWFLAKE_DATABASE")
if "warehouse" not in config and os.getenv("SNOWFLAKE_WAREHOUSE"):
config["warehouse"] = os.getenv("SNOWFLAKE_WAREHOUSE")
if "dbSchema" not in config and os.getenv("SNOWFLAKE_SCHEMA"):
config["dbSchema"] = os.getenv("SNOWFLAKE_SCHEMA")
if "username" not in config and os.getenv("SNOWFLAKE_USERNAME"):
config["username"] = os.getenv("SNOWFLAKE_USERNAME")
if "password" not in config and os.getenv("SNOWFLAKE_PASSWORD"):
config["password"] = os.getenv("SNOWFLAKE_PASSWORD")

super().__init__(config)

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
return SnowFlakeConnectorConfig(**config)

def _init_connection(self, config: SnowFlakeConnectorConfig):
"""
Initialize Database Connection
Args:
config (SQLConnectorConfig): Configurations to load database
"""
self._engine = create_engine(
f"{config.dialect}://{config.username}:{config.password}@{config.account}/?warehouse={config.warehouse}&database={config.database}&schema={config.dbSchema}"
)

self._connection = self._engine.connect()

@cache
def head(self):
"""
Return the head of the data source that the connector is connected to.
This information is passed to the LLM to provide the schema of the data source.
Returns:
DataFrame: The head of the data source.
"""

if self.logger:
self.logger.log(
f"Getting head of {self._config.table} "
f"using dialect {self._config.dialect}"
)

# Run a SQL query to get all the columns names and 5 random rows
query = self._build_query(limit=5, order="RANDOM()")

# Return the head of the data source
return pd.read_sql(query, self._connection)

def __repr__(self):
"""
Return the string representation of the SnowFlake connector.
Returns:
str: The string representation of the SnowFlake connector.
"""
return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"username={self._config.username} "
f"password={self._config.password} Account={self._config.account} "
f"warehouse={self._config.warehouse} "
f"database={self._config.database} schema={str(self._config.dbSchema)} "
f"table={self._config.table}>"
)
9 changes: 4 additions & 5 deletions pandasai/connectors/yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pandas as pd
from .base import ConnectorConfig, BaseConnector

from .base import YahooFinanceConnectorConfig, BaseConnector
import time
from ..helpers.path import find_project_root
import hashlib
Expand All @@ -21,10 +22,8 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
"Could not import yfinance python package. "
"Please install it with `pip install yfinance`."
)
yahoo_finance_config = ConnectorConfig(
yahoo_finance_config = YahooFinanceConnectorConfig(
dialect="yahoo_finance",
username="",
password="",
host="yahoo.finance.com",
port=443,
database="stock_data",
Expand All @@ -34,7 +33,7 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
self._cache_interval = cache_interval
super().__init__(yahoo_finance_config)
self.ticker = yfinance.Ticker(self._config.table)

def head(self):
"""
Return the head of the data source that the connector is connected to.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ coverage = "^7.2.7"
google-cloud-aiplatform = "^1.26.1"

[tool.poetry.extras]
connectors = ["pymysql", "psycopg2"]
connectors = ["pymysql", "psycopg2", "snowflake-sqlalchemy"]
google-ai = ["google-generativeai", "google-cloud-aiplatform"]
google-sheets = ["beautifulsoup4"]
excel = ["openpyxl"]
Expand Down
8 changes: 8 additions & 0 deletions tests/connectors/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from pandasai.connectors import BaseConnector
from pandasai.connectors.base import BaseConnectorConfig
from pandasai.helpers import Logger


Expand All @@ -13,6 +14,13 @@ def __init__(self, host, port, database, table):

# Mock subclass of BaseConnector for testing
class MockConnector(BaseConnector):

def _load_connector_config(self, config: BaseConnectorConfig):
pass

def _init_connection(self, config: BaseConnectorConfig):
pass

def head(self):
pass

Expand Down
Loading

0 comments on commit 58559db

Please sign in to comment.