Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 1.2 #535

Merged
merged 19 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions examples/from_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Example of using PandasAI with a CSV file."""

from pandasai import SmartDatalake
from pandasai.llm import OpenAI
from pandasai.connectors import MySQLConnector, PostgreSQLConnector

# With a MySQL database
loan_connector = MySQLConnector(
config={
"host": "localhost",
"port": 3306,
"database": "mydb",
"username": "root",
"password": "root",
"table": "loans",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["loan_status", "=", "PAIDOFF"],
],
}
gventuri marked this conversation as resolved.
Show resolved Hide resolved
)

# With a PostgreSQL database
payment_connector = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "mydb",
"username": "root",
"password": "root",
"table": "payments",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["payment_status", "=", "PAIDOFF"],
],
}
gventuri marked this conversation as resolved.
Show resolved Hide resolved
)

llm = OpenAI()
df = SmartDatalake([loan_connector, payment_connector], config={"llm": llm})
response = df.chat("How many people from the United states?")
print(response)
# Output: 247 loans have been paid off by men.
17 changes: 17 additions & 0 deletions pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Connectors are used to connect to databases, external APIs, and other data sources.

The connectors package contains all the connectors that are used by the application.
"""

from .base import BaseConnector
from .sql import SQLConnector, MySQLConnector, PostgreSQLConnector
from .yahoo_finance import YahooFinanceConnector

__all__ = [
"BaseConnector",
"SQLConnector",
"MySQLConnector",
"PostgreSQLConnector",
"YahooFinanceConnector",
]
136 changes: 136 additions & 0 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Base connector class to be extended by all connectors.
"""

from abc import ABC, abstractmethod
from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from pydantic import BaseModel
from typing import Optional


class ConnectorConfig(BaseModel):
"""
Connector configuration.
"""

dialect: Optional[str] = None
driver: Optional[str] = None
username: str
password: str
host: str
port: int
database: str
table: str
where: list[list[str]] = None


class BaseConnector(ABC):
"""
Base connector class to be extended by all connectors.
"""

_config = None
_logger: Logger = None
_additional_filters: list[list[str]] = None

Comment on lines +33 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class variables _config, _logger, and _additional_filters are not typed. It's a good practice to type all variables for better readability and maintainability.

-    _config = None
-    _logger: Logger = None
-    _additional_filters: list[list[str]] = None
+    _config: ConnectorConfig = None
+    _logger: Optional[Logger] = None
+    _additional_filters: Optional[list[list[str]]] = None

def __init__(self, config):
"""
Initialize the connector with the given configuration.

Args:
config (dict): The configuration for the connector.
"""
self._config = config

Comment on lines +37 to +45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __init__ method accepts a dictionary as the configuration but the class ConnectorConfig is defined to hold the configuration. It would be better to use this class instead of a dictionary for type safety and better code organization.

-    def __init__(self, config):
+    def __init__(self, config: ConnectorConfig):

@abstractmethod
def head(self):
"""
Return the head of the data source that the connector is connected to.
This information is passed to the LLM to provide the schema of the
data source.
"""
pass

@abstractmethod
def execute(self) -> DataFrameType:
"""
Execute the given query on the data source that the connector is
connected to.
"""
pass

def set_additional_filters(self, filters: dict):
"""
Add additional filters to the connector.

Args:
filters (dict): The additional filters to add to the connector.
"""
self._additional_filters = filters if filters else []

@property
def rows_count(self):
"""
Return the number of rows in the data source that the connector is
connected to.
"""
raise NotImplementedError

@property
def columns_count(self):
"""
Return the number of columns in the data source that the connector is
connected to.
"""
raise NotImplementedError

@property
def column_hash(self):
"""
Return the hash code that is unique to the columns of the data source
that the connector is connected to.
"""
raise NotImplementedError
Comment on lines +78 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The properties rows_count, columns_count, and column_hash raise NotImplementedError. If these methods are mandatory for all subclasses, they should be declared as abstract methods. If they're optional, it would be better to return None instead of raising an error.

-    @property
-    def rows_count(self):
-        """
-        Return the number of rows in the data source that the connector is
-        connected to.
-        """
-        raise NotImplementedError
-
-    @property
-    def columns_count(self):
-        """
-        Return the number of columns in the data source that the connector is
-        connected to.
-        """
-        raise NotImplementedError
-
-    @property
-    def column_hash(self):
-        """
-        Return the hash code that is unique to the columns of the data source
-        that the connector is connected to.
-        """
-        raise NotImplementedError
+    @abstractmethod
+    def rows_count(self):
+        pass
+
+    @abstractmethod
+    def columns_count(self):
+        pass
+
+    @abstractmethod
+    def column_hash(self):
+        pass


@property
def path(self):
"""
Return the path of the data source that the connector is connected to.
"""
# JDBC string
return (
self.__class__.__name__
+ "://"
+ self._config.host
+ ":"
+ str(self._config.port)
+ "/"
+ self._config.database
+ "/"
+ self._config.table
)

@property
def logger(self):
"""
Return the logger for the connector.
"""
return self._logger

@logger.setter
def logger(self, logger: Logger):
"""
Set the logger for the connector.

Args:
logger (Logger): The logger for the connector.
"""
self._logger = logger

@property
def fallback_name(self):
"""
Return the name of the table that the connector is connected to.
"""
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The property fallback_name raises NotImplementedError. If this method is mandatory for all subclasses, it should be declared as an abstract method. If it's optional, it would be better to return None instead of raising an error.

-    @property
-    def fallback_name(self):
-        """
-        Return the name of the table that the connector is connected to.
-        """
-        raise NotImplementedError
+    @abstractmethod
+    def fallback_name(self):
+        pass

Loading
Loading