Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(SnowFlake): snowflake connector #574

Merged
merged 10 commits into from
Sep 19, 2023
30 changes: 30 additions & 0 deletions examples/from_snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Example of using PandasAI with a Snowflake"""

from pandasai import SmartDataframe
from pandasai.llm import OpenAI
from pandasai.connectors import SnowFlakeConnector


snowflake_connector = SnowFlakeConnector(
config={
"account": "ehxzojy-ue47135",
"database": "SNOWFLAKE_SAMPLE_DATA",
"username": "test",
"password": "*****",
"table": "lineitem",
"warehouse": "COMPUTE_WH",
"dbSchema": "tpch_sf1",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["l_quantity", ">", "49"]
],
}
)

OPEN_AI_API = "Your-API-Key"
gventuri marked this conversation as resolved.
Show resolved Hide resolved
llm = OpenAI(api_token=OPEN_AI_API)
df = SmartDataframe(snowflake_connector, config={"llm": llm})

response = df.chat("Count line status is F")
print(response)
gventuri marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 14 additions & 0 deletions examples/from_yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pandasai.connectors.yahoo_finance import YahooFinanceConnector
from pandasai import SmartDataframe
from pandasai.llm import OpenAI


yahoo_connector = YahooFinanceConnector("MSFT")


OPEN_AI_API = "OPEN_API_KEY"
llm = OpenAI(api_token=OPEN_AI_API)
df = SmartDataframe(yahoo_connector, config={"llm": llm})

response = df.chat("closing price yesterday")
print(response)
gventuri marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from .base import BaseConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector, SnowFlakeConnector
from .yahoo_finance import YahooFinanceConnector

__all__ = [
Expand All @@ -14,4 +14,5 @@
"MySQLConnector",
"PostgreSQLConnector",
"YahooFinanceConnector",
"SnowFlakeConnector"
]
60 changes: 53 additions & 7 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,54 @@
from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from pydantic import BaseModel
from typing import Optional
from typing import Optional, Union


class ConnectorConfig(BaseModel):
class BaseConnectorConfig(BaseModel):
"""
Connector configuration.
Base Connector configuration.
"""

dialect: Optional[str] = None
driver: Optional[str] = None
username: str
password: str
host: str
port: int
database: str
table: str
where: list[list[str]] = None


class YahooFinanceConnectorConfig(BaseConnectorConfig):
"""
Connector configuration for Yahoo Finance.
"""

host: str
port: int


class SQLConnectorConfig(BaseConnectorConfig):
"""
Connector configuration.
"""

host: str
port: int
username: str
password: str


class SnowFlakeConnectorConfig(BaseConnectorConfig):
"""
Connector configuration for SnowFlake.
"""

account: str
database: str
username: str
password: str
dbSchema: str
warehouse: str

gventuri marked this conversation as resolved.
Show resolved Hide resolved

class BaseConnector(ABC):
"""
Base connector class to be extended by all connectors.
Expand All @@ -43,6 +72,23 @@ def __init__(self, config):
"""
self._config = config

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
"""Loads passed Configuration to object

Args:
config (BaseConnectorConfig): Construct config in structure

Returns:
config: BaseConnectorConfig
"""
pass
gventuri marked this conversation as resolved.
Show resolved Hide resolved

def _init_connection(self, config: BaseConnectorConfig):
"""
make connection to database
"""
pass

@abstractmethod
def head(self):
"""
Expand Down
125 changes: 119 additions & 6 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import re
import os
import pandas as pd
from .base import BaseConnector, ConnectorConfig
from .base import BaseConnector, SQLConnectorConfig
from .base import BaseConnectorConfig, SnowFlakeConnectorConfig
from sqlalchemy import create_engine, sql, text, select, asc
from functools import cached_property, cache
import hashlib
Expand All @@ -25,19 +26,46 @@ class SQLConnector(BaseConnector):
_columns_count: int = None
_cache_interval: int = 600 # 10 minutes

def __init__(self, config: Union[ConnectorConfig, dict], cache_interval: int = 600):
def __init__(
self, config: Union[BaseConnectorConfig, dict], cache_interval: int = 600
):
gventuri marked this conversation as resolved.
Show resolved Hide resolved
"""
Initialize the SQL connector with the given configuration.

Args:
config (ConnectorConfig): The configuration for the SQL connector.
"""
config = ConnectorConfig(**config)
config = self._load_connector_config(config)
super().__init__(config)

if config.dialect is None:
raise Exception("SQL dialect must be specified")

self._init_connection(config)

self._cache_interval = cache_interval

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
"""
Loads passed Configuration to object

Args:
config (BaseConnectorConfig): Construct config in structure

Returns:
config: BaseConenctorConfig
"""
return SQLConnectorConfig(**config)

def _init_connection(self, config: SQLConnectorConfig):
"""
Initialize Database Connection

Args:
config (SQLConnectorConfig): Configurations to load database

"""

