Skip to content

Commit

Permalink
Merge branch 'main' into HenriqueBranco/fix-ci-test
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Oct 22, 2023
2 parents 240caac + c66cb06 commit b4c01fd
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 25 deletions.
5 changes: 1 addition & 4 deletions pandasai/connectors/airtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,7 @@ def _fetch_data(self):
Fetches data from the Airtable server via API and converts it to a DataFrame.
"""

params = {
"pageSize": 100,
"offset": "0"
}
params = {"pageSize": 100, "offset": "0"}

if self._config.where is not None:
params["filterByFormula"] = self._build_formula()
Expand Down
2 changes: 1 addition & 1 deletion pandasai/helpers/code_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def execute_code(
code,
logger=self._logger,
file_name=str(prompt_id),
save_charts_path=self._config.save_charts_path,
save_charts_path_str=self._config.save_charts_path,
)

# Get the code to run removing unsafe imports and df overwrites
Expand Down
15 changes: 9 additions & 6 deletions pandasai/helpers/save_chart.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Helper functions to save charts to a file, if plt.show() is called."""
from pathlib import Path
from .logger import Logger
from os import makedirs, path


def add_save_chart(
code: str,
logger: Logger,
file_name: str,
save_charts_path: str = None,
save_charts_path_str: str = None,
) -> str:
"""
Add line to code that save charts to a file, if plt.show() is called.
Expand All @@ -22,12 +22,15 @@ def add_save_chart(
str: Code with line added.
"""

save_charts_path = Path(save_charts_path_str) if save_charts_path_str else None

if save_charts_path is not None:
if not path.exists(save_charts_path):
makedirs(save_charts_path)
save_charts_path.mkdir(parents=True, exist_ok=True)

save_charts_file = save_charts_path / f"{file_name}.png"

save_charts_file = path.join(save_charts_path, file_name + ".png")
code = code.replace("temp_chart.png", save_charts_file)
code = code.replace("temp_chart.png", save_charts_file.as_posix())
logger.log(f"Saving charts to {save_charts_file}")

return code
1 change: 0 additions & 1 deletion pandasai/llm/google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def _generate_text(self, prompt: str) -> str:
else:
raise UnsupportedModelError(self.model)


return str(completion)

@property
Expand Down
1 change: 0 additions & 1 deletion pandasai/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
else:
raise UnsupportedModelError(self.model)


return response

@property
Expand Down
5 changes: 4 additions & 1 deletion pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,10 @@ def _get_sample_head(self) -> DataFrameType:
sampler = DataSampler(head)
sampled_head = sampler.sample(rows_to_display)

return self._truncate_head_columns(sampled_head)
if self.lake.config.enforce_privacy:
return sampled_head
else:
return self._truncate_head_columns(sampled_head)

def _load_from_config(self, name: str):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/callbacks/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def setUp(self):

def tearDown(self):
if os.path.exists(self.filename):
self.callback.file.close()
os.remove(self.filename)

def test_on_code(self):
Expand Down
14 changes: 6 additions & 8 deletions tests/connectors/test_airtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ def setUp(self) -> None:
api_key="your_token",
base_id="your_baseid",
table="your_table_name",
where= [
["Status", "=", "In progress"]
]
where=[["Status", "=", "In progress"]],
).dict()
self.root_url = "https://api.airtable.com/v0/"
self.expected_data_json = """
Expand Down Expand Up @@ -104,20 +102,20 @@ def test_columns_count_property(self, mock_request_get):
mock_request_get.return_value.status_code = 200
columns_count = self.connector.columns_count
self.assertEqual(columns_count, 3)

def test_build_formula_method(self):
formula = self.connector._build_formula()
expected_formula = "AND(Status='In progress')"
self.assertEqual(formula,expected_formula)
self.assertEqual(formula, expected_formula)

@patch("requests.get")
def test_column_hash(self,mock_request_get):
def test_column_hash(self, mock_request_get):
mock_request_get.return_value.json.return_value = json.loads(
self.expected_data_json
)
mock_request_get.return_value.status_code = 200
returned_hash = self.connector.column_hash
self.assertEqual(
returned_hash,
"e4cdc9402a0831fb549d7fdeaaa089b61aeaf61e14b8a044bc027219b2db941e"
"e4cdc9402a0831fb549d7fdeaaa089b61aeaf61e14b8a044bc027219b2db941e",
)
8 changes: 5 additions & 3 deletions tests/helpers/test_save_chart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
import tempfile
from pandasai.helpers.logger import Logger
from pandasai.helpers.save_chart import add_save_chart
Expand Down Expand Up @@ -34,13 +35,14 @@ def test_add_save_chart_with_user_defined_path(self):
file_name = "temp_chart"

with tempfile.TemporaryDirectory() as temp_dir:
save_charts_path = os.path.join(temp_dir, "charts")
assert not os.path.exists(save_charts_path)
save_charts_path = Path(temp_dir) / "charts"
assert not save_charts_path.exists()
full_path_to_file = save_charts_path / f"{file_name}.png"

expected_code = f"""
import matplotlib.pyplot as plt
plt.plot([1, 2, 3], [4, 5, 6])
plt.savefig("{os.path.join(save_charts_path, file_name)}.png")
plt.savefig("{full_path_to_file.as_posix()}")
plt.show()
"""
result = add_save_chart(code, logger, file_name, save_charts_path)
Expand Down

0 comments on commit b4c01fd

Please sign in to comment.