-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tidy] Improve code style in response models #626
Open
lingyielia
wants to merge
7
commits into
main
Choose a base branch
from
tidy/improve_code_style
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
b8b7531
add todos
lingyielia 71e71e3
Merge branch 'main' of github.com:mckinsey/vizro into tidy/improve_co…
lingyielia 7ed0685
add 4o mini and small improvement
lingyielia e379632
refactor filter proxy and control plan
lingyielia 59d8e37
tidy
lingyielia 712f96d
take out 4o-mini related content
lingyielia 68cd73c
merge
lingyielia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
48 changes: 48 additions & 0 deletions
48
vizro-ai/changelog.d/20240809_143449_lingyi_zhang_improve_code_style.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
<!-- | ||
A new scriv changelog fragment. | ||
|
||
Uncomment the section that is right (remove the HTML comment wrapper). | ||
--> | ||
|
||
<!-- | ||
### Highlights ✨ | ||
|
||
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1)) | ||
|
||
--> | ||
<!-- | ||
### Removed | ||
|
||
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1)) | ||
|
||
--> | ||
<!-- | ||
### Added | ||
|
||
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1)) | ||
|
||
--> | ||
<!-- | ||
### Changed | ||
|
||
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1)) | ||
|
||
--> | ||
<!-- | ||
### Deprecated | ||
|
||
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1)) | ||
|
||
--> | ||
<!-- | ||
### Fixed | ||
|
||
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1)) | ||
|
||
--> | ||
<!-- | ||
### Security | ||
|
||
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1)) | ||
|
||
--> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
"""Controls plan model.""" | ||
|
||
import logging | ||
from typing import List, Optional | ||
from typing import Any, Dict, List, Optional, Type | ||
|
||
import pandas as pd | ||
import vizro.models as vm | ||
|
@@ -16,76 +16,68 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _create_filter_proxy(df_cols, df_schema, controllable_components) -> BaseModel: | ||
"""Create a filter proxy model.""" | ||
class FilterProxyModel: | ||
"""Filter proxy model.""" | ||
|
||
def validate_targets(v): | ||
"""Validate the targets.""" | ||
if v not in controllable_components: | ||
raise ValueError(f"targets must be one of {controllable_components}") | ||
return v | ||
@classmethod | ||
def _create_model( | ||
cls, df_cols: List[str], df_schema: Dict[str, Any], controllable_components: List[str] | ||
) -> Type[BaseModel]: | ||
def validate_targets(v): | ||
if v not in controllable_components: | ||
raise ValueError(f"targets must be one of {controllable_components}") | ||
return v | ||
|
||
def validate_targets_not_empty(v): | ||
"""Validate the targets not empty.""" | ||
if not controllable_components: | ||
raise ValueError( | ||
""" | ||
This might be due to the filter target is not found in the controllable components. | ||
returning default values. | ||
""" | ||
) | ||
return v | ||
|
||
def validate_column(v): | ||
"""Validate the column.""" | ||
if v not in df_cols: | ||
raise ValueError(f"column must be one of {df_cols}") | ||
return v | ||
|
||
@root_validator(allow_reuse=True) | ||
def validate_date_picker_column(cls, values): | ||
"""Validate the column for date picker.""" | ||
column = values.get("column") | ||
selector = values.get("selector") | ||
if selector and selector.type == "date_picker": | ||
if not pd.api.types.is_datetime64_any_dtype(df_schema[column]): | ||
def validate_targets_not_empty(v): | ||
if not controllable_components: | ||
raise ValueError( | ||
f""" | ||
The column '{column}' is not of datetime type. Selector type 'date_picker' is | ||
not allowed. Use 'dropdown' instead. | ||
""" | ||
This might be due to the filter target is not found in the controllable components. | ||
returning default values. | ||
""" | ||
) | ||
return values | ||
|
||
return create_model( | ||
"FilterProxy", | ||
targets=( | ||
List[str], | ||
Field( | ||
..., | ||
description=f""" | ||
Target component to be affected by filter. | ||
Must be one of {controllable_components}. ALWAYS REQUIRED. | ||
""", | ||
return v | ||
|
||
def validate_column(v): | ||
if v not in df_cols: | ||
raise ValueError(f"column must be one of {df_cols}") | ||
return v | ||
|
||
@root_validator(allow_reuse=True) | ||
def validate_date_picker_column(cls, values): | ||
column = values.get("column") | ||
selector = values.get("selector") | ||
if selector and selector.type == "date_picker": | ||
if not pd.api.types.is_datetime64_any_dtype(df_schema[column]): | ||
raise ValueError( | ||
f""" | ||
The column '{column}' is not of datetime type. Selector type 'date_picker' is | ||
not allowed. Use 'dropdown' instead. | ||
""" | ||
) | ||
return values | ||
|
||
return create_model( | ||
"FilterProxy", | ||
targets=( | ||
List[str], | ||
Field( | ||
..., | ||
description=f""" | ||
Target component to be affected by filter. | ||
Must be one of {controllable_components}. ALWAYS REQUIRED. | ||
""", | ||
), | ||
), | ||
), | ||
column=(str, Field(..., description="Column name of DataFrame to filter. ALWAYS REQUIRED.")), | ||
__validators__={ | ||
"validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets), | ||
"validator2": validator("column", allow_reuse=True)(validate_column), | ||
"validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty), | ||
"validator4": validate_date_picker_column, | ||
}, | ||
__base__=vm.Filter, | ||
) | ||
|
||
|
||
def _create_filter(filter_prompt, model, df_cols, df_schema, controllable_components) -> vm.Filter: | ||
result_proxy = _create_filter_proxy( | ||
df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components | ||
) | ||
proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=result_proxy, df_info=df_schema) | ||
return vm.Filter.parse_obj(proxy.dict(exclude_unset=True)) | ||
column=(str, Field(..., description="Column name of DataFrame to filter. ALWAYS REQUIRED.")), | ||
__validators__={ | ||
"validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets), | ||
"validator2": validator("column", allow_reuse=True)(validate_column), | ||
"validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty), | ||
"validator4": validate_date_picker_column, | ||
}, | ||
__base__=vm.Filter, | ||
) | ||
|
||
|
||
class ControlPlan(BaseModel): | ||
|
@@ -100,31 +92,58 @@ class ControlPlan(BaseModel): | |
to control a specific component, include the relevant component details. | ||
""", | ||
) | ||
df_name: str = Field( | ||
target_components_id: List[str] = Field( | ||
..., | ||
description=""" | ||
The name of the dataframe that the target component will use. | ||
If the dataframe is not used, please specify that. | ||
The id of the target components that this control will affect. | ||
""", | ||
) | ||
|
||
def create(self, model, controllable_components, all_df_metadata) -> Optional[vm.Filter]: | ||
def _get_target_df_name(self, components_plan, controllable_components): | ||
target_controllable = set(self.target_components_id) & set(controllable_components) | ||
df_names = { | ||
component_plan.df_name | ||
for component_plan in components_plan | ||
if component_plan.component_id in target_controllable | ||
} | ||
|
||
if len(df_names) > 1: | ||
logger.warning( | ||
f""" | ||
[FALLBACK] Multiple dataframes found in the target components: {df_names}. | ||
Choose one dataframe to build the filter. | ||
""" | ||
Comment on lines
+102
to
+115
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I do not understand this. Should the logic not be:
|
||
) | ||
|
||
return next(iter(df_names)) if df_names else None | ||
|
||
def _create_filter(self, filter_prompt, model, df_cols, df_schema, controllable_components) -> vm.Filter: | ||
FilterProxy = FilterProxyModel._create_model( | ||
df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components | ||
) | ||
proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=FilterProxy, df_info=df_schema) | ||
return vm.Filter.parse_obj(proxy.dict(exclude_unset=True)) | ||
|
||
def create(self, model, controllable_components, all_df_metadata, components_plan) -> Optional[vm.Filter]: | ||
"""Create the control.""" | ||
filter_prompt = f""" | ||
Create a filter from the following instructions: <{self.control_description}>. Do not make up | ||
things that are optional and DO NOT configure actions, action triggers or action chains. | ||
If no options are specified, leave them out. | ||
""" | ||
|
||
df_name = self._get_target_df_name(components_plan, controllable_components) | ||
|
||
try: | ||
_df_schema = all_df_metadata.get_df_schema(self.df_name) | ||
_df_schema = all_df_metadata.get_df_schema(df_name) | ||
_df_cols = list(_df_schema.keys()) | ||
except KeyError: | ||
logger.warning(f"Dataframe {self.df_name} not found in metadata, returning default values.") | ||
logger.warning(f"Dataframe {df_name} not found in metadata, returning default values.") | ||
return None | ||
|
||
try: | ||
if self.control_type == "Filter": | ||
res = _create_filter( | ||
res = self._create_filter( | ||
filter_prompt=filter_prompt, | ||
model=model, | ||
df_cols=_df_cols, | ||
|
@@ -147,24 +166,39 @@ def create(self, model, controllable_components, all_df_metadata) -> Optional[vm | |
if __name__ == "__main__": | ||
import pandas as pd | ||
from dotenv import load_dotenv | ||
from vizro.tables import dash_ag_grid | ||
from vizro_ai._llm_models import _get_llm_model | ||
from vizro_ai.dashboard._response_models.components import ComponentPlan | ||
from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata | ||
|
||
load_dotenv() | ||
|
||
model = _get_llm_model() | ||
|
||
all_df_metadata = AllDfMetadata({}) | ||
all_df_metadata.all_df_metadata["gdp_chart"] = DfMetadata( | ||
all_df_metadata.all_df_metadata["world_gdp"] = DfMetadata( | ||
df_schema={"a": "int64", "b": "int64"}, | ||
df=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), | ||
df_sample=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), | ||
) | ||
components_plan = [ | ||
ComponentPlan( | ||
component_type="AgGrid", | ||
component_description="Create a table that shows GDP data.", | ||
component_id="gdp_table", | ||
df_name="world_gdp", | ||
) | ||
] | ||
vm.AgGrid(id="gdp_table", figure=dash_ag_grid(data_frame="world_gdp")) | ||
control_plan = ControlPlan( | ||
control_type="Filter", | ||
control_description="Create a filter that filters the data by column 'a'.", | ||
df_name="gdp_chart", | ||
target_components_id=["gdp_table"], | ||
) | ||
control = control_plan.create( | ||
model, ["gdp_chart"], all_df_metadata | ||
) # error: Target gdp_chart not found in model_manager. | ||
model, | ||
["gdp_table"], | ||
all_df_metadata, | ||
components_plan, | ||
) | ||
print(control.__repr__()) # noqa: T201 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was more imagining something along the following lines:
So in the above toy example,
Proxy
can just act as a normal response model no?