Skip to content

Commit

Permalink
Fixed cached methods and code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanmaypatil123 committed Oct 13, 2023
1 parent 02007ae commit 3f8f55f
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions pandasai/connectors/airtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
import time
import hashlib
from ..exceptions import InvalidRequestError
from functools import cache, cached_property


class AirtableConnector(BaseConnector):
"""
Airtable connector to retrieving record data.
"""

_rows_count: int = None
_columns_count: int = None
instance = None

def __init__(
self,
config: Optional[Union[AirtableConnectorConfig, dict]] = None,
Expand Down Expand Up @@ -82,7 +87,7 @@ def _get_cache_path(self, include_additional_filters: bool = False):
cache_dir = os.path.join(os.getcwd(), "cache")
return os.path.join(cache_dir, f"{self._config.table}_data.parquet")

def _cached(self):
def _cached(self, include_additional_filters: bool = False):
"""
Returns the cached Airtable data if it exists and
is not older than the cache interval.
Expand All @@ -92,7 +97,7 @@ def _cached(self):
it exists and is not older than the cache
interval, None otherwise.
"""
cache_path = self._get_cache_path()
cache_path = self._get_cache_path(include_additional_filters)
if not os.path.exists(cache_path):
return None

Expand Down Expand Up @@ -138,7 +143,15 @@ def execute(self):
Returns:
DataFrameType: The result of the connector.
"""
return self.fetch_data()
cached = self._cached() or self._cached(include_additional_filters=True)
if cached:
return pd.read_parquet(cached)

if isinstance(self.instance, pd.DataFrame):
return self.instance
else:
self.instance = self.fetch_data()
return self.instance

def build_formula(self):
"""
Expand All @@ -159,7 +172,6 @@ def fetch_data(self):
"""
url = f"{self._root_url}{self._config.base_id}/{self._config.table}"
params = {}

if self._config.where:
params["filterByFormula"] = self.build_formula()
response = requests.get(
Expand Down Expand Up @@ -191,6 +203,7 @@ def preprocess(self, data):
df = pd.DataFrame(records)
return df

@cache
def head(self):
"""
Return the head of the table that
Expand All @@ -200,9 +213,14 @@ def head(self):
DatFrameType: The head of the data source
that the conector is connected to .
"""
return self.fetch_data().head()
# return self.fetch_data().head()
if isinstance(self.instance, pd.DataFrame):
return self.instance.head()
else:
self.instance = self.fetch_data()
return self.instance.head()

@property
@cached_property
def rows_count(self):
"""
Return the number of rows in the data source that the connector is
Expand All @@ -212,10 +230,13 @@ def rows_count(self):
int: The number of rows in the data source that the connector is
connected to.
"""
if self._rows_count is not None:
return self._rows_count
data = self.execute()
return len(data)
self._rows_count = len(data)
return self._rows_count

@property
@cached_property
def columns_count(self):
"""
Return the number of columns in the data source that the connector is
Expand All @@ -225,8 +246,11 @@ def columns_count(self):
int: The number of columns in the data source that the connector is
connected to.
"""
if self._columns_count is not None:
return self._columns_count
data = self.execute()
return len(data.columns)
self._columns_count = len(data.columns)
return self._columns_count

@property
def column_hash(self):
Expand All @@ -238,6 +262,8 @@ def column_hash(self):
int: The hash code that is unique to the columns of the data source
that the connector is connected to.
"""
data = self.execute()
columns_str = "|".join(data.columns)
if not isinstance(self.instance, pd.DataFrame):
self.instance = self.execute()
columns_str = "|".join(self.instance.columns)
columns_str += "WHERE" + self.build_formula()
return hashlib.sha256(columns_str.encode("utf-8")).hexdigest()

0 comments on commit 3f8f55f

Please sign in to comment.