Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : Airtable connector Support #635

Merged
Merged
29 changes: 29 additions & 0 deletions docs/connectors.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,32 @@ yahoo_connector = YahooFinanceConnector("MSFT")
df = SmartDataframe(yahoo_connector)
df.chat("What is the closing price for yesterday?")
```

## Airtable Connector

The Airtable connector allows you to connect to Airtable Projects Tables, by simply passing the `base_id` , `api_key` and `table_name` of the table you want to analyze.

To use the Airtable connector, you only need to import it into your Python code and pass it to a `SmartDataframe` or `SmartDatalake` object:

```python
from pandasai.connectors import AirtableConnector
from pandasai import SmartDataframe


airtable_connectors = AirtableConnector(
config={
"api_key": "AIRTABLE_API_TOKEN",
"table":"AIRTABLE_TABLE_NAME",
"base_id":"AIRTABLE_BASE_ID",
"where" : [
# this is optional and filters the data to
# reduce the size of the dataframe
["Status" ,"==","In progress"]
]
}
)

df = SmartDataframe(airtable_connectors)

df.chat("How many rows are there in data ?")
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new hunk introduces the Airtable connector and provides a code snippet demonstrating its usage. The code seems to be correct and well-explained. However, the api_key, table, and base_id are hardcoded strings. It's recommended to use environment variables or secure storage for sensitive data like API keys.

airtable_connectors = AirtableConnector(
    config={
-        "api_key": "AIRTABLE_API_TOKEN",
-        "table":"AIRTABLE_TABLE_NAME",
-        "base_id":"AIRTABLE_BASE_ID",
+        "api_key": os.getenv("AIRTABLE_API_TOKEN"),
+        "table": os.getenv("AIRTABLE_TABLE_NAME"),
+        "base_id": os.getenv("AIRTABLE_BASE_ID"),
        "where" : [
            # this is optional and filters the data to
            # reduce the size of the dataframe
            ["Status" ,"==","In progress"]
        ]
    }
)

Don't forget to import the os module at the beginning of your script.

23 changes: 23 additions & 0 deletions examples/from_airtable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pandasai.connectors import AirtableConnector
from pandasai.llm import OpenAI
from pandasai import SmartDataframe


