forked from Sinaptik-AI/pandas-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add save_charts option (Sinaptik-AI#188)
* test: add save_chart tests Create tests to confirm plt.savefig() is injected into llm output. Affix letter to timestamp when multiple charts are created. * feat: add save_chart.py module Initial commit of save_chart.py. - compare_ast() Compare two AST nodes for equality. - add_save_chart() Add line to code that save charts to a file, if plt.show() is called. Tests failing! * test: add compare_ast test - Add compare_ast test. - Update save_chart tests to use compare_ast(). Tests failing! * feat: update add_save_chart() to use unique names Affix a letter character to the end of the filename when more than one plt.show() call exists in the code. fix: correct project root variable * feat: add save_charts argument to PandasAI Add save_charts argument to init call of PandasAi. When set to True, a call to `plt.save_fig()` is injected before any calls to `plt.show_plot()` in the run code. * docs: add instructions to save charts * feat: print chart save path Add print expression to run code so user can locate saved charts. Update test to check for print expression.
- Loading branch information
1 parent
06796f1
commit 7ae7f1d
Showing
4 changed files
with
178 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""Helper functions to save charts to a file, if plt.show() is called.""" | ||
import ast | ||
import os | ||
from datetime import datetime | ||
from itertools import zip_longest | ||
from os.path import dirname | ||
from typing import Union | ||
|
||
import astor | ||
|
||
|
||
def compare_ast( | ||
node1: Union[ast.expr, list[ast.expr], ast.stmt, ast.AST], | ||
node2: Union[ast.expr, list[ast.expr], ast.stmt, ast.AST], | ||
ignore_args=False, | ||
) -> bool: | ||
"""Compare two AST nodes for equality. | ||
Source: https://stackoverflow.com/a/66733795/11080806""" | ||
if type(node1) is not type(node2): | ||
return False | ||
|
||
if isinstance(node1, ast.AST): | ||
for k, node in vars(node1).items(): | ||
if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}: | ||
continue | ||
if ignore_args and k == "args": | ||
continue | ||
if not compare_ast(node, getattr(node2, k), ignore_args): | ||
return False | ||
return True | ||
|
||
if isinstance(node1, list) and isinstance(node2, list): | ||
return all( | ||
compare_ast(n1, n2, ignore_args) for n1, n2 in zip_longest(node1, node2) | ||
) | ||
|
||
return node1 == node2 | ||
|
||
|
||
def add_save_chart(code: str) -> str: | ||
"""Add line to code that save charts to a file, if plt.show() is called.""" | ||
date = datetime.now().strftime("%Y-%m-%d") | ||
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | ||
|
||
# define chart save directory | ||
project_root = dirname(dirname(dirname(__file__))) | ||
chart_save_dir = os.path.join(project_root, f"exports\\charts\\{date}") | ||
if not os.path.exists(chart_save_dir): | ||
os.makedirs(chart_save_dir) | ||
|
||
tree = ast.parse(code) | ||
|
||
# count number of plt.show() calls | ||
show_count = sum( | ||
compare_ast(node, ast.parse("plt.show()").body[0], ignore_args=True) | ||
for node in ast.walk(tree) | ||
) | ||
|
||
# if there are no plt.show() calls, return the original code | ||
if show_count == 0: | ||
return code | ||
|
||
# iterate through the AST and add plt.savefig() calls before plt.show() calls | ||
counter = ord("a") | ||
new_body = [] | ||
for node in tree.body: | ||
if compare_ast(node, ast.parse("plt.show()").body[0], ignore_args=True): | ||
filename = f"chart_{timestamp}" | ||
if show_count > 1: | ||
filename += f"_{chr(counter)}" | ||
counter += 1 | ||
new_body.append( | ||
ast.parse(f"plt.savefig(r'{chart_save_dir}\\{filename}.png')") | ||
) | ||
new_body.append(node) | ||
|
||
new_body.append(ast.parse(f"print(r'Charts saved to: {chart_save_dir}')")) | ||
|
||
new_tree = ast.Module(body=new_body) | ||
return astor.to_source(new_tree).strip() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""Unit tests for the save_chart module.""" | ||
import ast | ||
import os.path | ||
|
||
from pandasai.helpers.save_chart import add_save_chart, compare_ast | ||
|
||
|
||
class TestSaveChart: | ||
"""Unit tests for the save_chart module.""" | ||
|
||
def test_compare_ast(self): | ||
node1 = ast.parse("plt.show()").body[0] | ||
node2 = ast.parse("plt.show(*some-args)").body[0] | ||
assert compare_ast(node1, node2, ignore_args=True) | ||
|
||
node1 = ast.parse("print(r'hello/word.jpeg')").body[0] | ||
node2 = ast.parse("print()").body[0] | ||
assert compare_ast(node1, node2, ignore_args=True) | ||
|
||
def test_save_chart(self): | ||
chart_code = """ | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
df = pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]}) | ||
df.plot() | ||
plt.show() | ||
""" | ||
line_count = len(ast.parse(chart_code).body) | ||
tree = ast.parse(add_save_chart(chart_code)) | ||
show_node = ast.parse("plt.show()").body[0] | ||
show_call_pos = [ | ||
i | ||
for i, node in enumerate(tree.body) | ||
if compare_ast(node, show_node, ignore_args=True) | ||
][0] | ||
expected_node = ast.parse("plt.savefig()").body[0] | ||
assert len(tree.body) == line_count + 2 | ||
assert compare_ast( | ||
tree.body[show_call_pos - 1], expected_node, ignore_args=True | ||
) | ||
assert compare_ast( | ||
tree.body[-1], ast.parse("print()").body[0], ignore_args=True | ||
) | ||
|
||
def test_save_multiple_charts(self): | ||
chart_code = """ | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
df = pd.DataFrame({'a': [1,2,3], 'b': [4,5,6]}) | ||
df.plot('a') | ||
plt.show() | ||
df.plot('b') | ||
plt.show() | ||
""" | ||
line_count = len(ast.parse(chart_code).body) | ||
tree = ast.parse(add_save_chart(chart_code)) | ||
show_node = ast.parse("plt.show()").body[0] | ||
show_call_pos = [ | ||
i | ||
for i, node in enumerate(tree.body) | ||
if compare_ast(node, show_node, ignore_args=True) | ||
] | ||
expected_node = ast.parse("plt.savefig()").body[0] | ||
|
||
assert len(tree.body) == line_count + 3 | ||
|
||
# check first node is plt.savefig() and filename ends with a | ||
actual_node = tree.body[show_call_pos[0] - 1] | ||
assert compare_ast(actual_node, expected_node, ignore_args=True) | ||
actual_node_args = [a.value for a in actual_node.value.args] | ||
assert os.path.splitext(actual_node_args[0])[0][-1] == "a" | ||
|
||
# check second node is plt.savefig() and filename ends with n | ||
actual_node = tree.body[show_call_pos[1] - 1] | ||
assert compare_ast(actual_node, expected_node, ignore_args=True) | ||
actual_node_args = [a.value for a in actual_node.value.args] | ||
assert os.path.splitext(actual_node_args[0])[0][-1] == "b" |