From 8dce080032c3f57ebe70b48e2a1683a84de630b6 Mon Sep 17 00:00:00 2001 From: Anna Xiong Date: Thu, 2 Nov 2023 12:10:12 -0400 Subject: [PATCH] prompt refinement and dependency update (#144) --- ...101_175538_anna_xiong_prompt_refinement.md | 48 +++++++++++++++++++ vizro-ai/pyproject.toml | 2 +- vizro-ai/snyk/requirements.txt | 2 +- vizro-ai/src/vizro_ai/chains/_llm_models.py | 8 +--- .../vizro_ai/components/dataframe_craft.py | 16 +++---- .../src/vizro_ai/components/visual_code.py | 5 +- 6 files changed, 63 insertions(+), 18 deletions(-) create mode 100644 vizro-ai/changelog.d/20231101_175538_anna_xiong_prompt_refinement.md diff --git a/vizro-ai/changelog.d/20231101_175538_anna_xiong_prompt_refinement.md b/vizro-ai/changelog.d/20231101_175538_anna_xiong_prompt_refinement.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-ai/changelog.d/20231101_175538_anna_xiong_prompt_refinement.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-ai/pyproject.toml b/vizro-ai/pyproject.toml index f411d2e70..ec978163d 100644 --- a/vizro-ai/pyproject.toml +++ b/vizro-ai/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "pandas", "tabulate", "openai>=0.27.8", - "langchain==0.0.317", + "langchain==0.0.325", "python-dotenv>=1.0.0", # TODO decide env var management to see if we need this "vizro>=0.1.4", # TODO set upper bound later "ipython>=8.10.0", # not directly required, pinned by Snyk to avoid a vulnerability: https://app.snyk.io/vuln/SNYK-PYTHON-IPYTHON-3318382 diff --git a/vizro-ai/snyk/requirements.txt b/vizro-ai/snyk/requirements.txt index 392bcd627..845ee3e01 100644 --- a/vizro-ai/snyk/requirements.txt +++ b/vizro-ai/snyk/requirements.txt @@ -1,7 +1,7 @@ pandas tabulate openai>=0.27.8 -langchain==0.0.317 +langchain==0.0.325 python-dotenv>=1.0.0 vizro>=0.1.4 ipython>=8.10.0 diff --git a/vizro-ai/src/vizro_ai/chains/_llm_models.py b/vizro-ai/src/vizro_ai/chains/_llm_models.py index 0d34967e7..73210fc8d 100644 --- a/vizro-ai/src/vizro_ai/chains/_llm_models.py +++ b/vizro-ai/src/vizro_ai/chains/_llm_models.py @@ -1,11 +1,10 @@ from typing import Callable, Dict, List, Union from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI from pydantic import BaseModel, Field # TODO add new wrappers in if new model support is added -LLM_MODELS = Union[ChatOpenAI, OpenAI] +LLM_MODELS = Union[ChatOpenAI] # TODO constant of model inventory, can be converted to yaml and link to docs PREDEFINED_MODELS: List[Dict[str, any]] = [ @@ -19,11 +18,6 @@ "max_tokens": 8192, "wrapper": ChatOpenAI, }, - { - "name": "text-davinci-003", - "max_tokens": 8192, - "wrapper": OpenAI, - }, ] diff --git a/vizro-ai/src/vizro_ai/components/dataframe_craft.py b/vizro-ai/src/vizro_ai/components/dataframe_craft.py index d9b56e9d2..65f95f445 100755 --- a/vizro-ai/src/vizro_ai/components/dataframe_craft.py +++ b/vizro-ai/src/vizro_ai/components/dataframe_craft.py @@ -25,14 +25,14 @@ class DataFrameCraft(BaseModel): # 2. Define prompt -dataframe_prompt = ( - "write pandas dataframe manipulation code for the given df, df info:{df_schema}, {df_head}, " - "and user question {input}?, DO NOT create a new dataframe" - "DO NOT include plot here, make sure each column exists and will have names and re-indexed " - "when there is aggregation" - "only write dataframe manipulation if required for visualization" - "DO NOT wrap into a function, use line by line code" -) +dataframe_prompt = """Context: You are working with a pandas DataFrame in Python named df. +DataFrame Details Schema: {df_schema}, Sample Data: {df_head}, User Query: {input} +Instructions: 1.Write code to manipulate the df DataFrame according to the user's query. +2.Do not create any new DataFrames; work only with df. +3.Ensure that any aggregated columns are named appropriately and re-indexed if necessary. +4.If a visualization is implied by the user's query, only write the necessary DataFrame manipulation +code for that visualization. 5.Do not include any plotting code. +6. Produce the code in a line-by-line format, not wrapped inside a function.""" # 3. Define Component diff --git a/vizro-ai/src/vizro_ai/components/visual_code.py b/vizro-ai/src/vizro_ai/components/visual_code.py index 89ba9d599..1570ba5a3 100644 --- a/vizro-ai/src/vizro_ai/components/visual_code.py +++ b/vizro-ai/src/vizro_ai/components/visual_code.py @@ -21,7 +21,10 @@ class VizroCode(BaseModel): # 2. Define prompt visual_code_prompt = ( - "Give plotly code for {chart_types}, and user question {input}, using this as df result: {df_code}?" + "Context: You are working with a pandas dataframe in Python. The name of the dataframe is `df`." + "Instructions: Given the code snippet {df_code}, generate Plotly visualization code to produce a {chart_types} " + "chart that addresses user query: {input}. " + "Please ensure the Plotly code aligns with the provided DataFrame details." )