if config.driver:
self._engine = create_engine(
f"{config.dialect}+{config.driver}://{config.username}:{config.password}"
gventuri marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -48,8 +76,8 @@ def __init__(self, config: Union[ConnectorConfig, dict], cache_interval: int = 6
f"{config.dialect}://{config.username}:{config.password}@{config.host}"
f":{str(config.port)}/{config.database}"
)
gventuri marked this conversation as resolved.
Show resolved Hide resolved

self._connection = self._engine.connect()
self._cache_interval = cache_interval

def __del__(self):
"""
Expand Down Expand Up @@ -340,7 +368,7 @@ class MySQLConnector(SQLConnector):
MySQL connectors are used to connect to MySQL databases.
"""

def __init__(self, config: ConnectorConfig):
def __init__(self, config: SQLConnectorConfig):
"""
Initialize the MySQL connector with the given configuration.

Expand Down Expand Up @@ -369,7 +397,7 @@ class PostgreSQLConnector(SQLConnector):
PostgreSQL connectors are used to connect to PostgreSQL databases.
"""

def __init__(self, config: ConnectorConfig):
def __init__(self, config: SQLConnectorConfig):
"""
Initialize the PostgreSQL connector with the given configuration.

Expand Down Expand Up @@ -413,3 +441,88 @@ def head(self):

# Return the head of the data source
return pd.read_sql(query, self._connection)


class SnowFlakeConnector(SQLConnector):
"""
SnowFlake connectors are used to connect to SnowFlake Data Cloud.
"""

def __init__(self, config: SnowFlakeConnectorConfig):
"""
Initialize the SnowFlake connector with the given configuration.

Args:
config (ConnectorConfig): The configuration for the SnowFlake connector.
"""
config["dialect"] = "snowflake"

if "account" not in config and os.getenv("SNOWFLAKE_HOST"):
config["account"] = os.getenv("SNOWFLAKE_HOST")
if "database" not in config and os.getenv("SNOWFLAKE_DATABASE"):
config["database"] = os.getenv("SNOWFLAKE_DATABASE")
if "warehouse" not in config and os.getenv("SNOWFLAKE_WAREHOUSE"):
config["warehouse"] = os.getenv("SNOWFLAKE_WAREHOUSE")
if "dbSchema" not in config and os.getenv("SNOWFLAKE_SCHEMA"):
config["dbSchema"] = os.getenv("SNOWFLAKE_SCHEMA")
if "username" not in config and os.getenv("SNOWFLAKE_USERNAME"):
config["username"] = os.getenv("SNOWFLAKE_USERNAME")
if "password" not in config and os.getenv("SNOWFLAKE_PASSWORD"):
config["password"] = os.getenv("SNOWFLAKE_PASSWORD")

super().__init__(config)

def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]):
return SnowFlakeConnectorConfig(**config)

def _init_connection(self, config: SnowFlakeConnectorConfig):
"""
Initialize Database Connection

Args:
config (SQLConnectorConfig): Configurations to load database

"""
self._engine = create_engine(
f"{config.dialect}://{config.username}:{config.password}@{config.account}/?warehouse={config.warehouse}&database={config.database}&schema={config.dbSchema}"
)

self._connection = self._engine.connect()

@cache
def head(self):
"""
Return the head of the data source that the connector is connected to.
This information is passed to the LLM to provide the schema of the data source.

Returns:
DataFrame: The head of the data source.
"""

if self.logger:
self.logger.log(
f"Getting head of {self._config.table} "
f"using dialect {self._config.dialect}"
)

# Run a SQL query to get all the columns names and 5 random rows
query = self._build_query(limit=5, order="RANDOM()")

# Return the head of the data source
return pd.read_sql(query, self._connection)

def __repr__(self):
"""
Return the string representation of the SnowFlake connector.

Returns:
str: The string representation of the SnowFlake connector.
"""
return (
f"<{self.__class__.__name__} dialect={self._config.dialect} "
f"username={self._config.username} "
f"password={self._config.password} Account={self._config.account} "
f"warehouse={self._config.warehouse} "
f"database={self._config.database} schema={str(self._config.dbSchema)} "
f"table={self._config.table}>"
)
9 changes: 4 additions & 5 deletions pandasai/connectors/yahoo_finance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pandas as pd
from .base import ConnectorConfig, BaseConnector

from .base import YahooFinanceConnectorConfig, BaseConnector
import time
from ..helpers.path import find_project_root
import hashlib
Expand All @@ -21,10 +22,8 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
"Could not import yfinance python package. "
"Please install it with `pip install yfinance`."
)
yahoo_finance_config = ConnectorConfig(
yahoo_finance_config = YahooFinanceConnectorConfig(
dialect="yahoo_finance",
username="",
password="",
host="yahoo.finance.com",
port=443,
database="stock_data",
Expand All @@ -34,7 +33,7 @@ def __init__(self, stock_ticker, where=None, cache_interval: int = 600):
self._cache_interval = cache_interval
super().__init__(yahoo_finance_config)
self.ticker = yfinance.Ticker(self._config.table)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover

def head(self):
"""
Return the head of the data source that the connector is connected to.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ coverage = "^7.2.7"
google-cloud-aiplatform = "^1.26.1"

[tool.poetry.extras]
connectors = ["pymysql", "psycopg2"]
connectors = ["pymysql", "psycopg2", "snowflake-sqlalchemy"]
google-ai = ["google-generativeai", "google-cloud-aiplatform"]
google-sheets = ["beautifulsoup4"]
excel = ["openpyxl"]
Expand Down
8 changes: 8 additions & 0 deletions tests/connectors/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from pandasai.connectors import BaseConnector
from pandasai.connectors.base import BaseConnectorConfig
from pandasai.helpers import Logger


Expand All @@ -13,6 +14,13 @@ def __init__(self, host, port, database, table):

# Mock subclass of BaseConnector for testing
class MockConnector(BaseConnector):

def _load_connector_config(self, config: BaseConnectorConfig):
pass

def _init_connection(self, config: BaseConnectorConfig):
pass

def head(self):
pass

Expand Down
Loading