Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: visualization generation instruction and fewshots #19

Merged
merged 2 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion coml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
FIX_INSTRUCTION,
GENERATE_INSTRUCTION,
GENERATE_INSTRUCTION_COT,
GENERATE_INSTRUCTION_VIS_MATPLOTLIB,
GENERATE_INSTRUCTION_VIS_SEABORN,
SANITY_CHECK_INSTRUCTION,
SUGGEST_INSTRUCTION,
FixContext,
Expand Down Expand Up @@ -125,7 +127,9 @@ class CoMLAgent:
def __init__(
self,
llm: BaseChatModel,
prompt_version: Literal["v1", "v2", "kaggle", "leetcode"] = "v2",
prompt_version: Literal[
"v1", "v2", "kaggle", "leetcode", "matplotlib", "seaborn"
] = "v2",
prompt_validation: Callable[[list[BaseMessage]], bool] | None = None,
num_examples: float | int = 1.0,
message_style: Literal["chatgpt", "gemini"] = "chatgpt",
Expand Down Expand Up @@ -298,6 +302,12 @@ def generate_code(
shot["answer"] = shot.pop("answer_wo_intact")
if "rationale_wo_intact" in shot:
shot["rationale"] = shot.pop("rationale_wo_intact")

if self.prompt_version == "matplotlib":
generate_instruction = GENERATE_INSTRUCTION_VIS_MATPLOTLIB
elif self.prompt_version == "seaborn":
generate_instruction = GENERATE_INSTRUCTION_VIS_SEABORN

messages.append(SystemMessage(content=generate_instruction))

for shot in self._select_examples(request, fewshots):
Expand Down
27 changes: 26 additions & 1 deletion coml/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,18 @@ def render_sanity_check_context(
- Think before you write. You should first understand the user's request, think about how to achieve it, and then write the code.
"""

GENERATE_INSTRUCTION_VIS_MATPLOTLIB = f"""You're a helpful assistant proficient in writing Python code for data visualization. Upon receiving relevant context, such as available variables and any pre-executed code, your goal is to complete the Python code to generate a visualization that meets the user's request.

Instructions:

- Utilize the `matplotlib` library to create the visualization and ensure include `plt.show()` to display the chart.
- You must return the generated code wrapped by ``` before and after it, and do not add any explanation.
"""

GENERATE_INSTRUCTION_VIS_SEABORN = GENERATE_INSTRUCTION_VIS_MATPLOTLIB.replace(
"matplotlib", "seaborn"
)

FIX_INSTRUCTION = f"""{GENERATE_INSTRUCTION.rstrip()}
- If the user thinks the generated code is problematic, you should help fix it. The user will provide you with the exception message (if any), the output of the code (if any), and a hint (if any). You should provide a line-by-line explanation of the code, and point out what is wrong with the code. You should also provide the fixed code.
- If you think the provided problematic code is actually correct, you should first explain the code, and write "THE CODE IS CORRECT." (in upper case) in the observation section. The fixed code can be omitted.
Expand Down Expand Up @@ -459,7 +471,20 @@ def cached_generate_fewshots(prompt_version: str) -> list[GenerateContext]:
with open(
Path(__file__).parent / f"prompts/generate_fewshots_{prompt_version}.json"
) as f:
return json.load(f)
fewshots = json.load(f)
for shot in fewshots:
variables = {}
if "datasets" in shot:
for name in shot["datasets"]:
dataset = pd.read_csv(
str(Path(__file__).parent / f"prompts/dataset/{name}.csv")
)
# todo: dataframe_format
variables[f"{name.split('/')[1]}_dataset"] = describe_variable(
dataset
)
shot["variables"] = variables
return fewshots


def cached_fix_fewshots() -> list[FixContext]:
Expand Down
11 changes: 11 additions & 0 deletions coml/prompts/dataset/coffee_shop/member.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Member_ID,Name,Membership_card,Age,Time_of_purchase,Level_of_membership,Address
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure the csv's are contained in the package upon next release.

1,"Ashby, Lazale",Black,29,18,5,Hartford
2,"Breton, Robert",White,67,41,4,Waterbury
3,"Campbell, Jessie",Black,34,20,6,Hartford
4,"Cobb, Sedrick",Black,51,27,2,Waterbury
5,"Hayes, Steven",White,50,44,3,Cheshire
6,"Komisarjevsky, Joshua",White,33,26,2,Cheshire
7,"Peeler, Russell",Black,42,26,6,Bridgeport
8,"Reynolds, Richard",Black,45,24,1,Waterbury
9,"Rizzo, Todd",White,35,18,4,Waterbury
10,"Webb, Daniel",Black,51,27,22,Hartford
21 changes: 21 additions & 0 deletions coml/prompts/dataset/game_injury/game.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
stadium_id,id,Season,Date,Home_team,Away_team,Score,Competition
1,1,2007,18 May 2007,Quruvchi,Pakhtakor,1–1,League
2,2,2007,22 September 2007,Pakhtakor,Quruvchi,0–0,League
3,3,2007,9 December 2007,Pakhtakor,Quruvchi,0–0 (7:6),Cup
4,4,2008,10 July 2008,Pakhtakor,Quruvchi,1–1,League
5,5,2008,16 August 2008,Bunyodkor,Pakhtakor,1–1,League
6,6,2008,31 October 2008,Bunyodkor,Pakhtakor,3–1,Cup
7,7,2009,12 July 2009,Bunyodkor,Pakhtakor,2–1,League
8,8,2009,14 October 2009,Pakhtakor,Bunyodkor,0–0,League
9,9,2009,8 August 2009,Pakhtakor,Bunyodkor,1–0,Cup
10,10,2010,14 March 2010,Bunyodkor,Pakhtakor,2–1,League
10,11,2010,31 October 2010,Pakhtakor,Bunyodkor,0–0,League
10,12,2011,7 July 2011,Pakhtakor,Bunyodkor,0–0,League
1,13,2011,21 August 2011,Bunyodkor,Pakhtakor,2–1,League
2,14,2012,11 March 2012,Bunyodkor,Pakhtakor,–,Supercup
3,15,2012,26 June 2012,Bunyodkor,Pakhtakor,2–0,League
4,16,2012,9 August 2012,Pakhtakor,Bunyodkor,1–1,League
5,17,2012,22 August 2012,Bunyodkor,Pakhtakor,1–1,Cup
11,18,2012,25 November 2012,Pakhtakor,Bunyodkor,1–3,Cup
12,19,2013,30 June 2013,Pakhtakor,Bunyodkor,0–2,League
7,20,2013,8 August 2013,Bunyodkor,Pakhtakor,1–2,League
6 changes: 6 additions & 0 deletions coml/prompts/dataset/pilot_record/pilot.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Pilot_ID,Pilot_name,Rank,Age,Nationality,Position,Join_Year,Team
1,Patrick O'Bryant,13,33,United States,Center Team,2009,Bradley
2,Jermaine O'Neal,6,40,United States,Forward-Center Team,2008,Eau Claire High School
3,Dan O'Sullivan,45,37,United States,Center Team,1999,Fordham
4,Charles Oakley,34,22,United Kindom,Forward Team,2001,Virginia Union
5,Hakeem Olajuwon,34,32,Nigeria,Center Team,2010,Houston
13 changes: 13 additions & 0 deletions coml/prompts/dataset/scientist_1/AssignedTo.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Scientist,Project
123234877,AeH1
152934485,AeH3
222364883,Ast3
326587417,Ast3
332154719,Bte1
546523478,Che1
631231482,Ast3
654873219,Che1
745685214,AeH3
845657245,Ast1
845657246,Ast2
332569843,AeH4
15 changes: 15 additions & 0 deletions coml/prompts/dataset/scientist_1/Projects.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Code,Name,Hours
AeH1,Winds: Studying Bernoullis Principle,156
AeH2,Aerodynamics and Bridge Design,189
AeH3,Aerodynamics and Gas Mileage,256
AeH4,Aerodynamics and Ice Hockey,789
AeH5,Aerodynamics of a Football,98
AeH6,Aerodynamics of Air Hockey,89
Ast1,A Matter of Time,112
Ast2,A Puzzling Parallax,299
Ast3,Build Your Own Telescope,6546
Bte1,Juicy: Extracting Apple Juice with Pectinase,321
Bte2,A Magnetic Primer Designer,9684
Bte3,Bacterial Transformation Efficiency,321
Che1,A Silver-Cleaning Battery,545
Che2,A Soluble Separation Solution,778
47 changes: 47 additions & 0 deletions coml/prompts/generate_fewshots_matplotlib.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[
{
"datasets": [
"scientist_1/Projects",
"scientist_1/AssignedTo"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt",
"Projects_dataset = pd.read_csv('../dataset/scientist_1/Projects.csv')\nAssignedTo_dataset = pd.read_csv('../dataset/scientist_1/AssignedTo.csv')"
],
"request": "What are the names of projects that require more than 300 hours, and how many scientists are assigned to each? Plot the result in a bar chart.",
"answer": "# Merge the two datasets on the project code\nmerged_data = pd.merge(Projects_dataset, AssignedTo_dataset, left_on='Code', right_on='Project')\n\n# Filter projects that require more than 300 hours\nfiltered_data = merged_data[merged_data['Hours'] > 300]\n\n# Count the number of scientists assigned to each project\nproject_counts = filtered_data['Name'].value_counts()\n\n# Plot a bar chart\nplt.figure(figsize=(10, 8))\nplt.bar(project_counts.index, project_counts.values)\nplt.xticks(rotation=5)\nplt.xlabel('Project Name')\nplt.ylabel('Number of Scientists')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\nplt.title('Number of Scientists Assigned to Each Project')\nplt.show()"
},
{
"datasets": [
"coffee_shop/member"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt",
"member = pd.read_csv('../dataset/coffee_shop/member.csv')"
],
"request": "A scatter chart showing the correlation between the age of the customer and the time of purchase colored by membership level.",
"answer": "# Group the dataset by Membership_card\ngroups = member_dataset.groupby('Membership_card')\n\n# Create a scatter chart for each Membership_card\nfor membership_card, group in groups:\n plt.scatter(group['Age'], group['Time_of_purchase'], label=membership_card)\n\n# Set the title and labels\nplt.title('Correlation between Age and Time of Purchase')\nplt.xlabel('Age')\nplt.ylabel('Time of Purchase')\nplt.legend(loc='upper left')\n\n# Show the plot\nplt.show()"
},
{
"datasets": [
"game_injury/game"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt",
"game_dataset = pd.read_csv('../dataset/game_injury/game.csv')"
],
"request": "Show the number of games in each season and group by away team in a group line chart. The x-axis is season.",
"answer": "# group the dataset by season and away team\ngrouped = game_dataset.groupby(['Season', 'Away_team']).size().reset_index(name='counts')\n\n# create a pivot table with season as index and away team as columns\npivot_table = pd.pivot_table(grouped, values='counts', index=['Season'], columns=['Away_team'], fill_value=0)\n\n# create the line chart\npivot_table.plot(kind='line')\n\n# set the title and labels\nplt.title('Number of Games in Each Season by Away Team')\nplt.xlabel('Season')\nplt.ylabel('Number of Games')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\n\n# show the plot\nplt.show()"
},
{
"datasets": [
"pilot_record/pilot"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt",
"pilot_dataset = pd.read_csv('../dataset/pilot_record/pilot.csv')"
],
"request": "What is the proportion of positions of pilots? Show the result in a pie chart.",
"answer": "# Count the number of each position\nposition_counts = pilot_dataset['Position'].value_counts()\n\n# Create a pie chart\nplt.figure(figsize=(8,6))\nplt.pie(position_counts, labels = position_counts.index, autopct='%1.1f%%')\n\n# Set the title\nplt.title('Proportion of Positions')\n\n# Show the plot\nplt.show()"
}
]
14 changes: 14 additions & 0 deletions coml/prompts/generate_fewshots_seaborn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"datasets": [
"scientist_1/Projects",
"scientist_1/AssignedTo"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns",
"Projects_dataset = pd.read_csv('../dataset/scientist_1/Projects.csv')\nAssignedTo_dataset = pd.read_csv('../dataset/scientist_1/AssignedTo.csv')"
],
"request": "What are the names of projects that require more than 300 hours, and how many scientists are assigned to each? Plot the result in a bar chart.",
"answer": "# Merge the two datasets on the project code\nmerged_data = pd.merge(Projects_dataset, AssignedTo_dataset, left_on='Code', right_on='Project')\n\n# Filter projects that require more than 300 hours\nfiltered_data = merged_data[merged_data['Hours'] > 300]\n\n# Count the number of scientists assigned to each project\nproject_counts = filtered_data['Name'].value_counts()\n\n# Plot a bar chart\nplt.figure(figsize=(10, 8))\nsns.barplot(project_counts)\nplt.xticks(rotation=5)\nplt.xlabel('Project Name')\nplt.ylabel('Number of Scientists')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\nplt.title('Number of Scientists Assigned to Each Project')\nplt.show()"
}
]
Loading