diff --git a/docs/connectors.md b/docs/connectors.md index 242de42eb..1850b9a7c 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -187,3 +187,32 @@ yahoo_connector = YahooFinanceConnector("MSFT") df = SmartDataframe(yahoo_connector) df.chat("What is the closing price for yesterday?") ``` + +## Airtable Connector + +The Airtable connector allows you to connect to Airtable Projects Tables, by simply passing the `base_id` , `api_key` and `table_name` of the table you want to analyze. + +To use the Airtable connector, you only need to import it into your Python code and pass it to a `SmartDataframe` or `SmartDatalake` object: + +```python +from pandasai.connectors import AirtableConnector +from pandasai import SmartDataframe + + +airtable_connectors = AirtableConnector( + config={ + "api_key": "AIRTABLE_API_TOKEN", + "table":"AIRTABLE_TABLE_NAME", + "base_id":"AIRTABLE_BASE_ID", + "where" : [ + # this is optional and filters the data to + # reduce the size of the dataframe + ["Status" ,"=","In progress"] + ] + } +) + +df = SmartDataframe(airtable_connectors) + +df.chat("How many rows are there in data ?") +``` \ No newline at end of file diff --git a/examples/from_airtable.py b/examples/from_airtable.py new file mode 100644 index 000000000..8a2cf1c88 --- /dev/null +++ b/examples/from_airtable.py @@ -0,0 +1,23 @@ +from pandasai.connectors import AirtableConnector +from pandasai.llm import OpenAI +from pandasai import SmartDataframe + + +airtable_connectors = AirtableConnector( + config={ + "api_key": "AIRTABLE_API_TOKEN", + "table": "AIRTABLE_TABLE_NAME", + "base_id": "AIRTABLE_BASE_ID", + "where": [ + # this is optional and filters the data to + # reduce the size of the dataframe + ["Status", "=", "In progress"] + ], + } +) + +llm = OpenAI("OPENAI_API_KEY") +df = SmartDataframe(airtable_connectors, config={"llm": llm}) + +response = df.chat("How many rows are there in data ?") +print(response) diff --git a/pandasai/connectors/__init__.py b/pandasai/connectors/__init__.py index 1a28f8dcc..fb80c8628 100644 --- a/pandasai/connectors/__init__.py +++ b/pandasai/connectors/__init__.py @@ -9,6 +9,7 @@ from .snowflake import SnowFlakeConnector from .databricks import DatabricksConnector from .yahoo_finance import YahooFinanceConnector +from .airtable import AirtableConnector __all__ = [ "BaseConnector", @@ -18,4 +19,5 @@ "YahooFinanceConnector", "SnowFlakeConnector", "DatabricksConnector", + "AirtableConnector", ] diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py new file mode 100644 index 000000000..417e609e9 --- /dev/null +++ b/pandasai/connectors/airtable.py @@ -0,0 +1,279 @@ +""" +Airtable connectors are used to connect airtable records. +""" + +from .base import AirtableConnectorConfig, BaseConnector, BaseConnectorConfig +from typing import Union, Optional +import requests +import pandas as pd +import os +from ..helpers.path import find_project_root +import time +import hashlib +from ..exceptions import InvalidRequestError +from functools import cache, cached_property + + +class AirtableConnector(BaseConnector): + """ + Airtable connector to retrieving record data. + """ + + _rows_count: int = None + _columns_count: int = None + _instance = None + + def __init__( + self, + config: Optional[Union[AirtableConnectorConfig, dict]] = None, + cache_interval: int = 600, + ): + if isinstance(config, dict): + if "api_key" in config and "base_id" in config and "table" in config: + config = AirtableConnectorConfig(**config) + else: + raise KeyError( + "Please specify all api_key,table,base_id properly in config ." + ) + + elif not config: + airtable_env_vars = { + "api_key": "AIRTABLE_API_TOKEN", + "base_id": "AIRTABLE_BASE_ID", + "table": "AIRTABLE_TABLE_NAME", + } + config = AirtableConnectorConfig( + **self._populate_config_from_env(config, airtable_env_vars) + ) + + self._root_url: str = "https://api.airtable.com/v0/" + self._cache_interval = cache_interval + + super().__init__(config) + + def _init_connection(self, config: BaseConnectorConfig): + """ + make connection to database + """ + config = config.dict() + url = f"{self._root_url}{config['base_id']}/{config['table']}" + response = requests.head( + url=url, headers={"Authorization": f"Bearer {config['api_key']}"} + ) + if response.status_code == 200: + self.logger.log( + """ + Connected to Airtable. + """ + ) + else: + raise InvalidRequestError( + f"""Failed to connect to Airtable. + Status code: {response.status_code}, + message: {response.text}""" + ) + + def _get_cache_path(self, include_additional_filters: bool = False): + """ + Return the path of the cache file. + + Returns : + str : The path of the cache file. + """ + cache_dir = os.path.join(os.getcwd(), "") + try: + cache_dir = os.path.join((find_project_root()), "cache") + except ValueError: + cache_dir = os.path.join(os.getcwd(), "cache") + return os.path.join(cache_dir, f"{self._config.table}_data.parquet") + + def _cached(self, include_additional_filters: bool = False): + """ + Returns the cached Airtable data if it exists and + is not older than the cache interval. + + Returns : + DataFrame | None : The cached data if + it exists and is not older than the cache + interval, None otherwise. + """ + cache_path = self._get_cache_path(include_additional_filters) + if not os.path.exists(cache_path): + return None + + # If the file is older than 1 day , delete it. + if os.path.getmtime(cache_path) < time.time() - self._cache_interval: + if self.logger: + self.logger.log(f"Deleting expired cached data from {cache_path}") + os.remove(cache_path) + return None + + if self.logger: + self.logger.log(f"Loading cached data from {cache_path}") + + return cache_path + + def _save_cache(self, df): + """ + Save the given DataFrame to the cache. + + Args: + df (DataFrame): The DataFrame to save to the cache. + """ + filename = self._get_cache_path( + include_additional_filters=self._additional_filters is not None + and len(self._additional_filters) > 0 + ) + df.to_parquet(filename) + + @property + def fallback_name(self): + """ + Returns the fallback table name of the connector. + + Returns : + str : The fallback table name of the connector. + """ + return self._config.table + + def execute(self): + """ + Execute the connector and return the result. + + Returns: + DataFrameType: The result of the connector. + """ + cached = self._cached() or self._cached(include_additional_filters=True) + if cached: + return pd.read_parquet(cached) + + if isinstance(self._instance, pd.DataFrame): + return self._instance + else: + self._instance = self._fetch_data() + + return self._instance + + def _build_formula(self): + """ + Build Airtable query formula for filtering. + """ + + condition_strings = [] + if self._config.where is not None: + for i in self._config.where: + filter_query = f"{i[0]}{i[1]}'{i[2]}'" + condition_strings.append(filter_query) + filter_formula = f'AND({",".join(condition_strings)})' + return filter_formula + + def _request_api(self, params): + url = f"{self._root_url}{self._config.base_id}/{self._config.table}" + response = requests.get( + url=url, + headers={"Authorization": f"Bearer {self._config.api_key}"}, + params=params, + ) + return response + + def _fetch_data(self): + """ + Fetches data from the Airtable server via API and converts it to a DataFrame. + """ + + params = { + "pageSize": 100, + "offset": "0" + } + + if self._config.where is not None: + params["filterByFormula"] = self._build_formula() + + data = [] + while True: + response = self._request_api(params=params) + + if response.status_code != 200: + raise InvalidRequestError( + f"Failed to connect to Airtable. " + f"Status code: {response.status_code}, " + f"message: {response.text}" + ) + + res = response.json() + records = res.get("records", []) + data.extend({"id": record["id"], **record["fields"]} for record in records) + + if len(records) < 100 or "offset" not in res: + break + + if "offset" in res: + params["offset"] = res["offset"] + + return pd.DataFrame(data) + + @cache + def head(self): + """ + Return the head of the table that + the connector is connected to. + + Returns : + DatFrameType: The head of the data source + that the conector is connected to . + """ + data = self._request_api(params={"maxRecords": 5}) + return pd.DataFrame( + [ + {"id": record["id"], **record["fields"]} + for record in data.json()["records"] + ] + ) + + @cached_property + def rows_count(self): + """ + Return the number of rows in the data source that the connector is + connected to. + + Returns: + int: The number of rows in the data source that the connector is + connected to. + """ + if self._rows_count is not None: + return self._rows_count + data = self.execute() + self._rows_count = len(data) + return self._rows_count + + @cached_property + def columns_count(self): + """ + Return the number of columns in the data source that the connector is + connected to. + + Returns: + int: The number of columns in the data source that the connector is + connected to. + """ + if self._columns_count is not None: + return self._columns_count + data = self.head() + self._columns_count = len(data.columns) + return self._columns_count + + @property + def column_hash(self): + """ + Return the hash code that is unique to the columns of the data source + that the connector is connected to. + + Returns: + int: The hash code that is unique to the columns of the data source + that the connector is connected to. + """ + if not isinstance(self._instance, pd.DataFrame): + self._instance = self.execute() + columns_str = "|".join(self._instance.columns) + columns_str += "WHERE" + self._build_formula() + return hashlib.sha256(columns_str.encode("utf-8")).hexdigest() diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py index f60bf0052..4de547486 100644 --- a/pandasai/connectors/base.py +++ b/pandasai/connectors/base.py @@ -20,6 +20,16 @@ class BaseConnectorConfig(BaseModel): where: list[list[str]] = None +class AirtableConnectorConfig(BaseConnectorConfig): + """ + Connecter configuration for Airtable data. + """ + + api_key: str + base_id: str + database: str = "airtable_data" + + class SQLBaseConnectorConfig(BaseConnectorConfig): """ Base Connector configuration. diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 74de081dc..57c0734b9 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -5,6 +5,16 @@ """ +class InvalidRequestError(Exception): + + """ + Raised when the request is not succesfull. + + Args : + Exception (Exception): InvalidRequestError + """ + + class APIKeyNotFoundError(Exception): """ diff --git a/pandasai/helpers/openai_info.py b/pandasai/helpers/openai_info.py index 581f6a97f..c5438e5a1 100644 --- a/pandasai/helpers/openai_info.py +++ b/pandasai/helpers/openai_info.py @@ -45,9 +45,9 @@ def get_openai_token_cost_for_model( - model_name: str, - num_tokens: int, - is_completion: bool = False, + model_name: str, + num_tokens: int, + is_completion: bool = False, ) -> float: """ Get the cost in USD for a given model and number of tokens. @@ -63,9 +63,9 @@ def get_openai_token_cost_for_model( """ model_name = model_name.lower() if is_completion and ( - model_name.startswith("gpt-4") - or model_name.startswith("gpt-3.5") - or model_name.startswith("gpt-35") + model_name.startswith("gpt-4") + or model_name.startswith("gpt-3.5") + or model_name.startswith("gpt-35") ): # The cost of completion token is different from # the cost of prompt tokens. diff --git a/tests/connectors/test_airtable.py b/tests/connectors/test_airtable.py new file mode 100644 index 000000000..41b0bd0f5 --- /dev/null +++ b/tests/connectors/test_airtable.py @@ -0,0 +1,123 @@ +import unittest +from pandasai.connectors import AirtableConnector +from pandasai.connectors.base import AirtableConnectorConfig +import pandas as pd +from unittest.mock import patch +import json + + +class TestAirTableConnector(unittest.TestCase): + def setUp(self) -> None: + # Define your ConnectorConfig instance here + self.config = AirtableConnectorConfig( + api_key="your_token", + base_id="your_baseid", + table="your_table_name", + where= [ + ["Status", "=", "In progress"] + ] + ).dict() + self.root_url = "https://api.airtable.com/v0/" + self.expected_data_json = """ + { + "records": [ + { + "id": "recnAIoHRTmpecLgY", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": { + "Name": "Quarterly launch", + "Status": "Done" + } + }, + { + "id": "recmRf57B2p3F9j8o", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": { + "Name": "Customer research", + "Status": "In progress" + } + }, + { + "id": "recsxnHUagIce7nB2", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": { + "Name": "Campaign analysis", + "Status": "To do" + } + } + ], + "offset": "itrowYGFfoBEIob3C/recsxnHUagIce7nB2" + } + """ + # Create an instance of Connector + self.connector = AirtableConnector(config=self.config) + + def test_constructor_and_properties(self): + self.assertEqual(self.connector._config, self.config) + self.assertEqual(self.connector._root_url, self.root_url) + self.assertEqual(self.connector._cache_interval, 600) + + def test_fallback_name(self): + self.assertEqual(self.connector.fallback_name, self.config["table"]) + + @patch("requests.get") + def test_execute(self, mock_request_get): + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + execute_data = self.connector.execute() + self.assertEqual(type(execute_data), pd.DataFrame) + self.assertEqual(len(execute_data), 3) + + @patch("requests.get") + def test_head(self, mock_request_get): + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + execute_data = self.connector.head() + self.assertEqual(type(execute_data), pd.DataFrame) + self.assertLessEqual(len(execute_data), 5) + + def test_fallback_name_property(self): + # Test fallback_name property + fallback_name = self.connector.fallback_name + self.assertEqual(fallback_name, self.config["table"]) + + @patch("requests.get") + def test_rows_count_property(self, mock_request_get): + # Test rows_count property + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + rows_count = self.connector.rows_count + self.assertEqual(rows_count, 3) + + @patch("requests.get") + def test_columns_count_property(self, mock_request_get): + # Test columns_count property + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + columns_count = self.connector.columns_count + self.assertEqual(columns_count, 3) + + def test_build_formula_method(self): + formula = self.connector._build_formula() + expected_formula = "AND(Status='In progress')" + self.assertEqual(formula,expected_formula) + + @patch("requests.get") + def test_column_hash(self,mock_request_get): + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + returned_hash = self.connector.column_hash + self.assertEqual( + returned_hash, + "e4cdc9402a0831fb549d7fdeaaa089b61aeaf61e14b8a044bc027219b2db941e" + )