From 5fd17ac11326c649bb3477dd48c59c2a455f8eff Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Thu, 5 Dec 2024 10:51:01 +0100 Subject: [PATCH] fix(pull): clean and reuse Session class --- pandasai/__init__.py | 15 ++++++++--- pandasai/data_loader/loader.py | 29 ++++++++++++++++++--- pandasai/dataframe/base.py | 28 ++++++++++++++------ pandasai/helpers/request.py | 9 +++++-- pandasai/vectorstores/bamboo_vectorstore.py | 4 +-- 5 files changed, 66 insertions(+), 19 deletions(-) diff --git a/pandasai/__init__.py b/pandasai/__init__.py index e335fc8ac..ad3e50f5e 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -9,10 +9,10 @@ from zipfile import ZipFile import pandas as pd -import requests -from pandasai.exceptions import DatasetNotFound +from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError from pandasai.helpers.path import find_project_root +from pandasai.helpers.request import get_pandaai_session from .agent import Agent from .helpers.cache import Cache from .dataframe.base import DataFrame @@ -85,10 +85,17 @@ def load(dataset_path: str, virtualized=False) -> DataFrame: if not os.path.exists(dataset_full_path): api_key = os.environ.get("PANDAAI_API_KEY", None) api_url = os.environ.get("PANDAAI_API_URL", None) + if not api_url or not api_key: + raise PandasAIApiKeyError( + "Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to push dataset to the remote server" + ) + + request_session = get_pandaai_session() + headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"} - file_data = requests.get( - f"{api_url}/datasets/pull", headers=headers, params={"path": dataset_path} + file_data = request_session.get( + "/datasets/pull", headers=headers, params={"path": dataset_path} ) if file_data.status_code != 200: raise DatasetNotFound("Dataset not found!") diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index f1b1c7c33..d7a133320 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -35,16 +35,25 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame: self._cache_data(df, cache_file) table_name = self.schema["source"]["table"] + table_description = self.schema.get("description", None) - return DataFrame(df, schema=self.schema, name=table_name, path=dataset_path) + return DataFrame( + df, + schema=self.schema, + name=table_name, + description=table_description, + path=dataset_path, + ) else: # Initialize new dataset loader for virtualization data_loader = self.copy() table_name = self.schema["source"]["table"] + table_description = self.schema.get("description", None) return VirtualDataFrame( schema=self.schema, data_loader=data_loader, name=table_name, + description=table_description, path=dataset_path, ) @@ -88,10 +97,24 @@ def _is_cache_valid(self, cache_file: str) -> bool: def _read_cache(self, cache_file: str) -> DataFrame: cache_format = self.schema["destination"]["format"] + table_name = self.schema["source"]["table"] + table_description = self.schema.get("description", None) if cache_format == "parquet": - return DataFrame(pd.read_parquet(cache_file), path=self.dataset_path) + return DataFrame( + pd.read_parquet(cache_file), + schema=self.schema, + path=self.dataset_path, + name=table_name, + description=table_description, + ) elif cache_format == "csv": - return DataFrame(pd.read_csv(cache_file), path=self.dataset_path) + return DataFrame( + pd.read_csv(cache_file), + schema=self.schema, + path=self.dataset_path, + name=table_name, + description=table_description, + ) else: raise ValueError(f"Unsupported cache format: {cache_format}") diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index b7ae0dd8b..ecdfb5620 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -6,13 +6,12 @@ import pandas as pd from typing import TYPE_CHECKING, List, Optional, Union, Dict, ClassVar -import requests import yaml from pandasai.config import Config import hashlib -from pandasai.exceptions import DatasetNotFound +from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError from pandasai.helpers.dataframe_serializer import ( DataframeSerializer, DataframeSerializerType, @@ -256,18 +255,27 @@ def push(self): def pull(self): api_key = os.environ.get("PANDAAI_API_KEY", None) - api_url = os.environ.get("PANDAAI_API_URL", None) + + if not api_key: + raise PandasAIApiKeyError( + "Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to push dataset to the remote server" + ) + + request_session = get_pandaai_session() + headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"} - file_data = requests.get( - f"{api_url}/datasets/pull", headers=headers, params={"path": self.path} + file_data = request_session.get( + "/datasets/pull", headers=headers, params={"path": self.path} ) if file_data.status_code != 200: raise DatasetNotFound("Remote dataset not found to pull!") with ZipFile(BytesIO(file_data.content)) as zip_file: for file_name in zip_file.namelist(): - target_path = os.path.join(self.path, file_name) + target_path = os.path.join( + find_project_root(), "datasets", self.path, file_name + ) # Check if the file already exists if os.path.exists(target_path): @@ -281,6 +289,10 @@ def pull(self): f.write(zip_file.read(file_name)) # reloads the Dataframe - from pandasai import load + from pandasai import DatasetLoader - self = load(self.path, virtualized=not isinstance(self, DataFrame)) + dataset_loader = DatasetLoader() + df = dataset_loader.load(self.path, virtualized=not isinstance(self, DataFrame)) + self.__init__( + df, schema=df.schema, name=df.name, description=df.description, path=df.path + ) diff --git a/pandasai/helpers/request.py b/pandasai/helpers/request.py index fad3e0641..82ddb269f 100644 --- a/pandasai/helpers/request.py +++ b/pandasai/helpers/request.py @@ -35,7 +35,7 @@ def __init__( self._logger = logger or Logger() def get(self, path=None, **kwargs): - return self.make_request("GET", path, **kwargs)["data"] + return self.make_request("GET", path, **kwargs) def post(self, path=None, **kwargs): return self.make_request("POST", path, **kwargs) @@ -79,7 +79,12 @@ def make_request( **kwargs, ) - data = response.json() + try: + data = response.json() + except ValueError: + if response.status_code == 200: + return response + if response.status_code not in [200, 201]: if "message" in data: raise PandasAIApiCallError(data["message"]) diff --git a/pandasai/vectorstores/bamboo_vectorstore.py b/pandasai/vectorstores/bamboo_vectorstore.py index d7421db09..6f3cc01a9 100644 --- a/pandasai/vectorstores/bamboo_vectorstore.py +++ b/pandasai/vectorstores/bamboo_vectorstore.py @@ -58,7 +58,7 @@ def get_relevant_qa_documents(self, question: str, k: int = None) -> List[dict]: try: docs = self._session.get( "/training-data/qa/relevant-qa", params={"query": question, "count": k} - ) + )["data"] return docs["docs"] except Exception: self._logger.log("Querying without using training data.", logging.ERROR) @@ -77,7 +77,7 @@ def get_relevant_docs_documents( docs = self._session.get( "/training-docs/docs/relevant-docs", params={"query": question, "count": k}, - ) + )["data"] return docs["docs"] except Exception: self._logger.log("Querying without using training docs.", logging.ERROR)