-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1e54979
commit 77e20e0
Showing
9 changed files
with
677 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
47 changes: 47 additions & 0 deletions
47
vizro-ai/tests/unit/vizro-ai/components/test_chart_selection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pandas as pd | ||
import pytest | ||
from langchain.llms.fake import FakeListLLM | ||
from vizro_ai.components import GetChartSelection | ||
|
||
|
||
@pytest.fixture | ||
def fake_llm(): | ||
# This is to simulate the response of LLM | ||
response = ['{"chart_type": "bar"}'] | ||
return FakeListLLM(responses=response) | ||
|
||
|
||
class TestChartSelectionInstantiation: | ||
def test_instantiation(self): | ||
chart_selection = GetChartSelection(llm=fake_llm) | ||
assert chart_selection.llm == fake_llm | ||
|
||
def setup_method(self, fake_llm): | ||
self.get_chart_selection = GetChartSelection(llm=fake_llm) | ||
|
||
def test_pre_process(self): | ||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) | ||
llm_kwargs, partial_vars = self.get_chart_selection._pre_process(df) | ||
expected_partial_vars = {"df_schema": "A: int64\nB: int64", "df_head": df.head().to_markdown()} | ||
assert partial_vars == expected_partial_vars | ||
|
||
@pytest.mark.parametrize( | ||
"load_args, expected_chart_name", | ||
[ | ||
({"chart_type": "line"}, "line"), | ||
({"chart_type": "bar"}, "bar"), | ||
({"chart_type": ["line", "bar"]}, "line,bar"), | ||
], | ||
) | ||
def test_post_process(self, load_args, expected_chart_name): | ||
chart_names = self.get_chart_selection._post_process(load_args) | ||
assert chart_names == expected_chart_name | ||
|
||
|
||
class TestChartSelection: | ||
def test_fake_response(self, gapminder, fake_llm): | ||
get_chart_selection = GetChartSelection(fake_llm) | ||
target_chart = get_chart_selection.run( | ||
df=gapminder, chain_input="choose a best chart for describe the composition" | ||
) | ||
assert target_chart == "bar" |
57 changes: 57 additions & 0 deletions
57
vizro-ai/tests/unit/vizro-ai/components/test_code_validation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pytest | ||
from langchain.llms.fake import FakeListLLM | ||
from vizro_ai.components import GetDebugger | ||
|
||
|
||
@pytest.fixture | ||
def fake_llm(): | ||
"""This is to simulate the response of LLM.""" | ||
response = ['{{"fixed_code": "{}"}}'.format("print(df[['country', 'continent']])")] | ||
return FakeListLLM(responses=response) | ||
|
||
|
||
@pytest.fixture | ||
def fake_code_snippet(): | ||
return "print(df['country', 'continent'])" | ||
|
||
|
||
@pytest.fixture | ||
def fake_error_msg(): | ||
return "KeyError: ('country', 'continent')" | ||
|
||
|
||
class TestCodeValidationInstantiation: | ||
def test_instantiation(self): | ||
chart_selection = GetDebugger(llm=fake_llm) | ||
assert chart_selection.llm == fake_llm | ||
|
||
def setup_method(self, fake_llm): | ||
self.get_debugger = GetDebugger(llm=fake_llm) | ||
|
||
def test_pre_process(self): | ||
llm_kwargs, partial_vars = self.get_debugger._pre_process(fake_code_snippet) | ||
assert partial_vars == {"code_snippet": fake_code_snippet} | ||
|
||
@pytest.mark.parametrize( | ||
"load_args, expected_fixed_code", | ||
[ | ||
( | ||
{"fixed_code": "print('unit test for expected fixed code')"}, | ||
"print('unit test for expected fixed code')", | ||
), | ||
( | ||
{"fixed_code": "import pandas as pd\n" "\n" "print(df[['country', 'continent']])\n"}, | ||
"import pandas as pd\n" "\n" "print(df[['country', 'continent']])\n", | ||
), | ||
], | ||
) | ||
def test_post_process(self, load_args, expected_fixed_code): | ||
fixed_code = self.get_debugger._post_process(load_args) | ||
assert fixed_code == expected_fixed_code | ||
|
||
|
||
class TestChartSelection: | ||
def test_fake_response(self, fake_llm, fake_code_snippet, fake_error_msg): | ||
get_debugger = GetDebugger(fake_llm) | ||
fixed_code = get_debugger.run(chain_input=fake_error_msg, code_snippet=fake_code_snippet) | ||
assert fixed_code == "print(df[['country', 'continent']])" |
130 changes: 130 additions & 0 deletions
130
vizro-ai/tests/unit/vizro-ai/components/test_custom_chart_wrap.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import pytest | ||
from langchain.llms.fake import FakeListLLM | ||
from vizro_ai.components import GetCustomChart | ||
|
||
|
||
@pytest.fixture | ||
def output_visual_component_1(): | ||
return """import vizro.plotly.express as px | ||
import pandas as pd | ||
df = df.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'}) | ||
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents') | ||
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP') | ||
fig.show()""" | ||
|
||
|
||
@pytest.fixture | ||
def output_custom_chart_LLM_1(): | ||
return """import vizro.plotly.express as px | ||
import pandas as pd | ||
def custom_chart(data_frame): | ||
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'}) | ||
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents') | ||
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP') | ||
return fig""" | ||
|
||
|
||
@pytest.fixture | ||
def expected_final_output_1(): | ||
return """from vizro.models.types import capture | ||
import vizro.plotly.express as px | ||
import pandas as pd | ||
@capture('graph') | ||
def custom_chart(data_frame): | ||
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'}) | ||
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents') | ||
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP') | ||
return fig | ||
fig = custom_chart(data_frame=df)""" | ||
|
||
|
||
@pytest.fixture | ||
def output_custom_chart_LLM_2(): | ||
return """ | ||
import vizro.plotly.express as px | ||
import pandas as pd | ||
def custom_chart(data_frame): | ||
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'}) | ||
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents') | ||
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP') | ||
return fig""" | ||
|
||
|
||
@pytest.fixture | ||
def expected_final_output_2(): | ||
return """from vizro.models.types import capture | ||
import vizro.plotly.express as px | ||
import pandas as pd | ||
@capture('graph') | ||
def custom_chart(data_frame): | ||
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'}) | ||
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents') | ||
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP') | ||
return fig | ||
fig = custom_chart(data_frame=df)""" | ||
|
||
|
||
@pytest.fixture | ||
def output_custom_chart_LLM_3(): | ||
return """import vizro.plotly.express as px | ||
import pandas as pd | ||
def some_chart_name(data_frame): | ||
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'}) | ||
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents') | ||
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP') | ||
return fig""" | ||
|
||
|
||
@pytest.fixture | ||
def fake_llm(output_custom_chart_LLM_1): | ||
"""This is to simulate the response of LLM.""" | ||
response = ['{{"custom_chart_code": "{}"}}'.format(output_custom_chart_LLM_1)] | ||
return FakeListLLM(responses=response) | ||
|
||
|
||
class TestGetCustomChartMethods: | ||
def test_instantiation(self): | ||
"""Test initialization of GetCustomChart.""" | ||
get_custom_chart = GetCustomChart(llm=fake_llm) | ||
assert get_custom_chart.llm == fake_llm | ||
|
||
def setup_method(self, fake_llm): | ||
self.get_custom_chart = GetCustomChart(llm=fake_llm) | ||
|
||
def test_pre_process(self): | ||
llm_kwargs, partial_vars = self.get_custom_chart._pre_process() | ||
assert partial_vars == {} | ||
assert isinstance(llm_kwargs, dict) | ||
|
||
@pytest.mark.parametrize( | ||
"input,output", | ||
[ | ||
("output_custom_chart_LLM_1", "expected_final_output_1"), | ||
("output_custom_chart_LLM_2", "expected_final_output_2"), | ||
], | ||
) | ||
def test_post_process(self, input, output, request): | ||
input = request.getfixturevalue(input) | ||
output = request.getfixturevalue(output) | ||
loaded_args = {"custom_chart_code": input} | ||
processed_code = self.get_custom_chart._post_process(loaded_args) | ||
assert processed_code == output | ||
|
||
def test_post_process_fail(self, output_custom_chart_LLM_3): | ||
loaded_args = {"custom_chart_code": output_custom_chart_LLM_3} | ||
with pytest.raises(ValueError, match="def custom_chart is not added correctly by the LLM. Try again."): | ||
self.get_custom_chart._post_process(loaded_args) | ||
|
||
|
||
class TestGetCustomChartRun: | ||
def test_fake_run(self, fake_llm, expected_final_output_1): | ||
get_custom_chart = GetCustomChart(fake_llm) | ||
# Note that the chain input is not used in this component as we fake the LLM response | ||
processed_code = get_custom_chart.run(chain_input="XXX") | ||
assert processed_code == expected_final_output_1 |
97 changes: 97 additions & 0 deletions
97
vizro-ai/tests/unit/vizro-ai/components/test_dataframe_craft.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import re | ||
|
||
import pandas as pd | ||
import pytest | ||
from langchain.llms.fake import FakeListLLM | ||
from vizro_ai.components import GetDataFrameCraft | ||
|
||
|
||
def dataframe_code(): | ||
return """ | ||
data_frame = data_frame.groupby('continent')['gdpPercap'].sum().reset_index() | ||
data_frame = data_frame.rename(columns={'gdpPercap': 'total_gdp'}) | ||
data_frame.plot(kind='bar', x='continent', y='total_gdp', color='skyblue', legend=False)""" | ||
|
||
|
||
@pytest.fixture | ||
def fake_llm(): | ||
dataframe_code_before_postprocess = re.sub( | ||
r"[\x00-\x1f]", lambda m: "\\u{:04x}".format(ord(m.group(0))), dataframe_code() | ||
) | ||
response = ['{{"dataframe_code": "{}"}}'.format(dataframe_code_before_postprocess)] | ||
return FakeListLLM(responses=response) | ||
|
||
|
||
@pytest.fixture | ||
def input_df(): | ||
input_df = pd.DataFrame( | ||
{ | ||
"contintent": ["Asia", "Asia", "America", "Europe"], | ||
"country": ["China", "India", "US", "UK"], | ||
"gdpPercap": [102, 110, 300, 200], | ||
} | ||
) | ||
return input_df | ||
|
||
|
||
class TestDataFrameCraftMethods: | ||
def test_instantiation(self): | ||
dataframe_craft = GetDataFrameCraft(llm=fake_llm) | ||
assert dataframe_craft.llm == fake_llm | ||
|
||
def setup_method(self, fake_llm): | ||
self.get_dataframe_craft = GetDataFrameCraft(llm=fake_llm) | ||
|
||
def test_pre_process(self, input_df): | ||
llm_kwargs_to_use, partial_vars = self.get_dataframe_craft._pre_process(df=input_df) | ||
expected_partial_vars = { | ||
"df_schema": "contintent: object\ncountry: object\ngdpPercap: int64", | ||
"df_head": input_df.head().to_markdown(), | ||
} | ||
assert partial_vars == expected_partial_vars | ||
|
||
@pytest.mark.parametrize( | ||
"code_string, expected_code_string", | ||
[ | ||
( | ||
"df = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]})", | ||
"import pandas as pd\ndf = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).reset_index()", | ||
), | ||
( | ||
"df = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).reset_index()", | ||
"import pandas as pd\ndf = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).reset_index()", | ||
), | ||
( | ||
"data_frame = pd.DataFrame({'test1': [1, 1, 2], 'test2': [3, 4, 5]})\n" | ||
"data_frame = data_frame.groupby('test1')['test2'].sum()", | ||
"import pandas as pd\ndata_frame = pd.DataFrame({'test1': [1, 1, 2], 'test2': [3, 4, 5]})\n" | ||
"df = data_frame.groupby('test1')['test2'].sum().reset_index()", | ||
), | ||
( | ||
"import pandas as pd\n" | ||
"df = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).plot(kind='bar', x='test1', y='test2')", | ||
"import pandas as pd\ndf = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).reset_index()", | ||
), | ||
], | ||
) | ||
def test_post_process(self, code_string, expected_code_string, input_df): | ||
load_args = {"dataframe_code": code_string} | ||
df_code = self.get_dataframe_craft._post_process(load_args, input_df) | ||
|
||
assert df_code == expected_code_string | ||
|
||
|
||
class TestDataFrameCraftResponse: | ||
def test_fake_response(self, input_df, fake_llm): | ||
get_dataframe_craft = GetDataFrameCraft(fake_llm) | ||
df_code = get_dataframe_craft.run( | ||
chain_input="choose a best chart for describe the composition of gdp in continent, " | ||
"and horizontal line for avg gdp", | ||
df=input_df, | ||
) | ||
assert ( | ||
df_code == "import pandas as pd\n " | ||
"data_frame = data_frame.groupby('continent')['gdpPercap'].sum().reset_index()\n " | ||
"data_frame = data_frame.rename(columns={'gdpPercap': 'total_gdp'})\n" | ||
"df = data_frame.reset_index()" | ||
) |
Oops, something went wrong.