From 3f8f55f090c38958b56783ea9d279220a1e97d9c Mon Sep 17 00:00:00 2001 From: Tanmaypatil123 Date: Fri, 13 Oct 2023 14:18:24 +0530 Subject: [PATCH] Fixed cached methods and code improvements --- pandasai/connectors/airtable.py | 48 +++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 2e5a8c865..3137bf068 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -11,6 +11,7 @@ import time import hashlib from ..exceptions import InvalidRequestError +from functools import cache, cached_property class AirtableConnector(BaseConnector): @@ -18,6 +19,10 @@ class AirtableConnector(BaseConnector): Airtable connector to retrieving record data. """ + _rows_count: int = None + _columns_count: int = None + instance = None + def __init__( self, config: Optional[Union[AirtableConnectorConfig, dict]] = None, @@ -82,7 +87,7 @@ def _get_cache_path(self, include_additional_filters: bool = False): cache_dir = os.path.join(os.getcwd(), "cache") return os.path.join(cache_dir, f"{self._config.table}_data.parquet") - def _cached(self): + def _cached(self, include_additional_filters: bool = False): """ Returns the cached Airtable data if it exists and is not older than the cache interval. @@ -92,7 +97,7 @@ def _cached(self): it exists and is not older than the cache interval, None otherwise. """ - cache_path = self._get_cache_path() + cache_path = self._get_cache_path(include_additional_filters) if not os.path.exists(cache_path): return None @@ -138,7 +143,15 @@ def execute(self): Returns: DataFrameType: The result of the connector. """ - return self.fetch_data() + cached = self._cached() or self._cached(include_additional_filters=True) + if cached: + return pd.read_parquet(cached) + + if isinstance(self.instance, pd.DataFrame): + return self.instance + else: + self.instance = self.fetch_data() + return self.instance def build_formula(self): """ @@ -159,7 +172,6 @@ def fetch_data(self): """ url = f"{self._root_url}{self._config.base_id}/{self._config.table}" params = {} - if self._config.where: params["filterByFormula"] = self.build_formula() response = requests.get( @@ -191,6 +203,7 @@ def preprocess(self, data): df = pd.DataFrame(records) return df + @cache def head(self): """ Return the head of the table that @@ -200,9 +213,14 @@ def head(self): DatFrameType: The head of the data source that the conector is connected to . """ - return self.fetch_data().head() + # return self.fetch_data().head() + if isinstance(self.instance, pd.DataFrame): + return self.instance.head() + else: + self.instance = self.fetch_data() + return self.instance.head() - @property + @cached_property def rows_count(self): """ Return the number of rows in the data source that the connector is @@ -212,10 +230,13 @@ def rows_count(self): int: The number of rows in the data source that the connector is connected to. """ + if self._rows_count is not None: + return self._rows_count data = self.execute() - return len(data) + self._rows_count = len(data) + return self._rows_count - @property + @cached_property def columns_count(self): """ Return the number of columns in the data source that the connector is @@ -225,8 +246,11 @@ def columns_count(self): int: The number of columns in the data source that the connector is connected to. """ + if self._columns_count is not None: + return self._columns_count data = self.execute() - return len(data.columns) + self._columns_count = len(data.columns) + return self._columns_count @property def column_hash(self): @@ -238,6 +262,8 @@ def column_hash(self): int: The hash code that is unique to the columns of the data source that the connector is connected to. """ - data = self.execute() - columns_str = "|".join(data.columns) + if not isinstance(self.instance, pd.DataFrame): + self.instance = self.execute() + columns_str = "|".join(self.instance.columns) + columns_str += "WHERE" + self.build_formula() return hashlib.sha256(columns_str.encode("utf-8")).hexdigest()