Skip to content

Commit

Permalink
[Tidy] Restructure vizro-ai (#561)
Browse files Browse the repository at this point in the history
Co-authored-by: nadijagraca <[email protected]>
  • Loading branch information
Anna-Xiong and nadijagraca authored Jul 27, 2024
1 parent f01bb3d commit 0e86730
Show file tree
Hide file tree
Showing 29 changed files with 80 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
File renamed without changes.
6 changes: 3 additions & 3 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import plotly.graph_objects as go
from langchain_openai import ChatOpenAI

from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai.components import GetCodeExplanation, GetDebugger
from vizro_ai.task_pipeline._pipeline_manager import PipelineManager
from vizro_ai._llm_models import _get_llm_model
from vizro_ai.plot.components import GetCodeExplanation, GetDebugger
from vizro_ai.plot.task_pipeline._pipeline_manager import PipelineManager
from vizro_ai.utils.helper import (
DebugFailure,
PlotOutputs,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager
from vizro_ai.plot.components import VizroAiComponentBase
from vizro_ai.plot.schema_manager import SchemaManager

# initialization of schema manager, and register schema needed
# preprocess: llm kwargs for function description schema + partial vars
Expand Down Expand Up @@ -109,7 +109,7 @@ def _chart_to_use(load_args) -> str:
if __name__ == "__main__":
import vizro.plotly.express as px

from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai._llm_models import _get_llm_model

llm_to_use = _get_llm_model()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager
from vizro_ai.plot.components import VizroAiComponentBase
from vizro_ai.plot.schema_manager import SchemaManager

# 1. Define schema
openai_schema_manager = SchemaManager()
Expand Down Expand Up @@ -89,7 +89,7 @@ def run(self, code_snippet: str, chain_input: str = "") -> str:
if __name__ == "__main__":
import vizro.plotly.express as px

from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai._llm_models import _get_llm_model

llm_to_use = _get_llm_model()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager
from vizro_ai.plot.components import VizroAiComponentBase
from vizro_ai.plot.schema_manager import SchemaManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,7 +106,7 @@ def _add_capture_code(code_string: str) -> str:
if __name__ == "__main__":
import vizro.plotly.express as px

from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai._llm_models import _get_llm_model

llm_to_use = _get_llm_model()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager
from vizro_ai.plot.components import VizroAiComponentBase
from vizro_ai.plot.schema_manager import SchemaManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,7 +126,7 @@ def _format_dataframe_string(s: str) -> str:
if __name__ == "__main__":
import vizro.plotly.express as px

from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai._llm_models import _get_llm_model

llm_to_use = _get_llm_model()
df = px.data.gapminder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager
from vizro_ai.plot.components import VizroAiComponentBase
from vizro_ai.plot.schema_manager import SchemaManager

# 1. Define schema
openai_schema_manager = SchemaManager()
Expand Down Expand Up @@ -98,7 +98,7 @@ def _text_cleanup(load_args) -> str:


if __name__ == "__main__":
from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai._llm_models import _get_llm_model

llm_to_use = _get_llm_model()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from langchain_core.language_models.chat_models import BaseChatModel

from vizro_ai.chains._chain_utils import _log_time
from vizro_ai.components import VizroAiComponentBase
from vizro_ai.schema_manager import SchemaManager
from vizro_ai.plot.components import VizroAiComponentBase
from vizro_ai.plot.schema_manager import SchemaManager

# 1. Define schema
openai_schema_manager = SchemaManager()
Expand Down Expand Up @@ -104,7 +104,7 @@ def _clean_visual_code(raw_code: str) -> str:
if __name__ == "__main__":
import vizro.plotly.express as px

from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai._llm_models import _get_llm_model

llm_to_use = _get_llm_model()
df = px.data.gapminder()
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Pipeline Manager."""

from langchain_core.language_models.chat_models import BaseChatModel
from vizro_ai.components import GetChartSelection, GetCustomChart, GetDataFrameCraft, GetVisualCode
from vizro_ai.task_pipeline._pipeline import Pipeline
from vizro_ai.plot.components import GetChartSelection, GetCustomChart, GetDataFrameCraft, GetVisualCode
from vizro_ai.plot.task_pipeline._pipeline import Pipeline


class PipelineManager:
Expand Down
3 changes: 1 addition & 2 deletions vizro-ai/src/vizro_ai/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import pandas as pd
import plotly.graph_objects as go

from .safeguard import _safeguard_check
from vizro_ai.plot._utils._safeguard import _safeguard_check


@dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import pytest
from langchain_community.llms.fake import FakeListLLM
from vizro_ai.components import GetChartSelection
from vizro_ai.plot.components import GetChartSelection


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from langchain_community.llms.fake import FakeListLLM
from vizro_ai.components import GetDebugger
from vizro_ai.plot.components import GetDebugger


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from langchain_community.llms.fake import FakeListLLM
from vizro_ai.components import GetCustomChart
from vizro_ai.plot.components import GetCustomChart


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import pytest
from langchain_community.llms.fake import FakeListLLM
from vizro_ai.components import GetDataFrameCraft
from vizro_ai.plot.components import GetDataFrameCraft


def dataframe_code():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from langchain_community.llms.fake import FakeListLLM
from vizro_ai.components import GetCodeExplanation
from vizro_ai.plot.components import GetCodeExplanation


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from langchain_community.llms.fake import FakeListLLM
from vizro_ai.components import GetVisualCode
from vizro_ai.plot.components import GetVisualCode


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re

import pytest
from vizro_ai.utils.safeguard import _safeguard_check
from vizro_ai.plot._utils._safeguard import _safeguard_check


class TestMaliciousImports:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from vizro_ai.chains._llm_models import _get_llm_model
from vizro_ai._llm_models import _get_llm_model


@pytest.mark.parametrize(
Expand Down

0 comments on commit 0e86730

Please sign in to comment.