Skip to content

Commit

Permalink
prompt refinement and dependency update (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anna-Xiong authored Nov 2, 2023
1 parent acbba54 commit 8dce080
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
2 changes: 1 addition & 1 deletion vizro-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"pandas",
"tabulate",
"openai>=0.27.8",
"langchain==0.0.317",
"langchain==0.0.325",
"python-dotenv>=1.0.0", # TODO decide env var management to see if we need this
"vizro>=0.1.4", # TODO set upper bound later
"ipython>=8.10.0", # not directly required, pinned by Snyk to avoid a vulnerability: https://app.snyk.io/vuln/SNYK-PYTHON-IPYTHON-3318382
Expand Down
2 changes: 1 addition & 1 deletion vizro-ai/snyk/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pandas
tabulate
openai>=0.27.8
langchain==0.0.317
langchain==0.0.325
python-dotenv>=1.0.0
vizro>=0.1.4
ipython>=8.10.0
Expand Down
8 changes: 1 addition & 7 deletions vizro-ai/src/vizro_ai/chains/_llm_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Callable, Dict, List, Union

from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from pydantic import BaseModel, Field

# TODO add new wrappers in if new model support is added
LLM_MODELS = Union[ChatOpenAI, OpenAI]
LLM_MODELS = Union[ChatOpenAI]

# TODO constant of model inventory, can be converted to yaml and link to docs
PREDEFINED_MODELS: List[Dict[str, any]] = [
Expand All @@ -19,11 +18,6 @@
"max_tokens": 8192,
"wrapper": ChatOpenAI,
},
{
"name": "text-davinci-003",
"max_tokens": 8192,
"wrapper": OpenAI,
},
]


Expand Down
16 changes: 8 additions & 8 deletions vizro-ai/src/vizro_ai/components/dataframe_craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class DataFrameCraft(BaseModel):


# 2. Define prompt
dataframe_prompt = (
"write pandas dataframe manipulation code for the given df, df info:{df_schema}, {df_head}, "
"and user question {input}?, DO NOT create a new dataframe"
"DO NOT include plot here, make sure each column exists and will have names and re-indexed "
"when there is aggregation"
"only write dataframe manipulation if required for visualization"
"DO NOT wrap into a function, use line by line code"
)
dataframe_prompt = """Context: You are working with a pandas DataFrame in Python named df.
DataFrame Details Schema: {df_schema}, Sample Data: {df_head}, User Query: {input}
Instructions: 1.Write code to manipulate the df DataFrame according to the user's query.
2.Do not create any new DataFrames; work only with df.
3.Ensure that any aggregated columns are named appropriately and re-indexed if necessary.
4.If a visualization is implied by the user's query, only write the necessary DataFrame manipulation
code for that visualization. 5.Do not include any plotting code.
6. Produce the code in a line-by-line format, not wrapped inside a function."""


# 3. Define Component
Expand Down
5 changes: 4 additions & 1 deletion vizro-ai/src/vizro_ai/components/visual_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ class VizroCode(BaseModel):

# 2. Define prompt
visual_code_prompt = (
"Give plotly code for {chart_types}, and user question {input}, using this as df result: {df_code}?"
"Context: You are working with a pandas dataframe in Python. The name of the dataframe is `df`."
"Instructions: Given the code snippet {df_code}, generate Plotly visualization code to produce a {chart_types} "
"chart that addresses user query: {input}. "
"Please ensure the Plotly code aligns with the provided DataFrame details."
)


Expand Down

0 comments on commit 8dce080

Please sign in to comment.