Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tidy] Improve code style in response models #626

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ComponentPlan(BaseModel):
component_id: str = Field(
pattern=r"^[a-z]+(_[a-z]+)?$", description="Small snake case description of this component."
)
# TODO: for improvement, we could dynamically create the pydantic model at runtime so that we can
# validate the df_name against the available dataframes
df_name: str = Field(
...,
description="""
Expand All @@ -52,7 +54,9 @@ def create(self, model, all_df_metadata) -> Union[vm.Card, vm.AgGrid, vm.Figure]
return vm.Graph(
id=self.component_id,
figure=vizro_ai.plot(
df=all_df_metadata.get_df(self.df_name), user_input=self.component_description
df=all_df_metadata.get_df(self.df_name),
user_input=self.component_description,
max_debug_retry=2,
),
)
elif self.component_type == "AgGrid":
Expand Down
186 changes: 110 additions & 76 deletions vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was more imagining something along the following lines:

from typing import Any


try:
    from pydantic.v1 import BaseModel, Field, create_model
except ImportError:  # pragma: no cov
    from pydantic import BaseModel, Field
    
class FooModel(BaseModel):
    foo: str
    bar: int = 123

class ProxyModel:
    def __new__(cls, color) -> Any:
        return create_model(
            'BarModel',
            apple='russet',
            banana=color,
            __base__=FooModel,
        )
        

Proxy = ProxyModel(color="orange")

print(Proxy.__fields__.keys())
print(Proxy.__fields__.values())

So in the above toy example, Proxy can just act as a normal response model no?

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Controls plan model."""

import logging
from typing import List, Optional
from typing import Any, Dict, List, Optional, Type

import pandas as pd
import vizro.models as vm
Expand All @@ -16,76 +16,68 @@
logger = logging.getLogger(__name__)


def _create_filter_proxy(df_cols, df_schema, controllable_components) -> BaseModel:
"""Create a filter proxy model."""
class FilterProxyModel:
"""Filter proxy model."""

def validate_targets(v):
"""Validate the targets."""
if v not in controllable_components:
raise ValueError(f"targets must be one of {controllable_components}")
return v
@classmethod
def _create_model(
cls, df_cols: List[str], df_schema: Dict[str, Any], controllable_components: List[str]
) -> Type[BaseModel]:
def validate_targets(v):
if v not in controllable_components:
raise ValueError(f"targets must be one of {controllable_components}")
return v

def validate_targets_not_empty(v):
"""Validate the targets not empty."""
if not controllable_components:
raise ValueError(
"""
This might be due to the filter target is not found in the controllable components.
returning default values.
"""
)
return v

def validate_column(v):
"""Validate the column."""
if v not in df_cols:
raise ValueError(f"column must be one of {df_cols}")
return v

@root_validator(allow_reuse=True)
def validate_date_picker_column(cls, values):
"""Validate the column for date picker."""
column = values.get("column")
selector = values.get("selector")
if selector and selector.type == "date_picker":
if not pd.api.types.is_datetime64_any_dtype(df_schema[column]):
def validate_targets_not_empty(v):
if not controllable_components:
raise ValueError(
f"""
The column '{column}' is not of datetime type. Selector type 'date_picker' is
not allowed. Use 'dropdown' instead.
"""
This might be due to the filter target is not found in the controllable components.
returning default values.
"""
)
return values

return create_model(
"FilterProxy",
targets=(
List[str],
Field(
...,
description=f"""
Target component to be affected by filter.
Must be one of {controllable_components}. ALWAYS REQUIRED.
""",
return v

def validate_column(v):
if v not in df_cols:
raise ValueError(f"column must be one of {df_cols}")
return v

@root_validator(allow_reuse=True)
def validate_date_picker_column(cls, values):
column = values.get("column")
selector = values.get("selector")
if selector and selector.type == "date_picker":
if not pd.api.types.is_datetime64_any_dtype(df_schema[column]):
raise ValueError(
f"""
The column '{column}' is not of datetime type. Selector type 'date_picker' is
not allowed. Use 'dropdown' instead.
"""
)
return values

return create_model(
"FilterProxy",
targets=(
List[str],
Field(
...,
description=f"""
Target component to be affected by filter.
Must be one of {controllable_components}. ALWAYS REQUIRED.
""",
),
),
),
column=(str, Field(..., description="Column name of DataFrame to filter. ALWAYS REQUIRED.")),
__validators__={
"validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets),
"validator2": validator("column", allow_reuse=True)(validate_column),
"validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty),
"validator4": validate_date_picker_column,
},
__base__=vm.Filter,
)


def _create_filter(filter_prompt, model, df_cols, df_schema, controllable_components) -> vm.Filter:
result_proxy = _create_filter_proxy(
df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components
)
proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=result_proxy, df_info=df_schema)
return vm.Filter.parse_obj(proxy.dict(exclude_unset=True))
column=(str, Field(..., description="Column name of DataFrame to filter. ALWAYS REQUIRED.")),
__validators__={
"validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets),
"validator2": validator("column", allow_reuse=True)(validate_column),
"validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty),
"validator4": validate_date_picker_column,
},
__base__=vm.Filter,
)


class ControlPlan(BaseModel):
Expand All @@ -100,31 +92,58 @@ class ControlPlan(BaseModel):
to control a specific component, include the relevant component details.
""",
)
df_name: str = Field(
target_components_id: List[str] = Field(
...,
description="""
The name of the dataframe that the target component will use.
If the dataframe is not used, please specify that.
The id of the target components that this control will affect.
""",
)

def create(self, model, controllable_components, all_df_metadata) -> Optional[vm.Filter]:
def _get_target_df_name(self, components_plan, controllable_components):
target_controllable = set(self.target_components_id) & set(controllable_components)
df_names = {
component_plan.df_name
for component_plan in components_plan
if component_plan.component_id in target_controllable
}

if len(df_names) > 1:
logger.warning(
f"""
[FALLBACK] Multiple dataframes found in the target components: {df_names}.
Choose one dataframe to build the filter.
"""
Comment on lines +102 to +115
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I do not understand this. Should the logic not be:

  • model determines which components should be controlled
  • we then check if those components can be controlled, and take a subset at worst
  • the remaining dfs are what we get the metadata for (and the could be multiple no?)

)

return next(iter(df_names)) if df_names else None

def _create_filter(self, filter_prompt, model, df_cols, df_schema, controllable_components) -> vm.Filter:
FilterProxy = FilterProxyModel._create_model(
df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components
)
proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=FilterProxy, df_info=df_schema)
return vm.Filter.parse_obj(proxy.dict(exclude_unset=True))

def create(self, model, controllable_components, all_df_metadata, components_plan) -> Optional[vm.Filter]:
"""Create the control."""
filter_prompt = f"""
Create a filter from the following instructions: <{self.control_description}>. Do not make up
things that are optional and DO NOT configure actions, action triggers or action chains.
If no options are specified, leave them out.
"""

df_name = self._get_target_df_name(components_plan, controllable_components)

try:
_df_schema = all_df_metadata.get_df_schema(self.df_name)
_df_schema = all_df_metadata.get_df_schema(df_name)
_df_cols = list(_df_schema.keys())
except KeyError:
logger.warning(f"Dataframe {self.df_name} not found in metadata, returning default values.")
logger.warning(f"Dataframe {df_name} not found in metadata, returning default values.")
return None

try:
if self.control_type == "Filter":
res = _create_filter(
res = self._create_filter(
filter_prompt=filter_prompt,
model=model,
df_cols=_df_cols,
Expand All @@ -147,24 +166,39 @@ def create(self, model, controllable_components, all_df_metadata) -> Optional[vm
if __name__ == "__main__":
import pandas as pd
from dotenv import load_dotenv
from vizro.tables import dash_ag_grid
from vizro_ai._llm_models import _get_llm_model
from vizro_ai.dashboard._response_models.components import ComponentPlan
from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata

load_dotenv()

model = _get_llm_model()

all_df_metadata = AllDfMetadata({})
all_df_metadata.all_df_metadata["gdp_chart"] = DfMetadata(
all_df_metadata.all_df_metadata["world_gdp"] = DfMetadata(
df_schema={"a": "int64", "b": "int64"},
df=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}),
df_sample=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}),
)
components_plan = [
ComponentPlan(
component_type="AgGrid",
component_description="Create a table that shows GDP data.",
component_id="gdp_table",
df_name="world_gdp",
)
]
vm.AgGrid(id="gdp_table", figure=dash_ag_grid(data_frame="world_gdp"))
control_plan = ControlPlan(
control_type="Filter",
control_description="Create a filter that filters the data by column 'a'.",
df_name="gdp_chart",
target_components_id=["gdp_table"],
)
control = control_plan.create(
model, ["gdp_chart"], all_df_metadata
) # error: Target gdp_chart not found in model_manager.
model,
["gdp_table"],
all_df_metadata,
components_plan,
)
print(control.__repr__()) # noqa: T201
54 changes: 26 additions & 28 deletions vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,6 @@
logger = logging.getLogger(__name__)


