Skip to content

Commit

Permalink
add vizro-ai tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Anna-Xiong committed Oct 31, 2023
1 parent 1e54979 commit 77e20e0
Show file tree
Hide file tree
Showing 9 changed files with 677 additions and 1 deletion.
Empty file added vizro-ai/tests/.gitkeep
Empty file.
47 changes: 47 additions & 0 deletions vizro-ai/tests/unit/vizro-ai/components/test_chart_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pandas as pd
import pytest
from langchain.llms.fake import FakeListLLM
from vizro_ai.components import GetChartSelection


@pytest.fixture
def fake_llm():
# This is to simulate the response of LLM
response = ['{"chart_type": "bar"}']
return FakeListLLM(responses=response)


class TestChartSelectionInstantiation:
def test_instantiation(self):
chart_selection = GetChartSelection(llm=fake_llm)
assert chart_selection.llm == fake_llm

def setup_method(self, fake_llm):
self.get_chart_selection = GetChartSelection(llm=fake_llm)

def test_pre_process(self):
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
llm_kwargs, partial_vars = self.get_chart_selection._pre_process(df)
expected_partial_vars = {"df_schema": "A: int64\nB: int64", "df_head": df.head().to_markdown()}
assert partial_vars == expected_partial_vars

@pytest.mark.parametrize(
"load_args, expected_chart_name",
[
({"chart_type": "line"}, "line"),
({"chart_type": "bar"}, "bar"),
({"chart_type": ["line", "bar"]}, "line,bar"),
],
)
def test_post_process(self, load_args, expected_chart_name):
chart_names = self.get_chart_selection._post_process(load_args)
assert chart_names == expected_chart_name


class TestChartSelection:
def test_fake_response(self, gapminder, fake_llm):
get_chart_selection = GetChartSelection(fake_llm)
target_chart = get_chart_selection.run(
df=gapminder, chain_input="choose a best chart for describe the composition"
)
assert target_chart == "bar"
57 changes: 57 additions & 0 deletions vizro-ai/tests/unit/vizro-ai/components/test_code_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
from langchain.llms.fake import FakeListLLM
from vizro_ai.components import GetDebugger


@pytest.fixture
def fake_llm():
"""This is to simulate the response of LLM."""
response = ['{{"fixed_code": "{}"}}'.format("print(df[['country', 'continent']])")]
return FakeListLLM(responses=response)


@pytest.fixture
def fake_code_snippet():
return "print(df['country', 'continent'])"


@pytest.fixture
def fake_error_msg():
return "KeyError: ('country', 'continent')"


class TestCodeValidationInstantiation:
def test_instantiation(self):
chart_selection = GetDebugger(llm=fake_llm)
assert chart_selection.llm == fake_llm

def setup_method(self, fake_llm):
self.get_debugger = GetDebugger(llm=fake_llm)

def test_pre_process(self):
llm_kwargs, partial_vars = self.get_debugger._pre_process(fake_code_snippet)
assert partial_vars == {"code_snippet": fake_code_snippet}

@pytest.mark.parametrize(
"load_args, expected_fixed_code",
[
(
{"fixed_code": "print('unit test for expected fixed code')"},
"print('unit test for expected fixed code')",
),
(
{"fixed_code": "import pandas as pd\n" "\n" "print(df[['country', 'continent']])\n"},
"import pandas as pd\n" "\n" "print(df[['country', 'continent']])\n",
),
],
)
def test_post_process(self, load_args, expected_fixed_code):
fixed_code = self.get_debugger._post_process(load_args)
assert fixed_code == expected_fixed_code


class TestChartSelection:
def test_fake_response(self, fake_llm, fake_code_snippet, fake_error_msg):
get_debugger = GetDebugger(fake_llm)
fixed_code = get_debugger.run(chain_input=fake_error_msg, code_snippet=fake_code_snippet)
assert fixed_code == "print(df[['country', 'continent']])"
130 changes: 130 additions & 0 deletions vizro-ai/tests/unit/vizro-ai/components/test_custom_chart_wrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import pytest
from langchain.llms.fake import FakeListLLM
from vizro_ai.components import GetCustomChart


@pytest.fixture
def output_visual_component_1():
return """import vizro.plotly.express as px
import pandas as pd
df = df.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'})
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents')
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP')
fig.show()"""


@pytest.fixture
def output_custom_chart_LLM_1():
return """import vizro.plotly.express as px
import pandas as pd
def custom_chart(data_frame):
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'})
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents')
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP')
return fig"""


@pytest.fixture
def expected_final_output_1():
return """from vizro.models.types import capture
import vizro.plotly.express as px
import pandas as pd
@capture('graph')
def custom_chart(data_frame):
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'})
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents')
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP')
return fig
fig = custom_chart(data_frame=df)"""


@pytest.fixture
def output_custom_chart_LLM_2():
return """
import vizro.plotly.express as px
import pandas as pd
def custom_chart(data_frame):
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'})
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents')
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP')
return fig"""


@pytest.fixture
def expected_final_output_2():
return """from vizro.models.types import capture
import vizro.plotly.express as px
import pandas as pd
@capture('graph')
def custom_chart(data_frame):
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'})
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents')
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP')
return fig
fig = custom_chart(data_frame=df)"""


@pytest.fixture
def output_custom_chart_LLM_3():
return """import vizro.plotly.express as px
import pandas as pd
def some_chart_name(data_frame):
df = data_frame.groupby('continent')['gdpPercap'].sum().reset_index().rename(columns={'gdpPercap': 'total_gdp'})
fig = px.bar(df, x='continent', y='total_gdp', color='continent', title='Composition of GDP in Continents')
fig.add_hline(y=df['total_gdp'].mean(), line_dash='dash', line_color='red', annotation_text='Average GDP')
return fig"""


