Skip to content

Commit

Permalink
feat: add save_charts option (Sinaptik-AI#188)
Browse files Browse the repository at this point in the history
* test: add save_chart tests

Create tests to confirm plt.savefig() is injected into llm output.
Affix letter to timestamp when multiple charts are created.

* feat: add save_chart.py module

Initial commit of save_chart.py.
- compare_ast()
Compare two AST nodes for equality.
- add_save_chart()
Add line to code that save charts to a file, if plt.show() is called.

Tests failing!

* test: add compare_ast test

- Add compare_ast test.
- Update save_chart tests to use compare_ast().

Tests failing!

* feat: update add_save_chart() to use unique names

Affix a letter character to the end of the filename when more than one plt.show() call exists in the code.

fix: correct project root variable

* feat: add save_charts argument to PandasAI

Add save_charts argument to init call of PandasAi.
When set to True, a call to `plt.save_fig()` is injected before any calls to `plt.show_plot()`
in the run code.

* docs: add instructions to save charts

* feat: print chart save path

Add print expression to run code so user can locate saved charts.
Update test to check for print expression.
  • Loading branch information
jonbiemond authored May 29, 2023
1 parent 06796f1 commit 7ae7f1d
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 3 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ pip install pandasai

> Disclaimer: GDP data was collected from [this source](https://ourworldindata.org/grapher/gross-domestic-product?tab=table), published by World Development Indicators - World Bank (2022.05.26) and collected at National accounts data - World Bank / OECD. It relates to the year of 2020. Happiness indexes were extracted from [the World Happiness Report](https://ftnnews.com/images/stories/documents/2020/WHR20.pdf). Another useful [link](https://data.world/makeovermonday/2020w19-world-happiness-report-2020).
PandasAI is designed to be used in conjunction with [pandas](https://github.com/pandas-dev/pandas). It makes Pandas conversational, allowing you to ask questions about your data and get answers back, in the form of pandas DataFrames. For example, you can ask PandasAI to find all the rows in a DataFrame where the value of a column is greater than 5, and it will return a DataFrame containing only those rows:
PandasAI is designed to be used in conjunction with [pandas](https://github.com/pandas-dev/pandas). It makes Pandas conversational, allowing you to ask questions about your data and get answers back, in the form of pandas DataFrames.

### Queries

For example, you can ask PandasAI to find all the rows in a DataFrame where the value of a column is greater than 5, and it will return a DataFrame containing only those rows:

```python
import pandas as pd
Expand Down Expand Up @@ -78,17 +82,23 @@ The above code will return the following:
19012600725504
```

### Charts

You can also ask PandasAI to draw a graph:

```python
pandas_ai(
df,
"Plot the histogram of countries showing for each the gpd, using different colors for each bar",
"Plot the histogram of countries showing for each the gdp, using different colors for each bar",
)
```

![Chart](images/histogram-chart.png?raw=true)

You can save any charts generated by PandasAI by setting the `save_charts` parameter to `True` in the `PandasAI` constructor. For example, `PandasAI(llm, save_charts=True)`. Charts are saved in `./pandasai/exports/charts` .

### Multiple DataFrames

Additionally, you can also pass in multiple dataframes to PandasAI and ask questions relating them.

```python
Expand Down
10 changes: 9 additions & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" PandasAI is a wrapper around a LLM to make dataframes convesational """
""" PandasAI is a wrapper around a LLM to make dataframes conversational """
import ast
import io
import re
Expand All @@ -13,6 +13,7 @@
from .exceptions import LLMNotFoundError
from .helpers.anonymizer import anonymize_dataframe_head
from .helpers.notebook import Notebook
from .helpers.save_chart import add_save_chart
from .llm.base import LLM
from .prompts.correct_error_prompt import CorrectErrorPrompt
from .prompts.generate_python_code import GeneratePythonCodePrompt
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
conversational=True,
verbose=False,
enforce_privacy=False,
save_charts=False,
):
if llm is None:
raise LLMNotFoundError(
Expand All @@ -57,6 +59,7 @@ def __init__(
self._is_conversational_answer = conversational
self._verbose = verbose
self._enforce_privacy = enforce_privacy
self._save_charts = save_charts

self.notebook = Notebook()
self._in_notebook = self.notebook.in_notebook()
Expand Down Expand Up @@ -223,6 +226,11 @@ def run_code(
"""Run the code in the current context and return the result"""

multiple: bool = isinstance(data_frame, list)

# Add save chart code
if self._save_charts:
code = add_save_chart(code)

# Get the code to run removing unsafe imports and df overwrites
code_to_run = self.clean_code(code)
self.last_run_code = code_to_run
Expand Down
80 changes: 80 additions & 0 deletions pandasai/helpers/save_chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Helper functions to save charts to a file, if plt.show() is called."""
import ast
import os
from datetime import datetime
from itertools import zip_longest
from os.path import dirname
from typing import Union

import astor


def compare_ast(
node1: Union[ast.expr, list[ast.expr], ast.stmt, ast.AST],
node2: Union[ast.expr, list[ast.expr], ast.stmt, ast.AST],
ignore_args=False,
) -> bool:
"""Compare two AST nodes for equality.
Source: https://stackoverflow.com/a/66733795/11080806"""
if type(node1) is not type(node2):
return False

if isinstance(node1, ast.AST):
for k, node in vars(node1).items():
if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
continue
if ignore_args and k == "args":
continue
if not compare_ast(node, getattr(node2, k), ignore_args):
return False
return True

if isinstance(node1, list) and isinstance(node2, list):
return all(
compare_ast(n1, n2, ignore_args) for n1, n2 in zip_longest(node1, node2)
)

return node1 == node2


def add_save_chart(code: str) -> str:
"""Add line to code that save charts to a file, if plt.show() is called."""
date = datetime.now().strftime("%Y-%m-%d")
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

# define chart save directory
project_root = dirname(dirname(dirname(__file__)))
chart_save_dir = os.path.join(project_root, f"exports\\charts\\{date}")
if not os.path.exists(chart_save_dir):
os.makedirs(chart_save_dir)

tree = ast.parse(code)

# count number of plt.show() calls
show_count = sum(
compare_ast(node, ast.parse("plt.show()").body[0], ignore_args=True)
for node in ast.walk(tree)
)

# if there are no plt.show() calls, return the original code
if show_count == 0:
return code

# iterate through the AST and add plt.savefig() calls before plt.show() calls
counter = ord("a")
new_body = []
for node in tree.body:
if compare_ast(node, ast.parse("plt.show()").body[0], ignore_args=True):
filename = f"chart_{timestamp}"
if show_count > 1:
filename += f"_{chr(counter)}"
counter += 1
new_body.append(
ast.parse(f"plt.savefig(r'{chart_save_dir}\\{filename}.png')")
)
new_body.append(node)

new_body.append(ast.parse(f"print(r'Charts saved to: {chart_save_dir}')"))

new_tree = ast.Module(body=new_body)
return astor.to_source(new_tree).strip()
77 changes: 77 additions & 0 deletions tests/test_save_chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Unit tests for the save_chart module."""
import ast
import os.path

from pandasai.helpers.save_chart import add_save_chart, compare_ast


class TestSaveChart:
"""Unit tests for the save_chart module."""

def test_compare_ast(self):
node1 = ast.parse("plt.show()").body[0]
node2 = ast.parse("plt.show(*some-args)").body[0]
assert compare_ast(node1, node2, ignore_args=True)

node1 = ast.parse("print(r'hello/word.jpeg')").body[0]
node2 = ast.parse("print()").body[0]
assert compare_ast(node1, node2, ignore_args=True)

def test_save_chart(self):
chart_code = """
import matplotlib.pyplot as plt
import pandas as pd
df = pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]})
df.plot()
plt.show()
"""
line_count = len(ast.parse(chart_code).body)
tree = ast.parse(add_save_chart(chart_code))
show_node = ast.parse("plt.show()").body[0]
show_call_pos = [
i
for i, node in enumerate(tree.body)
if compare_ast(node, show_node, ignore_args=True)
][0]
expected_node = ast.parse("plt.savefig()").body[0]
assert len(tree.body) == line_count + 2
assert compare_ast(
tree.body[show_call_pos - 1], expected_node, ignore_args=True
)
assert compare_ast(
tree.body[-1], ast.parse("print()").body[0], ignore_args=True
)

def test_save_multiple_charts(self):
chart_code = """
import matplotlib.pyplot as plt
import pandas as pd
df = pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]})
df.plot('a')
plt.show()
df.plot('b')
plt.show()
"""
line_count = len(ast.parse(chart_code).body)
tree = ast.parse(add_save_chart(chart_code))
show_node = ast.parse("plt.show()").body[0]
show_call_pos = [
i
for i, node in enumerate(tree.body)
if compare_ast(node, show_node, ignore_args=True)
]
expected_node = ast.parse("plt.savefig()").body[0]

assert len(tree.body) == line_count + 3

# check first node is plt.savefig() and filename ends with a
actual_node = tree.body[show_call_pos[0] - 1]
assert compare_ast(actual_node, expected_node, ignore_args=True)
actual_node_args = [a.value for a in actual_node.value.args]
assert os.path.splitext(actual_node_args[0])[0][-1] == "a"

# check second node is plt.savefig() and filename ends with n
actual_node = tree.body[show_call_pos[1] - 1]
assert compare_ast(actual_node, expected_node, ignore_args=True)
actual_node_args = [a.value for a in actual_node.value.args]
assert os.path.splitext(actual_node_args[0])[0][-1] == "b"

0 comments on commit 7ae7f1d

Please sign in to comment.