def _convert_to_grid(layout_grid_template_areas: List[str], component_ids: List[str]) -> List[List[int]]:
component_map = {component: index for index, component in enumerate(component_ids)}
grid = []

for row in layout_grid_template_areas:
grid_row = []
for cell in row.split():
if cell == ".":
grid_row.append(-1)
else:
try:
grid_row.append(component_map[cell])
except KeyError:
logger.warning(
f"""
[FALLBACK] Component {cell} not found in component_ids: {component_ids}.
Returning default values.
"""
)
return []
grid.append(grid_row)

return grid


class LayoutPlan(BaseModel):
"""Layout plan model, which only applies to Vizro Components(Graph, AgGrid, Card)."""

Expand All @@ -55,15 +30,38 @@ class LayoutPlan(BaseModel):
""",
)

def _convert_to_grid(self, component_ids: List[str]) -> List[List[int]]:
component_map = {component: index for index, component in enumerate(component_ids)}
grid = []

for row in self.layout_grid_template_areas:
grid_row = []
for raw_cell in row.split():
cell = raw_cell.strip("'\"")
if cell == ".":
grid_row.append(-1)
else:
try:
grid_row.append(component_map[cell])
except KeyError:
logger.warning(
f"""
[FALLBACK] Component {cell} not found in component_ids: {component_ids}.
Returning default values.
"""
)
return []
grid.append(grid_row)

return grid

def create(self, component_ids: List[str]) -> Optional[vm.Layout]:
"""Create the layout."""
if not self.layout_grid_template_areas:
return None

try:
grid = _convert_to_grid(
layout_grid_template_areas=self.layout_grid_template_areas, component_ids=component_ids
)
grid = self._convert_to_grid(component_ids=component_ids)
actual = vm.Layout(grid=grid)
except ValidationError as e:
logger.warning(
Expand Down
Loading
Loading