Skip to content

Commit

Permalink
refactor: to_dict() parameters
Browse files Browse the repository at this point in the history
* (refactor): update behaviour of `to_dict()` method, make it possible
  to accept `into` and `as_series` parameters (the last one is for
  polars dataframes).
* (tests): add tests for casting the dataframe to a dictionary, add
  tests for passing parameters in the proxy-call to dataframe's
  `to_dict()` method.
  • Loading branch information
nautics889 committed Sep 10, 2023
1 parent f5c4be0 commit 58be64b
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
16 changes: 14 additions & 2 deletions pandasai/smart_dataframe/abstract_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


class DataframeAbstract(ABC):
_engine: str

@property
def dataframe(self):
raise NotImplementedError("This method must be implemented in the child class")
Expand Down Expand Up @@ -159,8 +161,18 @@ def to_json(self, path):
def to_sql(self, name, con):
return self.dataframe.to_sql(name=name, con=con)

def to_dict(self, orient):
return self.dataframe.to_dict(orient=orient)
def to_dict(self, orient="dict", into=dict, as_series=True):
"""
A proxy-call to the dataframe's `.to_dict()`.
"""
if self._engine == "pandas":
return self.dataframe.to_dict(orient=orient, into=into)
elif self._engine == "polars":
return self.dataframe.to_dict(as_series=as_series)
raise RuntimeError(
f"{self.__class__} object has unknown engine type. "
f"Possible engines: 'pandas', 'polars'. Actual '{self._engine}'."
)

def to_numpy(self):
return self.dataframe.to_numpy()
Expand Down
52 changes: 52 additions & 0 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import sys
from collections import defaultdict
from typing import Optional
from unittest.mock import patch, Mock
from uuid import UUID
Expand Down Expand Up @@ -107,6 +108,16 @@ def smart_dataframe(self, llm, sample_df, sample_head):
sample_head=sample_head,
)

@pytest.fixture
def smart_dataframe_mocked_df(self, llm, sample_df, sample_head):
smart_df = SmartDataframe(
sample_df,
config={"llm": llm, "enable_cache": False},
sample_head=sample_head,
)
smart_df._core._df = Mock()
return smart_df

@pytest.fixture
def custom_middleware(self):
class CustomMiddleware(Middleware):
Expand Down Expand Up @@ -209,6 +220,47 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
last_prompt = df.last_prompt.replace("\r\n", "\n")
assert last_prompt == expected_prompt

def test_to_dict(self, smart_dataframe: SmartDataframe):
expected_keys = ("country", "gdp", "happiness_index")

result_dict = smart_dataframe.to_dict()

assert isinstance(result_dict, dict)
assert all(key in result_dict for key in expected_keys)

@pytest.mark.parametrize(
"to_dict_params,expected_passing_params,engine_type",
[
({}, {"orient": "dict", "into": dict}, "pandas"),
({}, {"as_series": True}, "polars"),
({"orient": "dict"}, {"orient": "dict", "into": dict}, "pandas"),
(
{"orient": "dict", "into": defaultdict},
{"orient": "dict", "into": defaultdict},
"pandas",
),
({"as_series": False}, {"as_series": False}, "polars"),
(
{"as_series": False, "orient": "dict", "into": defaultdict},
{"as_series": False},
"polars",
),
],
)
def test_to_dict_passing_parameters(
self,
smart_dataframe_mocked_df: SmartDataframe,
to_dict_params,
engine_type,
expected_passing_params,
):
smart_dataframe_mocked_df._engine = engine_type
smart_dataframe_mocked_df.to_dict(**to_dict_params)
# noinspection PyUnresolvedReferences
smart_dataframe_mocked_df.dataframe.to_dict.assert_called_once_with(
**expected_passing_params
)

def test_extract_code(self, llm):
code = """```python
result = {'happiness': 0.5, 'gdp': 0.8}
Expand Down

0 comments on commit 58be64b

Please sign in to comment.