@pytest.fixture
def fake_llm(output_custom_chart_LLM_1):
"""This is to simulate the response of LLM."""
response = ['{{"custom_chart_code": "{}"}}'.format(output_custom_chart_LLM_1)]
return FakeListLLM(responses=response)


class TestGetCustomChartMethods:
def test_instantiation(self):
"""Test initialization of GetCustomChart."""
get_custom_chart = GetCustomChart(llm=fake_llm)
assert get_custom_chart.llm == fake_llm

def setup_method(self, fake_llm):
self.get_custom_chart = GetCustomChart(llm=fake_llm)

def test_pre_process(self):
llm_kwargs, partial_vars = self.get_custom_chart._pre_process()
assert partial_vars == {}
assert isinstance(llm_kwargs, dict)

@pytest.mark.parametrize(
"input,output",
[
("output_custom_chart_LLM_1", "expected_final_output_1"),
("output_custom_chart_LLM_2", "expected_final_output_2"),
],
)
def test_post_process(self, input, output, request):
input = request.getfixturevalue(input)
output = request.getfixturevalue(output)
loaded_args = {"custom_chart_code": input}
processed_code = self.get_custom_chart._post_process(loaded_args)
assert processed_code == output

def test_post_process_fail(self, output_custom_chart_LLM_3):
loaded_args = {"custom_chart_code": output_custom_chart_LLM_3}
with pytest.raises(ValueError, match="def custom_chart is not added correctly by the LLM. Try again."):
self.get_custom_chart._post_process(loaded_args)


class TestGetCustomChartRun:
def test_fake_run(self, fake_llm, expected_final_output_1):
get_custom_chart = GetCustomChart(fake_llm)
# Note that the chain input is not used in this component as we fake the LLM response
processed_code = get_custom_chart.run(chain_input="XXX")
assert processed_code == expected_final_output_1
97 changes: 97 additions & 0 deletions vizro-ai/tests/unit/vizro-ai/components/test_dataframe_craft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import re

import pandas as pd
import pytest
from langchain.llms.fake import FakeListLLM
from vizro_ai.components import GetDataFrameCraft


def dataframe_code():
return """
data_frame = data_frame.groupby('continent')['gdpPercap'].sum().reset_index()
data_frame = data_frame.rename(columns={'gdpPercap': 'total_gdp'})
data_frame.plot(kind='bar', x='continent', y='total_gdp', color='skyblue', legend=False)"""


@pytest.fixture
def fake_llm():
dataframe_code_before_postprocess = re.sub(
r"[\x00-\x1f]", lambda m: "\\u{:04x}".format(ord(m.group(0))), dataframe_code()
)
response = ['{{"dataframe_code": "{}"}}'.format(dataframe_code_before_postprocess)]
return FakeListLLM(responses=response)


@pytest.fixture
def input_df():
input_df = pd.DataFrame(
{
"contintent": ["Asia", "Asia", "America", "Europe"],
"country": ["China", "India", "US", "UK"],
"gdpPercap": [102, 110, 300, 200],
}
)
return input_df


class TestDataFrameCraftMethods:
def test_instantiation(self):
dataframe_craft = GetDataFrameCraft(llm=fake_llm)
assert dataframe_craft.llm == fake_llm

def setup_method(self, fake_llm):
self.get_dataframe_craft = GetDataFrameCraft(llm=fake_llm)

def test_pre_process(self, input_df):
llm_kwargs_to_use, partial_vars = self.get_dataframe_craft._pre_process(df=input_df)
expected_partial_vars = {
"df_schema": "contintent: object\ncountry: object\ngdpPercap: int64",
"df_head": input_df.head().to_markdown(),
}
assert partial_vars == expected_partial_vars

@pytest.mark.parametrize(
"code_string, expected_code_string",
[
(
"df = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]})",
"import pandas as pd\ndf = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).reset_index()",
),
(
"df = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).reset_index()",
"import pandas as pd\ndf = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).reset_index()",
),
(
"data_frame = pd.DataFrame({'test1': [1, 1, 2], 'test2': [3, 4, 5]})\n"
"data_frame = data_frame.groupby('test1')['test2'].sum()",
"import pandas as pd\ndata_frame = pd.DataFrame({'test1': [1, 1, 2], 'test2': [3, 4, 5]})\n"
"df = data_frame.groupby('test1')['test2'].sum().reset_index()",
),
(
"import pandas as pd\n"
"df = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).plot(kind='bar', x='test1', y='test2')",
"import pandas as pd\ndf = pd.DataFrame({'test1': [1, 2], 'test2': [3, 4]}).reset_index()",
),
],
)
def test_post_process(self, code_string, expected_code_string, input_df):
load_args = {"dataframe_code": code_string}
df_code = self.get_dataframe_craft._post_process(load_args, input_df)

assert df_code == expected_code_string


class TestDataFrameCraftResponse:
def test_fake_response(self, input_df, fake_llm):
get_dataframe_craft = GetDataFrameCraft(fake_llm)
df_code = get_dataframe_craft.run(
chain_input="choose a best chart for describe the composition of gdp in continent, "
"and horizontal line for avg gdp",
df=input_df,
)
assert (
df_code == "import pandas as pd\n "
"data_frame = data_frame.groupby('continent')['gdpPercap'].sum().reset_index()\n "
"data_frame = data_frame.rename(columns={'gdpPercap': 'total_gdp'})\n"
"df = data_frame.reset_index()"
)
Loading

0 comments on commit 77e20e0

Please sign in to comment.