Skip to content

Commit

Permalink
fix(pull): clean and reuse Session class
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Dec 5, 2024
1 parent eebfcb5 commit 5fd17ac
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 19 deletions.
15 changes: 11 additions & 4 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from zipfile import ZipFile

import pandas as pd
import requests

from pandasai.exceptions import DatasetNotFound
from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.request import get_pandaai_session
from .agent import Agent
from .helpers.cache import Cache
from .dataframe.base import DataFrame
Expand Down Expand Up @@ -85,10 +85,17 @@ def load(dataset_path: str, virtualized=False) -> DataFrame:
if not os.path.exists(dataset_full_path):
api_key = os.environ.get("PANDAAI_API_KEY", None)
api_url = os.environ.get("PANDAAI_API_URL", None)
if not api_url or not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to push dataset to the remote server"
)

request_session = get_pandaai_session()

headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}

file_data = requests.get(
f"{api_url}/datasets/pull", headers=headers, params={"path": dataset_path}
file_data = request_session.get(
"/datasets/pull", headers=headers, params={"path": dataset_path}
)
if file_data.status_code != 200:
raise DatasetNotFound("Dataset not found!")
Expand Down
29 changes: 26 additions & 3 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,25 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
self._cache_data(df, cache_file)

table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)

return DataFrame(df, schema=self.schema, name=table_name, path=dataset_path)
return DataFrame(
df,
schema=self.schema,
name=table_name,
description=table_description,
path=dataset_path,
)
else:
# Initialize new dataset loader for virtualization
data_loader = self.copy()
table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)
return VirtualDataFrame(
schema=self.schema,
data_loader=data_loader,
name=table_name,
description=table_description,
path=dataset_path,
)

Expand Down Expand Up @@ -88,10 +97,24 @@ def _is_cache_valid(self, cache_file: str) -> bool:

def _read_cache(self, cache_file: str) -> DataFrame:
cache_format = self.schema["destination"]["format"]
table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)
if cache_format == "parquet":
return DataFrame(pd.read_parquet(cache_file), path=self.dataset_path)
return DataFrame(
pd.read_parquet(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
elif cache_format == "csv":
return DataFrame(pd.read_csv(cache_file), path=self.dataset_path)
return DataFrame(
pd.read_csv(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

Expand Down
28 changes: 20 additions & 8 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import pandas as pd
from typing import TYPE_CHECKING, List, Optional, Union, Dict, ClassVar

import requests
import yaml


from pandasai.config import Config
import hashlib
from pandasai.exceptions import DatasetNotFound
from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError
from pandasai.helpers.dataframe_serializer import (
DataframeSerializer,
DataframeSerializerType,
Expand Down Expand Up @@ -256,18 +255,27 @@ def push(self):

def pull(self):
api_key = os.environ.get("PANDAAI_API_KEY", None)
api_url = os.environ.get("PANDAAI_API_URL", None)

if not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to push dataset to the remote server"
)

request_session = get_pandaai_session()

headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}

file_data = requests.get(
f"{api_url}/datasets/pull", headers=headers, params={"path": self.path}
file_data = request_session.get(
"/datasets/pull", headers=headers, params={"path": self.path}
)
if file_data.status_code != 200:
raise DatasetNotFound("Remote dataset not found to pull!")

with ZipFile(BytesIO(file_data.content)) as zip_file:
for file_name in zip_file.namelist():
target_path = os.path.join(self.path, file_name)
target_path = os.path.join(
find_project_root(), "datasets", self.path, file_name
)

# Check if the file already exists
if os.path.exists(target_path):
Expand All @@ -281,6 +289,10 @@ def pull(self):
f.write(zip_file.read(file_name))

# reloads the Dataframe
from pandasai import load
from pandasai import DatasetLoader

self = load(self.path, virtualized=not isinstance(self, DataFrame))
dataset_loader = DatasetLoader()
df = dataset_loader.load(self.path, virtualized=not isinstance(self, DataFrame))
self.__init__(
df, schema=df.schema, name=df.name, description=df.description, path=df.path
)
9 changes: 7 additions & 2 deletions pandasai/helpers/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self._logger = logger or Logger()

def get(self, path=None, **kwargs):
return self.make_request("GET", path, **kwargs)["data"]
return self.make_request("GET", path, **kwargs)

def post(self, path=None, **kwargs):
return self.make_request("POST", path, **kwargs)
Expand Down Expand Up @@ -79,7 +79,12 @@ def make_request(
**kwargs,
)

data = response.json()
try:
data = response.json()
except ValueError:
if response.status_code == 200:
return response

if response.status_code not in [200, 201]:
if "message" in data:
raise PandasAIApiCallError(data["message"])
Expand Down
4 changes: 2 additions & 2 deletions pandasai/vectorstores/bamboo_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_relevant_qa_documents(self, question: str, k: int = None) -> List[dict]:
try:
docs = self._session.get(
"/training-data/qa/relevant-qa", params={"query": question, "count": k}
)
)["data"]
return docs["docs"]
except Exception:
self._logger.log("Querying without using training data.", logging.ERROR)
Expand All @@ -77,7 +77,7 @@ def get_relevant_docs_documents(
docs = self._session.get(
"/training-docs/docs/relevant-docs",
params={"query": question, "count": k},
)
)["data"]
return docs["docs"]
except Exception:
self._logger.log("Querying without using training docs.", logging.ERROR)
Expand Down

0 comments on commit 5fd17ac

Please sign in to comment.