From 58be64b54401057d88e185dd966382cbfa7549c9 Mon Sep 17 00:00:00 2001 From: Nautics889 Date: Sun, 10 Sep 2023 16:53:52 +0300 Subject: [PATCH] refactor: `to_dict()` parameters * (refactor): update behaviour of `to_dict()` method, make it possible to accept `into` and `as_series` parameters (the last one is for polars dataframes). * (tests): add tests for casting the dataframe to a dictionary, add tests for passing parameters in the proxy-call to dataframe's `to_dict()` method. --- pandasai/smart_dataframe/abstract_df.py | 16 +++++++- tests/test_smartdataframe.py | 52 +++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/pandasai/smart_dataframe/abstract_df.py b/pandasai/smart_dataframe/abstract_df.py index ce95b2f76..65397e05c 100644 --- a/pandasai/smart_dataframe/abstract_df.py +++ b/pandasai/smart_dataframe/abstract_df.py @@ -2,6 +2,8 @@ class DataframeAbstract(ABC): + _engine: str + @property def dataframe(self): raise NotImplementedError("This method must be implemented in the child class") @@ -159,8 +161,18 @@ def to_json(self, path): def to_sql(self, name, con): return self.dataframe.to_sql(name=name, con=con) - def to_dict(self, orient): - return self.dataframe.to_dict(orient=orient) + def to_dict(self, orient="dict", into=dict, as_series=True): + """ + A proxy-call to the dataframe's `.to_dict()`. + """ + if self._engine == "pandas": + return self.dataframe.to_dict(orient=orient, into=into) + elif self._engine == "polars": + return self.dataframe.to_dict(as_series=as_series) + raise RuntimeError( + f"{self.__class__} object has unknown engine type. " + f"Possible engines: 'pandas', 'polars'. Actual '{self._engine}'." + ) def to_numpy(self): return self.dataframe.to_numpy() diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 1f2c8a0fc..c5e5a2958 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -2,6 +2,7 @@ import json import os import sys +from collections import defaultdict from typing import Optional from unittest.mock import patch, Mock from uuid import UUID @@ -107,6 +108,16 @@ def smart_dataframe(self, llm, sample_df, sample_head): sample_head=sample_head, ) + @pytest.fixture + def smart_dataframe_mocked_df(self, llm, sample_df, sample_head): + smart_df = SmartDataframe( + sample_df, + config={"llm": llm, "enable_cache": False}, + sample_head=sample_head, + ) + smart_df._core._df = Mock() + return smart_df + @pytest.fixture def custom_middleware(self): class CustomMiddleware(Middleware): @@ -209,6 +220,47 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: last_prompt = df.last_prompt.replace("\r\n", "\n") assert last_prompt == expected_prompt + def test_to_dict(self, smart_dataframe: SmartDataframe): + expected_keys = ("country", "gdp", "happiness_index") + + result_dict = smart_dataframe.to_dict() + + assert isinstance(result_dict, dict) + assert all(key in result_dict for key in expected_keys) + + @pytest.mark.parametrize( + "to_dict_params,expected_passing_params,engine_type", + [ + ({}, {"orient": "dict", "into": dict}, "pandas"), + ({}, {"as_series": True}, "polars"), + ({"orient": "dict"}, {"orient": "dict", "into": dict}, "pandas"), + ( + {"orient": "dict", "into": defaultdict}, + {"orient": "dict", "into": defaultdict}, + "pandas", + ), + ({"as_series": False}, {"as_series": False}, "polars"), + ( + {"as_series": False, "orient": "dict", "into": defaultdict}, + {"as_series": False}, + "polars", + ), + ], + ) + def test_to_dict_passing_parameters( + self, + smart_dataframe_mocked_df: SmartDataframe, + to_dict_params, + engine_type, + expected_passing_params, + ): + smart_dataframe_mocked_df._engine = engine_type + smart_dataframe_mocked_df.to_dict(**to_dict_params) + # noinspection PyUnresolvedReferences + smart_dataframe_mocked_df.dataframe.to_dict.assert_called_once_with( + **expected_passing_params + ) + def test_extract_code(self, llm): code = """```python result = {'happiness': 0.5, 'gdp': 0.8}