From 197740ba741c1bb9691baa6833b9d128dad2989c Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Wed, 11 Oct 2023 00:44:51 +0530 Subject: [PATCH 01/12] feat : Airtable connector Support --- pandasai/connectors/__init__.py | 2 + pandasai/connectors/airtable.py | 92 +++++++++++++++++++++++++++++++ pandasai/connectors/base.py | 10 ++++ pandasai/helpers/openai_info.py | 12 ++-- tests/connectors/test_airtable.py | 84 ++++++++++++++++++++++++++++ 5 files changed, 194 insertions(+), 6 deletions(-) create mode 100644 pandasai/connectors/airtable.py create mode 100644 tests/connectors/test_airtable.py diff --git a/pandasai/connectors/__init__.py b/pandasai/connectors/__init__.py index 1a28f8dcc..fb80c8628 100644 --- a/pandasai/connectors/__init__.py +++ b/pandasai/connectors/__init__.py @@ -9,6 +9,7 @@ from .snowflake import SnowFlakeConnector from .databricks import DatabricksConnector from .yahoo_finance import YahooFinanceConnector +from .airtable import AirtableConnector __all__ = [ "BaseConnector", @@ -18,4 +19,5 @@ "YahooFinanceConnector", "SnowFlakeConnector", "DatabricksConnector", + "AirtableConnector", ] diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py new file mode 100644 index 000000000..d97b26138 --- /dev/null +++ b/pandasai/connectors/airtable.py @@ -0,0 +1,92 @@ +""" +Airtable connectors are used to connect airtable records. +""" + +from .base import AirtableConnectorConfig, BaseConnector, BaseConnectorConfig +from typing import Union, Optional +import requests +import pandas as pd + + +class AirtableConnector(BaseConnector): + """ + Airtable connector to retrieving record data. + """ + + def __init__( + self, + config: Optional[Union[AirtableConnectorConfig, dict]] = None, + ): + if isinstance(config, dict): + if config["token"] and config["baseID"] and config["table"]: + config = AirtableConnectorConfig(**config) + + elif not config: + airtable_env_vars = { + "token": "AIRTABLE_AUTH_TOKEN", + "baseID": "AIRTABLE_BASE_ID", + "table": "AIRTABLE_TABLE_NAME", + } + config = AirtableConnectorConfig( + **self._populate_config_from_env(config, airtable_env_vars) + ) + + self._root_url: str = "https://api.airtable.com/v0/" + + super().__init__(config) + + def _init_connection(self, config: BaseConnectorConfig): + """ + make connection to database + """ + config = config.dict() + _session = requests.Session() + _session.headers = {"Authorization": f"Bearer {config['token']}"} + url = f"{self._root_url}{config['baseID']}/{config['table']}" + response = _session.head(url=url) + if response.status_code == 200: + self._session = _session + else: + raise ValueError( + f"""Failed to connect to Airtable. + Status code: {response.status_code}, + message: {response.text}""" + ) + + def execute(self): + """ + Execute the connector and return the result. + + Returns: + DataFrameType: The result of the connector. + """ + url = f"{self._root_url}{self.config['baseID']}/{self.config['table']}" + if self._session: + _response = self._session.get(url) + if _response.status_code == 200: + data = _response.json() + ## Following column selection is done + ## to prepare output in favaourable format. + records = [ + {"id": record["id"], **record["fields"]} + for record in data["records"] + ] + self._response = pd.DataFrame(records) + else: + raise ValueError( + f"""Failed to connect to Airtable. + Status code: {_response.status_code}, + message: {_response.text}""" + ) + return self._response + + def head(self): + """ + Return the head of the table that + the connector is connected to. + + Returns : + DatFrameType: The head of the data source + that the conector is connected to . + """ + return self._response.head() diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py index f60bf0052..0019805c6 100644 --- a/pandasai/connectors/base.py +++ b/pandasai/connectors/base.py @@ -20,6 +20,16 @@ class BaseConnectorConfig(BaseModel): where: list[list[str]] = None +class AirtableConnectorConfig(BaseConnectorConfig): + """ + Connecter configuration for Airtable data. + """ + + token: str + baseID: str + database: str = "airtable_data" + + class SQLBaseConnectorConfig(BaseConnectorConfig): """ Base Connector configuration. diff --git a/pandasai/helpers/openai_info.py b/pandasai/helpers/openai_info.py index 581f6a97f..c5438e5a1 100644 --- a/pandasai/helpers/openai_info.py +++ b/pandasai/helpers/openai_info.py @@ -45,9 +45,9 @@ def get_openai_token_cost_for_model( - model_name: str, - num_tokens: int, - is_completion: bool = False, + model_name: str, + num_tokens: int, + is_completion: bool = False, ) -> float: """ Get the cost in USD for a given model and number of tokens. @@ -63,9 +63,9 @@ def get_openai_token_cost_for_model( """ model_name = model_name.lower() if is_completion and ( - model_name.startswith("gpt-4") - or model_name.startswith("gpt-3.5") - or model_name.startswith("gpt-35") + model_name.startswith("gpt-4") + or model_name.startswith("gpt-3.5") + or model_name.startswith("gpt-35") ): # The cost of completion token is different from # the cost of prompt tokens. diff --git a/tests/connectors/test_airtable.py b/tests/connectors/test_airtable.py new file mode 100644 index 000000000..1748a827c --- /dev/null +++ b/tests/connectors/test_airtable.py @@ -0,0 +1,84 @@ +import unittest +import pandas as pd +from unittest.mock import Mock, patch +from pandasai.connectors.base import AirtableConnectorConfig +from pandasai.connectors import AirtableConnector + + +class TestAirTableConnector(unittest.TestCase): + def setUp(self): + # Define your ConnectorConfig instance here + self.config = AirtableConnectorConfig( + token="your_token", baseID="your_baseid", table="your_table_name" + ).dict() + self.root_url = "https://api.airtable.com/v0/" + # Create an instance of Connector + self.connector = AirtableConnector(config=self.config) + + def test_constructor_and_properties(self): + self.assertEqual(self.connector._config, self.config) + self.assertEqual(self.connector._root_url, self.root_url) + + def test_execute(self): + expected_data_json = { + "records": [ + { + "id": "recnAIoHRTmpecLgY", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": {"Name": "Quarterly launch", "Status": "Done"}, + }, + { + "id": "recmRf57B2p3F9j8o", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": {"Name": "Customer research", "Status": "In progress"}, + }, + { + "id": "recsxnHUagIce7nB2", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": {"Name": "Campaign analysis", "Status": "To do"}, + }, + ], + "offset": "itrowYGFfoBEIob3C/recsxnHUagIce7nB2", + } + print(expected_data_json) + records = [ + {"id": record["id"], **record["fields"]} + for record in expected_data_json["records"] + ] + expected_data = pd.DataFrame(records) + self.connector.execute = Mock(return_value=expected_data) + execute_data = self.connector.execute() + self.assertEqual(execute_data.equals(expected_data), True) + + def test_head(self): + expected_data_json = { + "records": [ + { + "id": "recnAIoHRTmpecLgY", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": {"Name": "Quarterly launch", "Status": "Done"}, + }, + { + "id": "recmRf57B2p3F9j8o", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": {"Name": "Customer research", "Status": "In progress"}, + }, + { + "id": "recsxnHUagIce7nB2", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": {"Name": "Campaign analysis", "Status": "To do"}, + }, + ], + "offset": "itrowYGFfoBEIob3C/recsxnHUagIce7nB2", + } + + records = [ + {"id": record["id"], **record["fields"]} + for record in expected_data_json["records"] + ] + expected_data = pd.DataFrame(records) + self.connector.head = Mock(return_value=expected_data) + head_data = self.connector.head() + + self.assertEqual(head_data.equals(expected_data), True) + self.assertLessEqual(len(head_data), 5) From 826aa02338c66cbd869612dee93a9710b4bfb16c Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Wed, 11 Oct 2023 00:56:28 +0530 Subject: [PATCH 02/12] Removed unneccesary print statements --- tests/connectors/test_airtable.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/connectors/test_airtable.py b/tests/connectors/test_airtable.py index 1748a827c..69a462c82 100644 --- a/tests/connectors/test_airtable.py +++ b/tests/connectors/test_airtable.py @@ -40,7 +40,6 @@ def test_execute(self): ], "offset": "itrowYGFfoBEIob3C/recsxnHUagIce7nB2", } - print(expected_data_json) records = [ {"id": record["id"], **record["fields"]} for record in expected_data_json["records"] From cd78117b7bc5391b56d3bd3252a7d0085005712d Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Wed, 11 Oct 2023 19:03:40 +0530 Subject: [PATCH 03/12] Fixed tests and implementation --- docs/connectors.md | 29 +++++ examples/from_airtable.py | 23 ++++ pandasai/connectors/airtable.py | 193 +++++++++++++++++++++++++----- pandasai/connectors/base.py | 4 +- tests/connectors/test_airtable.py | 146 ++++++++++++---------- 5 files changed, 300 insertions(+), 95 deletions(-) create mode 100644 examples/from_airtable.py diff --git a/docs/connectors.md b/docs/connectors.md index 242de42eb..1850b9a7c 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -187,3 +187,32 @@ yahoo_connector = YahooFinanceConnector("MSFT") df = SmartDataframe(yahoo_connector) df.chat("What is the closing price for yesterday?") ``` + +## Airtable Connector + +The Airtable connector allows you to connect to Airtable Projects Tables, by simply passing the `base_id` , `api_key` and `table_name` of the table you want to analyze. + +To use the Airtable connector, you only need to import it into your Python code and pass it to a `SmartDataframe` or `SmartDatalake` object: + +```python +from pandasai.connectors import AirtableConnector +from pandasai import SmartDataframe + + +airtable_connectors = AirtableConnector( + config={ + "api_key": "AIRTABLE_API_TOKEN", + "table":"AIRTABLE_TABLE_NAME", + "base_id":"AIRTABLE_BASE_ID", + "where" : [ + # this is optional and filters the data to + # reduce the size of the dataframe + ["Status" ,"=","In progress"] + ] + } +) + +df = SmartDataframe(airtable_connectors) + +df.chat("How many rows are there in data ?") +``` \ No newline at end of file diff --git a/examples/from_airtable.py b/examples/from_airtable.py new file mode 100644 index 000000000..8a2cf1c88 --- /dev/null +++ b/examples/from_airtable.py @@ -0,0 +1,23 @@ +from pandasai.connectors import AirtableConnector +from pandasai.llm import OpenAI +from pandasai import SmartDataframe + + +airtable_connectors = AirtableConnector( + config={ + "api_key": "AIRTABLE_API_TOKEN", + "table": "AIRTABLE_TABLE_NAME", + "base_id": "AIRTABLE_BASE_ID", + "where": [ + # this is optional and filters the data to + # reduce the size of the dataframe + ["Status", "=", "In progress"] + ], + } +) + +llm = OpenAI("OPENAI_API_KEY") +df = SmartDataframe(airtable_connectors, config={"llm": llm}) + +response = df.chat("How many rows are there in data ?") +print(response) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index d97b26138..30d36fa33 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -6,6 +6,10 @@ from typing import Union, Optional import requests import pandas as pd +import os +from ..helpers.path import find_project_root +import time +import hashlib class AirtableConnector(BaseConnector): @@ -16,15 +20,16 @@ class AirtableConnector(BaseConnector): def __init__( self, config: Optional[Union[AirtableConnectorConfig, dict]] = None, + cache_interval: int = 600, ): if isinstance(config, dict): - if config["token"] and config["baseID"] and config["table"]: + if config["api_key"] and config["base_id"] and config["table"]: config = AirtableConnectorConfig(**config) elif not config: airtable_env_vars = { - "token": "AIRTABLE_AUTH_TOKEN", - "baseID": "AIRTABLE_BASE_ID", + "api_key": "AIRTABLE_API_TOKEN", + "base_id": "AIRTABLE_BASE_ID", "table": "AIRTABLE_TABLE_NAME", } config = AirtableConnectorConfig( @@ -32,6 +37,7 @@ def __init__( ) self._root_url: str = "https://api.airtable.com/v0/" + self._cache_interval = cache_interval super().__init__(config) @@ -40,19 +46,86 @@ def _init_connection(self, config: BaseConnectorConfig): make connection to database """ config = config.dict() - _session = requests.Session() - _session.headers = {"Authorization": f"Bearer {config['token']}"} - url = f"{self._root_url}{config['baseID']}/{config['table']}" - response = _session.head(url=url) + url = f"{self._root_url}{config['base_id']}/{config['table']}" + response = requests.head( + url=url, headers={"Authorization": f"Bearer {config['api_key']}"} + ) if response.status_code == 200: - self._session = _session + self.logger.log( + """ + Connected to Airtable. + """ + ) else: raise ValueError( f"""Failed to connect to Airtable. - Status code: {response.status_code}, - message: {response.text}""" + Status code: {response.status_code}, + message: {response.text}""" ) + def _get_cache_path(self, include_additional_filters: bool): + """ + Return the path of the cache file. + + Returns : + str : The path of the cache file. + """ + cache_dir = os.path.join(os.getcwd(), "") + try: + cache_dir = os.path.join((find_project_root()), "cache") + except ValueError: + cache_dir = os.path.join(os.getcwd(), "cache") + return os.path.join(cache_dir, f"{self._config.table}_data.parquet") + + def _cached(self): + """ + Returns the cached Airtable data if it exists and + is not older than the cache interval. + + Returns : + DataFrame | None : The cached data if + it exists and is not older than the cache + interval, None otherwise. + """ + cache_path = self._get_cache_path() + if not os.path.exists(cache_path): + return None + + # If the file is older than 1 day , delete it. + if os.path.getmtime(cache_path) < time.time() - self._cache_interval: + if self.logger: + self.logger.log(f"Deleting expired cached data from {cache_path}") + os.remove(cache_path) + return None + + if self.logger: + self.logger.log(f"Loading cached data from {cache_path}") + + return cache_path + + def _save_cache(self, df): + """ + Save the given DataFrame to the cache. + + Args: + df (DataFrame): The DataFrame to save to the cache. + """ + filename = self._get_cache_path( + include_additional_filters=self._additional_filters is not None + and len(self._additional_filters) > 0 + ) + df.to_parquet(filename) + + @property + def fallback_name(self): + """ + Returns the fallback table name of the connector. + + Returns : + str : The fallback table name of the connector. + """ + return self._config.table + def execute(self): """ Execute the connector and return the result. @@ -60,25 +133,31 @@ def execute(self): Returns: DataFrameType: The result of the connector. """ - url = f"{self._root_url}{self.config['baseID']}/{self.config['table']}" - if self._session: - _response = self._session.get(url) - if _response.status_code == 200: - data = _response.json() - ## Following column selection is done - ## to prepare output in favaourable format. - records = [ - {"id": record["id"], **record["fields"]} - for record in data["records"] - ] - self._response = pd.DataFrame(records) - else: - raise ValueError( - f"""Failed to connect to Airtable. - Status code: {_response.status_code}, - message: {_response.text}""" - ) - return self._response + url = f"{self._root_url}{self._config.base_id}/{self._config.table}" + response = requests.get( + url=url, headers={"Authorization": f"Bearer {self._config.api_key}"} + ) + if response.status_code == 200: + data = response.json() + data = self.preprocess(data=data) + self._save_cache(data) + else: + raise ValueError( + f"""Failed to connect to Airtable. + Status code: {response.status_code}, + message: {response.text}""" + ) + return data + + def preprocess(self, data): + """ + Preprocesses Json response data + To prepare dataframe correctly. + """ + records = [ + {"id": record["id"], **record["fields"]} for record in data["records"] + ] + return pd.DataFrame(records) def head(self): """ @@ -89,4 +168,58 @@ def head(self): DatFrameType: The head of the data source that the conector is connected to . """ - return self._response.head() + url = f"{self._root_url}{self._config.base_id}/{self._config.table}" + response = requests.get( + url=url, headers={"Authorization": f"Bearer {self._config.api_key}"} + ) + if response.status_code == 200: + data = response.json() + data = self.preprocess(data=data) + else: + raise ValueError( + f"""Failed to connect to Airtable. + Status code: {response.status_code}, + message: {response.text}""" + ) + + return data.head() + + @property + def rows_count(self): + """ + Return the number of rows in the data source that the connector is + connected to. + + Returns: + int: The number of rows in the data source that the connector is + connected to. + """ + data = self.execute() + return len(data) + + @property + def columns_count(self): + """ + Return the number of columns in the data source that the connector is + connected to. + + Returns: + int: The number of columns in the data source that the connector is + connected to. + """ + data = self.execute() + return len(data.columns) + + @property + def column_hash(self): + """ + Return the hash code that is unique to the columns of the data source + that the connector is connected to. + + Returns: + int: The hash code that is unique to the columns of the data source + that the connector is connected to. + """ + data = self.execute() + columns_str = "|".join(data.columns) + return hashlib.sha256(columns_str.encode("utf-8")).hexdigest() diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py index 0019805c6..4de547486 100644 --- a/pandasai/connectors/base.py +++ b/pandasai/connectors/base.py @@ -25,8 +25,8 @@ class AirtableConnectorConfig(BaseConnectorConfig): Connecter configuration for Airtable data. """ - token: str - baseID: str + api_key: str + base_id: str database: str = "airtable_data" diff --git a/tests/connectors/test_airtable.py b/tests/connectors/test_airtable.py index 69a462c82..5f885dee9 100644 --- a/tests/connectors/test_airtable.py +++ b/tests/connectors/test_airtable.py @@ -1,83 +1,103 @@ import unittest -import pandas as pd -from unittest.mock import Mock, patch -from pandasai.connectors.base import AirtableConnectorConfig from pandasai.connectors import AirtableConnector +from pandasai.connectors.base import AirtableConnectorConfig +import pandas as pd +from unittest.mock import patch +import json class TestAirTableConnector(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # Define your ConnectorConfig instance here self.config = AirtableConnectorConfig( - token="your_token", baseID="your_baseid", table="your_table_name" + api_key="your_token", + base_id="your_baseid", + table="your_table_name", ).dict() self.root_url = "https://api.airtable.com/v0/" + self.expected_data_json = """ + { + "records": [ + { + "id": "recnAIoHRTmpecLgY", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": { + "Name": "Quarterly launch", + "Status": "Done" + } + }, + { + "id": "recmRf57B2p3F9j8o", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": { + "Name": "Customer research", + "Status": "In progress" + } + }, + { + "id": "recsxnHUagIce7nB2", + "createdTime": "2023-10-09T13:04:58.000Z", + "fields": { + "Name": "Campaign analysis", + "Status": "To do" + } + } + ], + "offset": "itrowYGFfoBEIob3C/recsxnHUagIce7nB2" + } + """ # Create an instance of Connector self.connector = AirtableConnector(config=self.config) def test_constructor_and_properties(self): self.assertEqual(self.connector._config, self.config) self.assertEqual(self.connector._root_url, self.root_url) + self.assertEqual(self.connector._cache_interval, 600) + + def test_fallback_name(self): + self.assertEqual(self.connector.fallback_name, self.config["table"]) - def test_execute(self): - expected_data_json = { - "records": [ - { - "id": "recnAIoHRTmpecLgY", - "createdTime": "2023-10-09T13:04:58.000Z", - "fields": {"Name": "Quarterly launch", "Status": "Done"}, - }, - { - "id": "recmRf57B2p3F9j8o", - "createdTime": "2023-10-09T13:04:58.000Z", - "fields": {"Name": "Customer research", "Status": "In progress"}, - }, - { - "id": "recsxnHUagIce7nB2", - "createdTime": "2023-10-09T13:04:58.000Z", - "fields": {"Name": "Campaign analysis", "Status": "To do"}, - }, - ], - "offset": "itrowYGFfoBEIob3C/recsxnHUagIce7nB2", - } - records = [ - {"id": record["id"], **record["fields"]} - for record in expected_data_json["records"] - ] - expected_data = pd.DataFrame(records) - self.connector.execute = Mock(return_value=expected_data) + @patch("requests.get") + def test_execute(self, mock_request_get): + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 execute_data = self.connector.execute() - self.assertEqual(execute_data.equals(expected_data), True) + self.assertEqual(type(execute_data), pd.DataFrame) + self.assertEqual(len(execute_data), 3) + + @patch("requests.get") + def test_head(self, mock_request_get): + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + execute_data = self.connector.head() + self.assertEqual(type(execute_data), pd.DataFrame) + self.assertLessEqual(len(execute_data), 5) - def test_head(self): - expected_data_json = { - "records": [ - { - "id": "recnAIoHRTmpecLgY", - "createdTime": "2023-10-09T13:04:58.000Z", - "fields": {"Name": "Quarterly launch", "Status": "Done"}, - }, - { - "id": "recmRf57B2p3F9j8o", - "createdTime": "2023-10-09T13:04:58.000Z", - "fields": {"Name": "Customer research", "Status": "In progress"}, - }, - { - "id": "recsxnHUagIce7nB2", - "createdTime": "2023-10-09T13:04:58.000Z", - "fields": {"Name": "Campaign analysis", "Status": "To do"}, - }, - ], - "offset": "itrowYGFfoBEIob3C/recsxnHUagIce7nB2", - } + def test_fallback_name_property(self): + # Test fallback_name property + fallback_name = self.connector.fallback_name + self.assertEqual(fallback_name, self.config["table"]) - records = [ - {"id": record["id"], **record["fields"]} - for record in expected_data_json["records"] - ] - expected_data = pd.DataFrame(records) - self.connector.head = Mock(return_value=expected_data) - head_data = self.connector.head() + @patch("requests.get") + def test_rows_count_property(self, mock_request_get): + # Test rows_count property + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + rows_count = self.connector.rows_count + self.assertEqual(rows_count, 3) - self.assertEqual(head_data.equals(expected_data), True) - self.assertLessEqual(len(head_data), 5) + @patch("requests.get") + def test_columns_count_property(self, mock_request_get): + # Test columns_count property + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + rows_count = self.connector.columns_count + self.assertEqual(rows_count, 3) From 1ec1dd5375b6f314a2c66cc3aa7542585633cb82 Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Thu, 12 Oct 2023 15:12:55 +0530 Subject: [PATCH 04/12] where filter and pagination added --- docs/connectors.md | 2 +- examples/from_airtable.py | 2 +- pandasai/connectors/airtable.py | 50 +++++++++++++++++++-------------- pandasai/connectors/base.py | 1 + pandasai/exceptions.py | 10 +++++++ 5 files changed, 42 insertions(+), 23 deletions(-) diff --git a/docs/connectors.md b/docs/connectors.md index 1850b9a7c..8c802b479 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -207,7 +207,7 @@ airtable_connectors = AirtableConnector( "where" : [ # this is optional and filters the data to # reduce the size of the dataframe - ["Status" ,"=","In progress"] + ["Status" ,"==","In progress"] ] } ) diff --git a/examples/from_airtable.py b/examples/from_airtable.py index 8a2cf1c88..0bb78e139 100644 --- a/examples/from_airtable.py +++ b/examples/from_airtable.py @@ -11,7 +11,7 @@ "where": [ # this is optional and filters the data to # reduce the size of the dataframe - ["Status", "=", "In progress"] + ["Status", "==", "In progress"] ], } ) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 30d36fa33..746541389 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -10,6 +10,7 @@ from ..helpers.path import find_project_root import time import hashlib +from ..exceptions import InvalidRequestError class AirtableConnector(BaseConnector): @@ -23,8 +24,12 @@ def __init__( cache_interval: int = 600, ): if isinstance(config, dict): - if config["api_key"] and config["base_id"] and config["table"]: + if "api_key" in config and "base_id" in config and "table" in config: config = AirtableConnectorConfig(**config) + else: + raise KeyError( + "Please specify all api_key,table,base_id properly in config ." + ) elif not config: airtable_env_vars = { @@ -57,13 +62,13 @@ def _init_connection(self, config: BaseConnectorConfig): """ ) else: - raise ValueError( + raise InvalidRequestError( f"""Failed to connect to Airtable. Status code: {response.status_code}, message: {response.text}""" ) - def _get_cache_path(self, include_additional_filters: bool): + def _get_cache_path(self, include_additional_filters: bool = False): """ Return the path of the cache file. @@ -133,16 +138,25 @@ def execute(self): Returns: DataFrameType: The result of the connector. """ + return self.fetch_data() + + def fetch_data(self): + """ + Feteches data from airtable server through + API and converts it to DataFrame. + """ url = f"{self._root_url}{self._config.base_id}/{self._config.table}" response = requests.get( - url=url, headers={"Authorization": f"Bearer {self._config.api_key}"} + url=url, + headers={"Authorization": f"Bearer {self._config.api_key}"}, + params={"maxRecords": self._config.max_records}, ) if response.status_code == 200: data = response.json() data = self.preprocess(data=data) self._save_cache(data) else: - raise ValueError( + raise InvalidRequestError( f"""Failed to connect to Airtable. Status code: {response.status_code}, message: {response.text}""" @@ -157,7 +171,15 @@ def preprocess(self, data): records = [ {"id": record["id"], **record["fields"]} for record in data["records"] ] - return pd.DataFrame(records) + + df = pd.DataFrame(records) + + if self._config.where: + for i in self._config.where: + filter_string = f"{i[0]} {i[1]} '{i[2]}'" + df = df.query(filter_string) + + return df def head(self): """ @@ -168,21 +190,7 @@ def head(self): DatFrameType: The head of the data source that the conector is connected to . """ - url = f"{self._root_url}{self._config.base_id}/{self._config.table}" - response = requests.get( - url=url, headers={"Authorization": f"Bearer {self._config.api_key}"} - ) - if response.status_code == 200: - data = response.json() - data = self.preprocess(data=data) - else: - raise ValueError( - f"""Failed to connect to Airtable. - Status code: {response.status_code}, - message: {response.text}""" - ) - - return data.head() + return self.fetch_data().head() @property def rows_count(self): diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py index 4de547486..21ba4a571 100644 --- a/pandasai/connectors/base.py +++ b/pandasai/connectors/base.py @@ -28,6 +28,7 @@ class AirtableConnectorConfig(BaseConnectorConfig): api_key: str base_id: str database: str = "airtable_data" + max_records: Optional[int] = 100 class SQLBaseConnectorConfig(BaseConnectorConfig): diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 74de081dc..57c0734b9 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -5,6 +5,16 @@ """ +class InvalidRequestError(Exception): + + """ + Raised when the request is not succesfull. + + Args : + Exception (Exception): InvalidRequestError + """ + + class APIKeyNotFoundError(Exception): """ From 217129615e3e5fa1b7f55e315d9f731300bbfb8e Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Thu, 12 Oct 2023 16:08:02 +0530 Subject: [PATCH 05/12] Removed Pagination --- pandasai/connectors/airtable.py | 1 - pandasai/connectors/base.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 746541389..8c4e776c9 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -149,7 +149,6 @@ def fetch_data(self): response = requests.get( url=url, headers={"Authorization": f"Bearer {self._config.api_key}"}, - params={"maxRecords": self._config.max_records}, ) if response.status_code == 200: data = response.json() diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py index 21ba4a571..4de547486 100644 --- a/pandasai/connectors/base.py +++ b/pandasai/connectors/base.py @@ -28,7 +28,6 @@ class AirtableConnectorConfig(BaseConnectorConfig): api_key: str base_id: str database: str = "airtable_data" - max_records: Optional[int] = 100 class SQLBaseConnectorConfig(BaseConnectorConfig): From a1f25b83d3c34a0f2c61278ca6aee4010e0f81d8 Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Thu, 12 Oct 2023 20:57:49 +0530 Subject: [PATCH 06/12] Refactored where filter functionality --- pandasai/connectors/airtable.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 8c4e776c9..2e5a8c865 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -140,15 +140,32 @@ def execute(self): """ return self.fetch_data() + def build_formula(self): + """ + Build Airtable query formula for filtering. + """ + + condition_strings = [] + for i in self._config.where: + filter_query = f"{i[0]}{i[1]}'{i[2]}'" + condition_strings.append(filter_query) + filter_formula = f'AND({",".join(condition_strings)})' + return filter_formula + def fetch_data(self): """ Feteches data from airtable server through API and converts it to DataFrame. """ url = f"{self._root_url}{self._config.base_id}/{self._config.table}" + params = {} + + if self._config.where: + params["filterByFormula"] = self.build_formula() response = requests.get( url=url, headers={"Authorization": f"Bearer {self._config.api_key}"}, + params=params, ) if response.status_code == 200: data = response.json() @@ -172,12 +189,6 @@ def preprocess(self, data): ] df = pd.DataFrame(records) - - if self._config.where: - for i in self._config.where: - filter_string = f"{i[0]} {i[1]} '{i[2]}'" - df = df.query(filter_string) - return df def head(self): From 02007ae871d2b79a5ade76a3ad216dd83bc83182 Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Thu, 12 Oct 2023 21:16:38 +0530 Subject: [PATCH 07/12] Docs and example change --- docs/connectors.md | 2 +- examples/from_airtable.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/connectors.md b/docs/connectors.md index 8c802b479..1850b9a7c 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -207,7 +207,7 @@ airtable_connectors = AirtableConnector( "where" : [ # this is optional and filters the data to # reduce the size of the dataframe - ["Status" ,"==","In progress"] + ["Status" ,"=","In progress"] ] } ) diff --git a/examples/from_airtable.py b/examples/from_airtable.py index 0bb78e139..8a2cf1c88 100644 --- a/examples/from_airtable.py +++ b/examples/from_airtable.py @@ -11,7 +11,7 @@ "where": [ # this is optional and filters the data to # reduce the size of the dataframe - ["Status", "==", "In progress"] + ["Status", "=", "In progress"] ], } ) From 3f8f55f090c38958b56783ea9d279220a1e97d9c Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Fri, 13 Oct 2023 14:18:24 +0530 Subject: [PATCH 08/12] Fixed cached methods and code improvements --- pandasai/connectors/airtable.py | 48 +++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 2e5a8c865..3137bf068 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -11,6 +11,7 @@ import time import hashlib from ..exceptions import InvalidRequestError +from functools import cache, cached_property class AirtableConnector(BaseConnector): @@ -18,6 +19,10 @@ class AirtableConnector(BaseConnector): Airtable connector to retrieving record data. """ + _rows_count: int = None + _columns_count: int = None + instance = None + def __init__( self, config: Optional[Union[AirtableConnectorConfig, dict]] = None, @@ -82,7 +87,7 @@ def _get_cache_path(self, include_additional_filters: bool = False): cache_dir = os.path.join(os.getcwd(), "cache") return os.path.join(cache_dir, f"{self._config.table}_data.parquet") - def _cached(self): + def _cached(self, include_additional_filters: bool = False): """ Returns the cached Airtable data if it exists and is not older than the cache interval. @@ -92,7 +97,7 @@ def _cached(self): it exists and is not older than the cache interval, None otherwise. """ - cache_path = self._get_cache_path() + cache_path = self._get_cache_path(include_additional_filters) if not os.path.exists(cache_path): return None @@ -138,7 +143,15 @@ def execute(self): Returns: DataFrameType: The result of the connector. """ - return self.fetch_data() + cached = self._cached() or self._cached(include_additional_filters=True) + if cached: + return pd.read_parquet(cached) + + if isinstance(self.instance, pd.DataFrame): + return self.instance + else: + self.instance = self.fetch_data() + return self.instance def build_formula(self): """ @@ -159,7 +172,6 @@ def fetch_data(self): """ url = f"{self._root_url}{self._config.base_id}/{self._config.table}" params = {} - if self._config.where: params["filterByFormula"] = self.build_formula() response = requests.get( @@ -191,6 +203,7 @@ def preprocess(self, data): df = pd.DataFrame(records) return df + @cache def head(self): """ Return the head of the table that @@ -200,9 +213,14 @@ def head(self): DatFrameType: The head of the data source that the conector is connected to . """ - return self.fetch_data().head() + # return self.fetch_data().head() + if isinstance(self.instance, pd.DataFrame): + return self.instance.head() + else: + self.instance = self.fetch_data() + return self.instance.head() - @property + @cached_property def rows_count(self): """ Return the number of rows in the data source that the connector is @@ -212,10 +230,13 @@ def rows_count(self): int: The number of rows in the data source that the connector is connected to. """ + if self._rows_count is not None: + return self._rows_count data = self.execute() - return len(data) + self._rows_count = len(data) + return self._rows_count - @property + @cached_property def columns_count(self): """ Return the number of columns in the data source that the connector is @@ -225,8 +246,11 @@ def columns_count(self): int: The number of columns in the data source that the connector is connected to. """ + if self._columns_count is not None: + return self._columns_count data = self.execute() - return len(data.columns) + self._columns_count = len(data.columns) + return self._columns_count @property def column_hash(self): @@ -238,6 +262,8 @@ def column_hash(self): int: The hash code that is unique to the columns of the data source that the connector is connected to. """ - data = self.execute() - columns_str = "|".join(data.columns) + if not isinstance(self.instance, pd.DataFrame): + self.instance = self.execute() + columns_str = "|".join(self.instance.columns) + columns_str += "WHERE" + self.build_formula() return hashlib.sha256(columns_str.encode("utf-8")).hexdigest() From 8c0e07b5b68a64334ac42f23a94930929c2ed999 Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Fri, 13 Oct 2023 18:01:51 +0530 Subject: [PATCH 09/12] Added Pagination to connector --- pandasai/connectors/airtable.py | 111 ++++++++++++++++++------------ tests/connectors/test_airtable.py | 2 +- 2 files changed, 68 insertions(+), 45 deletions(-) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 3137bf068..78cd30005 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -21,7 +21,7 @@ class AirtableConnector(BaseConnector): _rows_count: int = None _columns_count: int = None - instance = None + _instance = None def __init__( self, @@ -147,60 +147,86 @@ def execute(self): if cached: return pd.read_parquet(cached) - if isinstance(self.instance, pd.DataFrame): - return self.instance + if isinstance(self._instance, pd.DataFrame): + return self._instance else: - self.instance = self.fetch_data() - return self.instance + self._instance = self._fetch_data() - def build_formula(self): + return self._instance + + def _build_formula(self): """ Build Airtable query formula for filtering. """ condition_strings = [] - for i in self._config.where: - filter_query = f"{i[0]}{i[1]}'{i[2]}'" - condition_strings.append(filter_query) + if self._config.where is not None: + for i in self._config.where: + filter_query = f"{i[0]}{i[1]}'{i[2]}'" + condition_strings.append(filter_query) filter_formula = f'AND({",".join(condition_strings)})' return filter_formula - def fetch_data(self): - """ - Feteches data from airtable server through - API and converts it to DataFrame. - """ + def _request_api(self, params): url = f"{self._root_url}{self._config.base_id}/{self._config.table}" - params = {} - if self._config.where: - params["filterByFormula"] = self.build_formula() response = requests.get( url=url, headers={"Authorization": f"Bearer {self._config.api_key}"}, params=params, ) - if response.status_code == 200: - data = response.json() - data = self.preprocess(data=data) - self._save_cache(data) - else: - raise InvalidRequestError( - f"""Failed to connect to Airtable. - Status code: {response.status_code}, - message: {response.text}""" - ) + return response + + def _fetch_data(self): + """ + Feteches data from airtable server through + API and converts it to DataFrame. + """ + + params = {} + if self._config.where is not None: + params["filterByFormula"] = self._build_formula() + + params["pageSize"] = 100 + params["offset"] = "0" + + data = [] + while True: + response = self._request_api(params=params) + + if response.status_code == 200: + res = response.json() + data.append(res) + if len(res["records"]) < 100: + break + else: + raise InvalidRequestError( + f"""Failed to connect to Airtable. + Status code: {response.status_code}, + message: {response.text}""" + ) + + if "offset" in res: + params["offset"] = res["offset"] + + data = self._preprocess(data=data) return data - def preprocess(self, data): + def _preprocess(self, data): """ Preprocesses Json response data To prepare dataframe correctly. """ - records = [ - {"id": record["id"], **record["fields"]} for record in data["records"] - ] - - df = pd.DataFrame(records) + columns = set() + data_dict_list = [] + for item in data: + for entry in item["records"]: + data_dict = {"id": entry["id"], "createdTime": entry["createdTime"]} + for field_name, field_value in entry["fields"].items(): + data_dict[field_name] = field_value + columns.add(field_name) + data_dict_list.append(data_dict) + + df = pd.DataFrame(data_dict_list) return df @cache @@ -213,12 +239,9 @@ def head(self): DatFrameType: The head of the data source that the conector is connected to . """ - # return self.fetch_data().head() - if isinstance(self.instance, pd.DataFrame): - return self.instance.head() - else: - self.instance = self.fetch_data() - return self.instance.head() + data = self._request_api(params={"maxRecords": 5}) + data = self._preprocess([data.json()]) + return data @cached_property def rows_count(self): @@ -248,7 +271,7 @@ def columns_count(self): """ if self._columns_count is not None: return self._columns_count - data = self.execute() + data = self.head() self._columns_count = len(data.columns) return self._columns_count @@ -262,8 +285,8 @@ def column_hash(self): int: The hash code that is unique to the columns of the data source that the connector is connected to. """ - if not isinstance(self.instance, pd.DataFrame): - self.instance = self.execute() - columns_str = "|".join(self.instance.columns) - columns_str += "WHERE" + self.build_formula() + if not isinstance(self._instance, pd.DataFrame): + self._instance = self.execute() + columns_str = "|".join(self._instance.columns) + columns_str += "WHERE" + self._build_formula() return hashlib.sha256(columns_str.encode("utf-8")).hexdigest() diff --git a/tests/connectors/test_airtable.py b/tests/connectors/test_airtable.py index 5f885dee9..edbe6c7b7 100644 --- a/tests/connectors/test_airtable.py +++ b/tests/connectors/test_airtable.py @@ -100,4 +100,4 @@ def test_columns_count_property(self, mock_request_get): ) mock_request_get.return_value.status_code = 200 rows_count = self.connector.columns_count - self.assertEqual(rows_count, 3) + self.assertEqual(rows_count, 4) From f6955eb4e2cc2ee795240a055ef05cd7bd080f41 Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Fri, 13 Oct 2023 20:03:33 +0530 Subject: [PATCH 10/12] Removed preprocces method --- pandasai/connectors/airtable.py | 36 +++++++++++-------------------- tests/connectors/test_airtable.py | 4 ++-- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 78cd30005..7a3d2aecf 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -195,7 +195,12 @@ def _fetch_data(self): if response.status_code == 200: res = response.json() - data.append(res) + data.extend( + [ + {"id": record["id"], **record["fields"]} + for record in res["records"] + ] + ) if len(res["records"]) < 100: break else: @@ -208,26 +213,7 @@ def _fetch_data(self): if "offset" in res: params["offset"] = res["offset"] - data = self._preprocess(data=data) - return data - - def _preprocess(self, data): - """ - Preprocesses Json response data - To prepare dataframe correctly. - """ - columns = set() - data_dict_list = [] - for item in data: - for entry in item["records"]: - data_dict = {"id": entry["id"], "createdTime": entry["createdTime"]} - for field_name, field_value in entry["fields"].items(): - data_dict[field_name] = field_value - columns.add(field_name) - data_dict_list.append(data_dict) - - df = pd.DataFrame(data_dict_list) - return df + return pd.DataFrame(data) @cache def head(self): @@ -240,8 +226,12 @@ def head(self): that the conector is connected to . """ data = self._request_api(params={"maxRecords": 5}) - data = self._preprocess([data.json()]) - return data + return pd.DataFrame( + [ + {"id": record["id"], **record["fields"]} + for record in data.json()["records"] + ] + ) @cached_property def rows_count(self): diff --git a/tests/connectors/test_airtable.py b/tests/connectors/test_airtable.py index edbe6c7b7..bbf376d8f 100644 --- a/tests/connectors/test_airtable.py +++ b/tests/connectors/test_airtable.py @@ -99,5 +99,5 @@ def test_columns_count_property(self, mock_request_get): self.expected_data_json ) mock_request_get.return_value.status_code = 200 - rows_count = self.connector.columns_count - self.assertEqual(rows_count, 4) + columns_count = self.connector.columns_count + self.assertEqual(columns_count, 3) From b2fe3a72bec205ac5d72be1c64e1cc69cd29116d Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 16 Oct 2023 12:07:30 +0500 Subject: [PATCH 11/12] add offset not in res break condition --- pandasai/connectors/airtable.py | 37 +++++++++++++++------------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 7a3d2aecf..417e609e9 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -178,38 +178,35 @@ def _request_api(self, params): def _fetch_data(self): """ - Feteches data from airtable server through - API and converts it to DataFrame. + Fetches data from the Airtable server via API and converts it to a DataFrame. """ - params = {} + params = { + "pageSize": 100, + "offset": "0" + } + if self._config.where is not None: params["filterByFormula"] = self._build_formula() - params["pageSize"] = 100 - params["offset"] = "0" - data = [] while True: response = self._request_api(params=params) - if response.status_code == 200: - res = response.json() - data.extend( - [ - {"id": record["id"], **record["fields"]} - for record in res["records"] - ] - ) - if len(res["records"]) < 100: - break - else: + if response.status_code != 200: raise InvalidRequestError( - f"""Failed to connect to Airtable. - Status code: {response.status_code}, - message: {response.text}""" + f"Failed to connect to Airtable. " + f"Status code: {response.status_code}, " + f"message: {response.text}" ) + res = response.json() + records = res.get("records", []) + data.extend({"id": record["id"], **record["fields"]} for record in records) + + if len(records) < 100 or "offset" not in res: + break + if "offset" in res: params["offset"] = res["offset"] From fc5c3266395de160fbe29b1f64eaecb20c99eacb Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Tue, 17 Oct 2023 18:29:56 +0530 Subject: [PATCH 12/12] Added more test coverge for where clause and column hashing --- tests/connectors/test_airtable.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/connectors/test_airtable.py b/tests/connectors/test_airtable.py index bbf376d8f..41b0bd0f5 100644 --- a/tests/connectors/test_airtable.py +++ b/tests/connectors/test_airtable.py @@ -13,6 +13,9 @@ def setUp(self) -> None: api_key="your_token", base_id="your_baseid", table="your_table_name", + where= [ + ["Status", "=", "In progress"] + ] ).dict() self.root_url = "https://api.airtable.com/v0/" self.expected_data_json = """ @@ -101,3 +104,20 @@ def test_columns_count_property(self, mock_request_get): mock_request_get.return_value.status_code = 200 columns_count = self.connector.columns_count self.assertEqual(columns_count, 3) + + def test_build_formula_method(self): + formula = self.connector._build_formula() + expected_formula = "AND(Status='In progress')" + self.assertEqual(formula,expected_formula) + + @patch("requests.get") + def test_column_hash(self,mock_request_get): + mock_request_get.return_value.json.return_value = json.loads( + self.expected_data_json + ) + mock_request_get.return_value.status_code = 200 + returned_hash = self.connector.column_hash + self.assertEqual( + returned_hash, + "e4cdc9402a0831fb549d7fdeaaa089b61aeaf61e14b8a044bc027219b2db941e" + )