airtable_connectors = AirtableConnector(
config={
"api_key": "AIRTABLE_API_TOKEN",
"table": "AIRTABLE_TABLE_NAME",
"base_id": "AIRTABLE_BASE_ID",
"where": [
# this is optional and filters the data to
# reduce the size of the dataframe
["Status", "==", "In progress"]
],
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AirtableConnector is initialized with a hardcoded configuration. This might not be an issue if this is just an example, but in a production environment, sensitive data like api_key and base_id should be stored securely and not hardcoded. Consider using environment variables or a secure configuration management system.

)

llm = OpenAI("OPENAI_API_KEY")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OpenAI instance is initialized with a hardcoded API key. As with the AirtableConnector, this is a security concern. API keys should be stored securely and not hardcoded.

df = SmartDataframe(airtable_connectors, config={"llm": llm})

response = df.chat("How many rows are there in data ?")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chat method is called on the SmartDataframe object without any error handling. If an error occurs during the execution of this method, it could cause the program to crash. Consider wrapping this call in a try/except block to handle potential exceptions.

print(response)
2 changes: 2 additions & 0 deletions pandasai/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .snowflake import SnowFlakeConnector
from .databricks import DatabricksConnector
from .yahoo_finance import YahooFinanceConnector
from .airtable import AirtableConnector

__all__ = [
"BaseConnector",
Expand All @@ -18,4 +19,5 @@
"YahooFinanceConnector",
"SnowFlakeConnector",
"DatabricksConnector",
"AirtableConnector",
]
232 changes: 232 additions & 0 deletions pandasai/connectors/airtable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""
Airtable connectors are used to connect airtable records.
"""

from .base import AirtableConnectorConfig, BaseConnector, BaseConnectorConfig
from typing import Union, Optional
import requests
import pandas as pd
import os
from ..helpers.path import find_project_root
import time
import hashlib
from ..exceptions import InvalidRequestError


class AirtableConnector(BaseConnector):
"""
Airtable connector to retrieving record data.
"""

def __init__(
self,
config: Optional[Union[AirtableConnectorConfig, dict]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a cache mechanism similar to the other connectors!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tanmaypatil123 check out yahoo finance connector for references

cache_interval: int = 600,
):
Comment on lines +26 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache_interval parameter is not documented in the function docstring. Please add a description for it.

Comment on lines +27 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache_interval parameter is not documented in the function docstring. Please add a description for it.

if isinstance(config, dict):
if "api_key" in config and "base_id" in config and "table" in config:
config = AirtableConnectorConfig(**config)
else:
raise KeyError(
"Please specify all api_key,table,base_id properly in config ."
)
Comment on lines +31 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for the presence of "api_key", "base_id", and "table" in the config dictionary is not robust. If any of these keys have a value that evaluates to False (like an empty string), the condition will fail. Consider using in to check for key existence.

Comment on lines +26 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __init__ method checks if the config parameter is a dictionary and if it contains the keys api_key, base_id, and table. If not, it raises a KeyError. This could be improved by providing a more specific error message for each missing key. This way, the user will know exactly which key is missing from the configuration.

-            if "api_key" in config and "base_id" in config and "table" in config:
-                config = AirtableConnectorConfig(**config)
-            else:
-                raise KeyError(
-                    "Please specify all api_key,table,base_id properly in config ."
-                )
+            missing_keys = [key for key in ["api_key", "base_id", "table"] if key not in config]
+            if missing_keys:
+                raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
+            config = AirtableConnectorConfig(**config)

Comment on lines +27 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config parameter is optional and can be either a dictionary or an AirtableConnectorConfig object. If it's a dictionary, it must contain the keys api_key, base_id, and table. If these keys are not present, a KeyError is raised. Consider providing a more detailed error message to help the user understand what each key represents.

- raise KeyError("Please specify all api_key,table,base_id properly in config .")
+ raise KeyError("Please specify all api_key, table, and base_id properly in the config. api_key is your Airtable API key, base_id is the ID of the base you are connecting to, and table is the name of the table within the base.")

Comment on lines +32 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for the presence of "api_key", "base_id", and "table" in the config dictionary is not robust. If any of these keys have a value that evaluates to False (like an empty string), the check will pass, but the subsequent code may fail. Consider using dict.get() to check for the presence of these keys.


elif not config:
airtable_env_vars = {
"api_key": "AIRTABLE_API_TOKEN",
"base_id": "AIRTABLE_BASE_ID",
"table": "AIRTABLE_TABLE_NAME",
}
config = AirtableConnectorConfig(
**self._populate_config_from_env(config, airtable_env_vars)
)
Comment on lines +39 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config parameter is not used in the _populate_config_from_env method call. If it's not needed, consider removing it.

Comment on lines +38 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is not handling the case where the config is not a dictionary and is not None. This could lead to unexpected behavior if a user passes an invalid configuration. Consider adding an else clause to handle this case.


self._root_url: str = "https://api.airtable.com/v0/"
self._cache_interval = cache_interval

super().__init__(config)

def _init_connection(self, config: BaseConnectorConfig):
"""
make connection to database
"""
config = config.dict()
url = f"{self._root_url}{config['base_id']}/{config['table']}"
response = requests.head(
url=url, headers={"Authorization": f"Bearer {config['api_key']}"}
)
if response.status_code == 200:
self.logger.log(
"""
Connected to Airtable.
"""
)
else:
raise InvalidRequestError(
f"""Failed to connect to Airtable.
Status code: {response.status_code},
message: {response.text}"""
)
Comment on lines +61 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code makes a HEAD request to the Airtable API to check the connection. However, it does not handle potential network errors that could occur during the request, such as a timeout or a connection error. It would be better to wrap the request in a try/except block and handle these potential errors.

Comment on lines +58 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _init_connection method is making a HEAD request to the Airtable API to check the connection. This is a good practice, but it would be better to handle more specific HTTP status codes. For example, a 401 status code means that the API key is invalid, and a 404 status code means that the base or table does not exist. Handling these specific cases can provide more informative error messages to the user.

Comment on lines +54 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _init_connection method sends a HEAD request to the Airtable API to check the connection. If the status code is not 200, it raises an InvalidRequestError. However, the error message could be improved by including the URL that failed to connect. This would help in debugging connection issues.

-                f"""Failed to connect to Airtable. 
-                    Status code: {response.status_code}, 
-                    message: {response.text}"""
+                f"""Failed to connect to Airtable at {url}. 
+                    Status code: {response.status_code}, 
+                    message: {response.text}"""

Comment on lines +59 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The connection initialization does not handle potential network errors that could occur when making the request. Consider adding a try-except block to handle exceptions like requests.exceptions.RequestException.


def _get_cache_path(self, include_additional_filters: bool = False):
"""
Return the path of the cache file.

Returns :
str : The path of the cache file.
"""
cache_dir = os.path.join(os.getcwd(), "")
try:
cache_dir = os.path.join((find_project_root()), "cache")
except ValueError:
cache_dir = os.path.join(os.getcwd(), "cache")
Comment on lines +83 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is trying to find the project root and if it fails, it defaults to the current working directory. This could lead to inconsistent behavior depending on where the script is run from. Consider making the cache directory a configurable option.

return os.path.join(cache_dir, f"{self._config.table}_data.parquet")
Comment on lines +83 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _get_cache_path method does not handle potential os exceptions. Consider adding a try-except block to handle possible file system errors.


def _cached(self):
"""
Returns the cached Airtable data if it exists and
is not older than the cache interval.

Returns :
DataFrame | None : The cached data if
it exists and is not older than the cache
interval, None otherwise.
"""
cache_path = self._get_cache_path()
if not os.path.exists(cache_path):
return None

# If the file is older than 1 day , delete it.
if os.path.getmtime(cache_path) < time.time() - self._cache_interval:
if self.logger:
self.logger.log(f"Deleting expired cached data from {cache_path}")
os.remove(cache_path)
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache invalidation logic is based on the modification time of the cache file. However, this approach might not work as expected if the system time is changed. A more reliable approach would be to store the cache creation time within the cache file itself and use that for invalidation.

Comment on lines +104 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache invalidation logic is based on the modification time of the cache file. This could lead to unexpected behavior if the cache file is manually modified. Consider using a separate metadata file or a database to store the cache creation time.

Comment on lines +105 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache invalidation strategy is based on the modification time of the cache file. This could lead to issues if the system clock is changed or if the file is manually modified. Consider using a more robust cache invalidation strategy, like storing the cache creation time in the file itself or in a separate metadata file.


if self.logger:
self.logger.log(f"Loading cached data from {cache_path}")

return cache_path

def _save_cache(self, df):
"""
Save the given DataFrame to the cache.

Args:
df (DataFrame): The DataFrame to save to the cache.
"""
filename = self._get_cache_path(
include_additional_filters=self._additional_filters is not None
and len(self._additional_filters) > 0
)
df.to_parquet(filename)
Comment on lines +123 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache file name does not take into account the where clause of the config. This could lead to incorrect results if the where clause changes but the cache file is not invalidated. Consider including a hash of the where clause in the cache file name.


@property
def fallback_name(self):
"""
Returns the fallback table name of the connector.

Returns :
str : The fallback table name of the connector.
"""
return self._config.table

def execute(self):
"""
Execute the connector and return the result.

Returns:
DataFrameType: The result of the connector.
"""
return self.fetch_data()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have cached it but never used it. You can for the cache like:

cached = self._cached() or self._cached(include_additional_filters=True)
if cached:
return pd.read_parquet(cached)


def fetch_data(self):
"""
Feteches data from airtable server through
API and converts it to DataFrame.
"""
url = f"{self._root_url}{self._config.base_id}/{self._config.table}"
response = requests.get(
url=url,
headers={"Authorization": f"Bearer {self._config.api_key}"},
)
if response.status_code == 200:
data = response.json()
data = self.preprocess(data=data)
self._save_cache(data)
else:
raise InvalidRequestError(
f"""Failed to connect to Airtable.
Status code: {response.status_code},
message: {response.text}"""
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code does not handle potential network errors that might occur during the requests.get call. It would be better to wrap this call in a try-except block and handle potential network errors.

-        response = requests.get(
-            url=url,
-            headers={"Authorization": f"Bearer {self._config.api_key}"},
-        )
-        if response.status_code == 200:
-            data = response.json()
-            data = self.preprocess(data=data)
-            self._save_cache(data)
-        else:
-            raise InvalidRequestError(
-                f"""Failed to connect to Airtable. 
-                    Status code: {response.status_code}, 
-                    message: {response.text}"""
-            )
+        try:
+            response = requests.get(
+                url=url,
+                headers={"Authorization": f"Bearer {self._config.api_key}"},
+            )
+            response.raise_for_status()
+            data = response.json()
+            data = self.preprocess(data=data)
+            self._save_cache(data)
+        except requests.exceptions.RequestException as e:
+            raise InvalidRequestError(
+                f"""Failed to connect to Airtable. 
+                    Error: {str(e)}"""
+            )

return data

def preprocess(self, data):
"""
Preprocesses Json response data
To prepare dataframe correctly.
"""
records = [
{"id": record["id"], **record["fields"]} for record in data["records"]
]

df = pd.DataFrame(records)

if self._config.where:
for i in self._config.where:
filter_string = f"{i[0]} {i[1]} '{i[2]}'"
df = df.query(filter_string)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code assumes that the 'where' attribute in the config is a list of conditions. However, it does not validate if this is the case. If 'where' is not a list or does not contain valid conditions, the code will fail at runtime. It would be better to add a check to ensure 'where' is a list and contains valid conditions.


return df

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add a def fallback_name, similarly to all the other connectors.

Example:

@property
    def fallback_name(self):
        """
        Return the fallback name of the connector.

        Returns:
            str: The fallback name of the connector.
        """
        return self._config.table

def head(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be cache using @cache for reference check sql.py

"""
Return the head of the table that
the connector is connected to.

Returns :
DatFrameType: The head of the data source
that the conector is connected to .
"""
Comment on lines +216 to +224
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The head method is decorated with the cache decorator from the functools module. This means that the method result is cached and the same result is returned for subsequent calls with the same arguments. However, the head method does not take any arguments, so the caching does not have any effect. Consider removing the cache decorator.

return self.fetch_data().head()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here fetch only specific records lets say first 5 or random 5 rows instead of all.


@property
def rows_count(self):
"""
Return the number of rows in the data source that the connector is
connected to.

Returns:
int: The number of rows in the data source that the connector is
connected to.
"""
data = self.execute()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep properties of rows_count and store it in instance so we don't have to call again and again

return len(data)

@property
def columns_count(self):
"""
Return the number of columns in the data source that the connector is
connected to.

Returns:
int: The number of columns in the data source that the connector is
connected to.
"""
data = self.execute()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same like rows count to not call again and again. For this use head function instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here for column count get head() as it is also cached using @cache also it will have less data

return len(data.columns)

@property
def column_hash(self):
"""
Return the hash code that is unique to the columns of the data source
that the connector is connected to.

Returns:
int: The hash code that is unique to the columns of the data source
that the connector is connected to.
"""
Comment on lines +266 to +274
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column_hash property is computing a hash of the column names and the filter formula. This is a good practice for caching purposes. However, it assumes that self.instance is a DataFrame. Consider adding error handling for cases where self.instance is not a DataFrame.

data = self.execute()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also add where queries as well. Otherwise it will return wrong data from the cache.

columns_str = "|".join(data.columns)
return hashlib.sha256(columns_str.encode("utf-8")).hexdigest()
10 changes: 10 additions & 0 deletions pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class BaseConnectorConfig(BaseModel):
where: list[list[str]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The where parameter is not used in the AirtableConnectorConfig class. If it's not needed, consider removing it to avoid confusion.



class AirtableConnectorConfig(BaseConnectorConfig):
"""
Connecter configuration for Airtable data.
"""

api_key: str
base_id: str
database: str = "airtable_data"


class SQLBaseConnectorConfig(BaseConnectorConfig):
"""
Base Connector configuration.
Comment on lines 33 to 35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SQLBaseConnectorConfig class seems to be duplicated. If it's not intended, consider removing it to maintain code cleanliness.

Expand Down
10 changes: 10 additions & 0 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
"""


class InvalidRequestError(Exception):

"""
Raised when the request is not succesfull.

Args :
Exception (Exception): InvalidRequestError
"""


class APIKeyNotFoundError(Exception):

"""
Expand Down
12 changes: 6 additions & 6 deletions pandasai/helpers/openai_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@


def get_openai_token_cost_for_model(
model_name: str,
num_tokens: int,
is_completion: bool = False,
model_name: str,
num_tokens: int,
is_completion: bool = False,
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Expand All @@ -63,9 +63,9 @@ def get_openai_token_cost_for_model(
"""
model_name = model_name.lower()
if is_completion and (
model_name.startswith("gpt-4")
or model_name.startswith("gpt-3.5")
or model_name.startswith("gpt-35")
model_name.startswith("gpt-4")
or model_name.startswith("gpt-3.5")
or model_name.startswith("gpt-35")
):
# The cost of completion token is different from
# the cost of prompt tokens.
gventuri marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading