Skip to content

Commit

Permalink
feat(Dataframe): pull method to fetch dataset from remote server (#1446)
Browse files Browse the repository at this point in the history
* feat(dataframe): save dataframe to path

* feat(dataframe): save dataframe to path

* feat(dataframe): save path in dataframe

* feat(push): push dataset to the remote server

* feat(pull): pull dataset files

* fix(pull): clean and reuse Session class

* fix(pull): clean error messages

* Update pandasai/__init__.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: Gabriele Venturi <[email protected]>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 5, 2024
1 parent cb55a84 commit 61773a8
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 17 deletions.
29 changes: 29 additions & 0 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
PandasAI is a wrapper around a LLM to make dataframes conversational
"""

from io import BytesIO
import os
from typing import List
from zipfile import ZipFile

import pandas as pd

from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.request import get_pandaai_session
from .agent import Agent
from .helpers.cache import Cache
from .dataframe.base import DataFrame
Expand Down Expand Up @@ -74,6 +81,28 @@ def load(dataset_path: str, virtualized=False) -> DataFrame:
DataFrame: A new PandasAI DataFrame instance with loaded data.
"""
global _dataset_loader
dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path)
if not os.path.exists(dataset_full_path):
api_key = os.environ.get("PANDAAI_API_KEY", None)
api_url = os.environ.get("PANDAAI_API_URL", None)
if not api_url or not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to pull dataset from the remote server"
)

request_session = get_pandaai_session()

headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}

file_data = request_session.get(
"/datasets/pull", headers=headers, params={"path": dataset_path}
)
if file_data.status_code != 200:
raise DatasetNotFound("Dataset not found!")

with ZipFile(BytesIO(file_data.content)) as zip_file:
zip_file.extractall(dataset_full_path)

return _dataset_loader.load(dataset_path, virtualized)


Expand Down
29 changes: 26 additions & 3 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,25 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
self._cache_data(df, cache_file)

table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)

return DataFrame(df, schema=self.schema, name=table_name, path=dataset_path)
return DataFrame(
df,
schema=self.schema,
name=table_name,
description=table_description,
path=dataset_path,
)
else:
# Initialize new dataset loader for virtualization
data_loader = self.copy()
table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)
return VirtualDataFrame(
schema=self.schema,
data_loader=data_loader,
name=table_name,
description=table_description,
path=dataset_path,
)

Expand Down Expand Up @@ -88,10 +97,24 @@ def _is_cache_valid(self, cache_file: str) -> bool:

def _read_cache(self, cache_file: str) -> DataFrame:
cache_format = self.schema["destination"]["format"]
table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)
if cache_format == "parquet":
return DataFrame(pd.read_parquet(cache_file))
return DataFrame(
pd.read_parquet(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
elif cache_format == "csv":
return DataFrame(pd.read_csv(cache_file))
return DataFrame(
pd.read_csv(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

Expand Down
57 changes: 49 additions & 8 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations
from io import BytesIO
import os
import re
from zipfile import ZipFile
import pandas as pd
from typing import TYPE_CHECKING, List, Optional, Union, Dict, ClassVar

Expand All @@ -9,13 +11,13 @@

from pandasai.config import Config
import hashlib
from pandasai.exceptions import PandasAIApiKeyError
from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError
from pandasai.helpers.dataframe_serializer import (
DataframeSerializer,
DataframeSerializerType,
)
from pandasai.helpers.path import find_project_root
from pandasai.helpers.request import Session
from pandasai.helpers.request import get_pandaai_session


if TYPE_CHECKING:
Expand Down Expand Up @@ -220,14 +222,9 @@ def save(
print(f"Dataset saved successfully to path: {dataset_directory}")

def push(self):
api_url = os.environ.get("PANDAAI_API_URL", None)
api_key = os.environ.get("PANDAAI_API_KEY", None)
if not api_url or not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to push dataset to the remote server"
)

request_session = Session(endpoint_url=api_url, api_key=api_key)
request_session = get_pandaai_session()

params = {
"path": self.path,
Expand Down Expand Up @@ -255,3 +252,47 @@ def push(self):
params=params,
headers=headers,
)

def pull(self):
api_key = os.environ.get("PANDAAI_API_KEY", None)

if not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to pull dataset to the remote server"
)

request_session = get_pandaai_session()

headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}

file_data = request_session.get(
"/datasets/pull", headers=headers, params={"path": self.path}
)
if file_data.status_code != 200:
raise DatasetNotFound("Remote dataset not found to pull!")

with ZipFile(BytesIO(file_data.content)) as zip_file:
for file_name in zip_file.namelist():
target_path = os.path.join(
find_project_root(), "datasets", self.path, file_name
)

# Check if the file already exists
if os.path.exists(target_path):
print(f"Replacing existing file: {target_path}")

# Ensure target directory exists
os.makedirs(os.path.dirname(target_path), exist_ok=True)

# Extract the file
with open(target_path, "wb") as f:
f.write(zip_file.read(file_name))

# reloads the Dataframe
from pandasai import DatasetLoader

dataset_loader = DatasetLoader()
df = dataset_loader.load(self.path, virtualized=not isinstance(self, DataFrame))
self.__init__(
df, schema=df.schema, name=df.name, description=df.description, path=df.path
)
13 changes: 11 additions & 2 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ class PandasAIApiKeyError(Exception):
Exception (Exception): PandasAIApiKeyError
"""

def __init__(self):
message = PANDASBI_SETUP_MESSAGE
def __init__(self, message: str = None):
if not message:
message = PANDASBI_SETUP_MESSAGE
super().__init__(message)


Expand Down Expand Up @@ -264,3 +265,11 @@ class MaliciousCodeGenerated(Exception):
Args:
Exception (Exception): MaliciousCodeGenerated
"""


class DatasetNotFound(Exception):
"""
Raise error if dataset not found
Args:
Exception (Exception): DatasetNotFound
"""
20 changes: 18 additions & 2 deletions pandasai/helpers/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self._logger = logger or Logger()

def get(self, path=None, **kwargs):
return self.make_request("GET", path, **kwargs)["data"]
return self.make_request("GET", path, **kwargs)

def post(self, path=None, **kwargs):
return self.make_request("POST", path, **kwargs)
Expand Down Expand Up @@ -79,7 +79,12 @@ def make_request(
**kwargs,
)

data = response.json()
try:
data = response.json()
except ValueError:
if response.status_code == 200:
return response

if response.status_code not in [200, 201]:
if "message" in data:
raise PandasAIApiCallError(data["message"])
Expand All @@ -91,3 +96,14 @@ def make_request(
except requests.exceptions.RequestException as e:
self._logger.log(f"Request failed: {traceback.format_exc()}", logging.ERROR)
raise PandasAIApiCallError(f"Request failed: {e}") from e


def get_pandaai_session():
api_url = os.environ.get("PANDAAI_API_URL", None)
api_key = os.environ.get("PANDAAI_API_KEY", None)
if not api_url or not api_key:
raise PandasAIApiKeyError(
"Set PANDAAI_API_URL and PANDAAI_API_KEY in environment to push/pull dataset to the remote server"
)

return Session(endpoint_url=api_url, api_key=api_key)
4 changes: 2 additions & 2 deletions pandasai/vectorstores/bamboo_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_relevant_qa_documents(self, question: str, k: int = None) -> List[dict]:
try:
docs = self._session.get(
"/training-data/qa/relevant-qa", params={"query": question, "count": k}
)
)["data"]
return docs["docs"]
except Exception:
self._logger.log("Querying without using training data.", logging.ERROR)
Expand All @@ -77,7 +77,7 @@ def get_relevant_docs_documents(
docs = self._session.get(
"/training-docs/docs/relevant-docs",
params={"query": question, "count": k},
)
)["data"]
return docs["docs"]
except Exception:
self._logger.log("Querying without using training docs.", logging.ERROR)
Expand Down

0 comments on commit 61773a8

Please sign in to comment.