diff --git a/docs/concepts/collections.md b/docs/concepts/collections.md index ef5186ed..2959de57 100644 --- a/docs/concepts/collections.md +++ b/docs/concepts/collections.md @@ -25,6 +25,33 @@ my_collection.ask("Find me Italian recipes for soups") In this scenario, the LLM first determines the most suitable view to address the query, and then that view is used to pull the relevant data. +Sometimes, the selected view does not match question (LLM select wrong view) and will raise an error. In such situations, the fallback collections can be used. +This will cause a next view selection, but from the fallback collection. + +```python + llm = LiteLLM(model_name="gpt-3.5-turbo") + user_collection = dbally.create_collection("candidates", llm) + user_collection.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) + user_collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) + user_collection.add(CandidateView, lambda: (candidate_view_with_similarity_store.engine)) + + fallback_collection = dbally.create_collection("freeform candidates", llm) + fallback_collection.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) + user_collection.set_fallback(fallback_collection) +``` +The fallback collection process the same question with declared set of views. The fallback collection could be chained. + +```python + second_fallback_collection = dbally.create_collection("recruitment", llm) + second_fallback_collection.add(RecruitmentView, lambda: RecruitmentView(recruiting_engine)) + + fallback_collection.set_fallback(second_fallback_collection) + +``` + + + + !!! info The result of a query is an [`ExecutionResult`][dbally.collection.results.ExecutionResult] object, which contains the data fetched by the view. It contains a `results` attribute that holds the actual data, structured as a list of dictionaries. The exact structure of these dictionaries depends on the view that was used to fetch the data, which can be obtained by looking at the `view_name` attribute of the `ExecutionResult` object. diff --git a/examples/recruiting/candidate_view_with_similarity_store.py b/examples/recruiting/candidate_view_with_similarity_store.py index f50c4545..4fb394a9 100644 --- a/examples/recruiting/candidate_view_with_similarity_store.py +++ b/examples/recruiting/candidate_view_with_similarity_store.py @@ -5,9 +5,10 @@ from sqlalchemy.ext.automap import automap_base from typing_extensions import Annotated -from dbally import SqlAlchemyBaseView, decorators from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.similarity import FaissStore, SimilarityIndex, SimpleSqlAlchemyFetcher +from dbally.views import decorators +from dbally.views.sqlalchemy_base import SqlAlchemyBaseView engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") diff --git a/examples/recruiting/candidates_freeform.py b/examples/recruiting/candidates_freeform.py new file mode 100644 index 00000000..5aa67d9d --- /dev/null +++ b/examples/recruiting/candidates_freeform.py @@ -0,0 +1,42 @@ +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring +from typing import List + +from sqlalchemy import create_engine +from sqlalchemy.ext.automap import automap_base + +from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig + +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") + +_Base = automap_base() +_Base.prepare(autoload_with=engine) +_Candidate = _Base.classes.candidates + + +class CandidateFreeformView(BaseText2SQLView): + """ + A view for retrieving candidates from the database. + """ + + def get_tables(self) -> List[TableConfig]: + """ + Get the tables used by the view. + + Returns: + A list of tables. + """ + return [ + TableConfig( + name="candidates", + columns=[ + ColumnConfig("name", "TEXT"), + ColumnConfig("country", "TEXT"), + ColumnConfig("years_of_experience", "INTEGER"), + ColumnConfig("position", "TEXT"), + ColumnConfig("university", "TEXT"), + ColumnConfig("skills", "TEXT"), + ColumnConfig("tags", "TEXT"), + ColumnConfig("id", "INTEGER PRIMARY KEY"), + ], + ), + ] diff --git a/examples/recruiting/views.py b/examples/recruiting/views.py index 773d3f62..9765ba51 100644 --- a/examples/recruiting/views.py +++ b/examples/recruiting/views.py @@ -75,7 +75,7 @@ def is_available_within_months( # pylint: disable=W0602, C0116, W9011 end = start + relativedelta(months=months) return Candidate.available_from.between(start, end) - def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011 + def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011, C0116 return [ FewShotExample( "Which candidates studied at University of Toronto?", diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py new file mode 100644 index 00000000..f70f1eed --- /dev/null +++ b/examples/visualize_fallback_code.py @@ -0,0 +1,36 @@ +# pylint: disable=missing-function-docstring +import asyncio + +from recruiting import candidate_view_with_similarity_store, candidates_freeform +from recruiting.candidate_view_with_similarity_store import CandidateView +from recruiting.candidates_freeform import CandidateFreeformView +from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine +from recruiting.db import ENGINE as recruiting_engine +from recruiting.views import RecruitmentView + +import dbally +from dbally.audit import CLIEventHandler, OtelEventHandler +from dbally.gradio import create_gradio_interface +from dbally.llms.litellm import LiteLLM + + +async def main(): + llm = LiteLLM(model_name="gpt-3.5-turbo") + user_collection = dbally.create_collection("candidates", llm) + user_collection.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine)) + user_collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) + + fallback_collection = dbally.create_collection("freeform candidates", llm, event_handlers=[OtelEventHandler()]) + fallback_collection.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine)) + + second_fallback_collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) + second_fallback_collection.add(RecruitmentView, lambda: RecruitmentView(recruiting_engine)) + + user_collection.set_fallback(fallback_collection).set_fallback(second_fallback_collection) + + gradio_interface = await create_gradio_interface(user_collection=user_collection) + gradio_interface.launch() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 824ce34d..7aca9e48 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List -from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError +from dbally.collection.exceptions import NoViewFoundError from dbally.collection.results import ExecutionResult from dbally.views import decorators from dbally.views.methods_base import MethodsBaseView @@ -40,7 +40,6 @@ "EmbeddingConnectionError", "EmbeddingResponseError", "EmbeddingStatusError", - "IndexUpdateError", "LLMError", "LLMConnectionError", "LLMResponseError", diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index 03aa9806..5c97a016 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -13,7 +13,7 @@ pprint = print # type: ignore from dbally.audit.event_handlers.base import EventHandler -from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from dbally.audit.events import Event, FallbackEvent, LLMEvent, RequestEnd, RequestStart, SimilarityEvent _RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"} _RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]" @@ -94,6 +94,18 @@ async def event_start(self, event: Event, request_context: None) -> None: f"[cyan bold]STORE: [grey53]{event.store}\n" f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n" ) + elif isinstance(event, FallbackEvent): + self._print_syntax( + f"[grey53]\n=======================================\n" + "[grey53]=======================================\n" + f"[orange bold]Fallback event starts \n" + f"[orange bold]Triggering collection: [grey53]{event.triggering_collection_name}\n" + f"[orange bold]Triggering view name: [grey53]{event.triggering_view_name}\n" + f"[orange bold]Error description: [grey53]{event.error_description}\n" + f"[orange bold]Fallback collection name: [grey53]{event.fallback_collection_name}\n" + "[grey53]=======================================\n" + "[grey53]=======================================\n" + ) # pylint: disable=unused-argument async def event_end(self, event: Optional[Event], request_context: None, event_context: None) -> None: @@ -123,8 +135,11 @@ async def request_end(self, output: RequestEnd, request_context: Optional[dict] output: The output of the request. request_context: Optional context passed from request_start method """ - self._print_syntax("[green bold]REQUEST OUTPUT:") - self._print_syntax(f"Number of rows: {len(output.result.results)}") + if output.result: + self._print_syntax("[green bold]REQUEST OUTPUT:") + self._print_syntax(f"Number of rows: {len(output.result.results)}") - if "sql" in output.result.context: - self._print_syntax(f"{output.result.context['sql']}", "psql") + if "sql" in output.result.context: + self._print_syntax(f"{output.result.context['sql']}", "psql") + else: + self._print_syntax("[red bold]No results found") diff --git a/src/dbally/audit/event_handlers/otel_event_handler.py b/src/dbally/audit/event_handlers/otel_event_handler.py index 00a106a2..91a709f5 100644 --- a/src/dbally/audit/event_handlers/otel_event_handler.py +++ b/src/dbally/audit/event_handlers/otel_event_handler.py @@ -7,7 +7,7 @@ from opentelemetry.util.types import AttributeValue from dbally.audit.event_handlers.base import EventHandler -from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from dbally.audit.events import Event, FallbackEvent, LLMEvent, RequestEnd, RequestStart, SimilarityEvent TRACER_NAME = "db-ally.events" FORBIDDEN_CONTEXT_KEYS = {"filter_mask"} @@ -172,8 +172,11 @@ async def event_start(self, event: Event, request_context: SpanHandler) -> SpanH .set("db-ally.similarity.fetcher", event.fetcher) .set_input("db-ally.similarity.input", event.input_value) ) + if isinstance(event, FallbackEvent): + with self._new_child_span(request_context, "fallback") as span: + return self._handle_span(span).set("db-ally.error_description", event.error_description) - raise ValueError(f"Unsuported event: {type(event)}") + raise ValueError(f"Unsupported event: {type(event)}") async def event_end(self, event: Optional[Event], request_context: SpanHandler, event_context: SpanHandler) -> None: """ diff --git a/src/dbally/audit/events.py b/src/dbally/audit/events.py index de397a74..3bb23e17 100644 --- a/src/dbally/audit/events.py +++ b/src/dbally/audit/events.py @@ -41,6 +41,18 @@ class SimilarityEvent(Event): output_value: Optional[str] = None +@dataclass +class FallbackEvent(Event): + """ + FallbackEvent is fired when a processed view/collection raise an exception. + """ + + triggering_collection_name: str + triggering_view_name: str + fallback_collection_name: str + error_description: str + + @dataclass class RequestStart: """ diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index c3d7b1d3..542f78e4 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -1,5 +1,6 @@ import asyncio import inspect +import logging import textwrap import time from collections import defaultdict @@ -8,9 +9,10 @@ import dbally from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker -from dbally.audit.events import RequestEnd, RequestStart +from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError -from dbally.collection.results import ExecutionResult +from dbally.collection.results import ExecutionResult, ViewExecutionResult +from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder @@ -18,13 +20,15 @@ from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation +HANDLED_EXCEPTION_TYPES = (NoViewFoundError, UnsupportedQueryError, IndexUpdateError) + class Collection: """ Collection is a container for a set of views that can be used by db-ally to answer user questions. Tip: - It is recommended to create new collections using the [`dbally.create_colletion`][dbally.create_collection]\ + It is recommended to create new collections using the [`dbally.create_collection`][dbally.create_collection]\ function instead of instantiating this class directly. """ @@ -36,6 +40,7 @@ def __init__( nl_responder: NLResponder, event_handlers: Optional[List[EventHandler]] = None, n_retries: int = 3, + fallback_collection: Optional["Collection"] = None, ) -> None: """ Args: @@ -49,9 +54,12 @@ def __init__( event_handlers: Event handlers used by the collection during query executions. Can be used\ to log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance\ as [LangSmithEventHandler](event_handlers/langsmith_handler.md). + nl_responder: Object that translates RAW response from db-ally into natural language. n_retries: IQL generator may produce invalid IQL. If this is the case this argument specifies\ how many times db-ally will try to regenerate it. Previous try with the error message is\ appended to the chat history to guide next generations. + fallback_collection: collection to be asked when the ask function could not find answer in views registered + to this collection """ self.name = name self.n_retries = n_retries @@ -60,7 +68,7 @@ def __init__( self._view_selector = view_selector self._nl_responder = nl_responder self._llm = llm - + self._fallback_collection: Optional[Collection] = fallback_collection self._event_handlers = event_handlers or dbally.event_handlers T = TypeVar("T", bound=BaseView) @@ -74,7 +82,7 @@ def add(self, view: Type[T], builder: Optional[Callable[[], T]] = None, name: Op query execution. We expect Class instead of object, as otherwise Views must have been implemented\ stateless, which would be cumbersome. builder: Optional factory function that will be used to create the View instance. Use it when you\ - need to pass outcome of API call or database connection to the view and it can change over time. + need to pass outcome of API call or database connection to the view, and it can change over time. name: Custom name of the view (defaults to the name of the class). Raises: @@ -113,6 +121,37 @@ def build_dogs_df_view(): self._views[name] = view self._builders[name] = builder + def set_fallback(self, fallback_collection: "Collection") -> "Collection": + """ + Set fallback collection which will be asked if the ask to base collection does not succeed. + + Args: + fallback_collection: Collection to be asked in case of base collection failure. + + Returns: + The fallback collection to create chains call + """ + self._fallback_collection = fallback_collection + if fallback_collection._event_handlers != self._event_handlers: # pylint: disable=W0212 + logging.warning( + "Event handlers of the fallback collection are different from the base collection. " + "Continuity of the audit trail is not guaranteed.", + ) + + return fallback_collection + + def __rshift__(self, fallback_collection: "Collection"): + """ + Add fallback collection which will be asked if the ask to base collection does not succeed. + + Args: + fallback_collection: Collection to be asked in case of base collection failure. + + Returns: + The fallback collection to create chains call + """ + return self.set_fallback(fallback_collection) + def get(self, name: str) -> BaseView: """ Returns an instance of the view with the given name @@ -143,12 +182,173 @@ def list(self) -> Dict[str, str]: name: (textwrap.dedent(view.__doc__).strip() if view.__doc__ else "") for name, view in self._views.items() } + async def _select_view( + self, + question: str, + event_tracker: EventTracker, + llm_options: Optional[LLMOptions], + ) -> str: + """ + Select a view based on the provided question and options. + + If there is only one view available, it selects that view directly. Otherwise, it + uses the view selector to choose the most appropriate view. + + Args: + question: The question to be answered. + event_tracker: The event tracker for logging and tracking events. + llm_options: Options for the LLM client. + + Returns: + str: The name of the selected view. + + Raises: + ValueError: If the collection of views is empty. + """ + + views = self.list() + if len(views) == 0: + raise ValueError("Empty collection") + if len(views) == 1: + selected_view_name = next(iter(views)) + else: + selected_view_name = await self._view_selector.select_view( + question=question, + views=views, + event_tracker=event_tracker, + llm_options=llm_options, + ) + return selected_view_name + + async def _ask_view( + self, + selected_view_name: str, + question: str, + event_tracker: EventTracker, + llm_options: Optional[LLMOptions], + dry_run: bool, + ): + """ + Ask the selected view to provide an answer to the question. + + Args: + selected_view_name: The name of the selected view. + question: The question to be answered. + event_tracker: The event tracker for logging and tracking events. + llm_options: Options for the LLM client. + dry_run: If True, only generate the query without executing it. + + Returns: + Any: The result from the selected view. + """ + selected_view = self.get(selected_view_name) + view_result = await selected_view.ask( + query=question, + llm=self._llm, + event_tracker=event_tracker, + n_retries=self.n_retries, + dry_run=dry_run, + llm_options=llm_options, + ) + return view_result + + async def _generate_textual_response( + self, + view_result: ViewExecutionResult, + question: str, + event_tracker: EventTracker, + llm_options: Optional[LLMOptions], + ) -> str: + """ + Generate a textual response from the view result. + + Args: + view_result: The result from the view. + question: The question to be answered. + event_tracker: The event tracker for logging and tracking events. + llm_options: Options for the LLM client. + + Returns: + The generated textual response. + """ + textual_response = await self._nl_responder.generate_response( + result=view_result, + question=question, + event_tracker=event_tracker, + llm_options=llm_options, + ) + return textual_response + + def get_all_event_handlers(self) -> List[EventHandler]: + """ + Retrieves all event handlers, including those from a fallback collection if available. + + This method returns a list of event handlers. If there is no fallback collection, + it simply returns the event handlers stored in the current object. If a fallback + collection is available, it combines the event handlers from both the current object + and the fallback collection, ensuring no duplicates. + + Returns: + A list of event handlers. + """ + if not self._fallback_collection: + return self._event_handlers + return list(set(self._event_handlers).union(self._fallback_collection.get_all_event_handlers())) + + async def _handle_fallback( + self, + question: str, + dry_run: bool, + return_natural_response: bool, + llm_options: Optional[LLMOptions], + selected_view_name: str, + event_tracker: EventTracker, + caught_exception: Exception, + ) -> ExecutionResult: + """ + Handle fallback if the main query fails. + + Args: + question: The question to be answered. + dry_run: If True, only generate the query without executing it. + return_natural_response: If True, return the natural language response. + llm_options: Options for the LLM client. + selected_view_name: The name of the selected view. + event_tracker: The event tracker for logging and tracking events. + caught_exception: The exception that was caught. + + Returns: + The result from the fallback collection. + + """ + if not self._fallback_collection: + raise caught_exception + + fallback_event = FallbackEvent( + triggering_collection_name=self.name, + triggering_view_name=selected_view_name, + fallback_collection_name=self._fallback_collection.name, + error_description=repr(caught_exception), + ) + + async with event_tracker.track_event(fallback_event) as span: + result = await self._fallback_collection.ask( + question=question, + dry_run=dry_run, + return_natural_response=return_natural_response, + llm_options=llm_options, + event_tracker=event_tracker, + ) + span(fallback_event) + return result + async def ask( self, question: str, dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. @@ -168,6 +368,7 @@ async def ask( the natural response will be included in the answer llm_options: options to use for the LLM client. If provided, these options will be merged with the default options provided to the LLM client, prioritizing option values other than NOT_GIVEN + event_tracker: Event tracker object for given ask. Returns: ExecutionResult object representing the result of the query execution. @@ -176,60 +377,68 @@ async def ask( ValueError: if collection is empty IQLError: if incorrect IQL was generated `n_retries` amount of times. ValueError: if incorrect IQL was generated `n_retries` amount of times. + NoViewFoundError: if question does not match to any registered view, + UnsupportedQueryError: if the question could not be answered + IndexUpdateError: if index update failed """ - start_time = time.monotonic() - - event_tracker = EventTracker.initialize_with_handlers(self._event_handlers) + if not event_tracker: + is_fallback_call = False + event_handlers = self.get_all_event_handlers() + event_tracker = EventTracker.initialize_with_handlers(event_handlers) + await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) + else: + is_fallback_call = True - await event_tracker.request_start(RequestStart(question=question, collection_name=self.name)) + selected_view_name = "" - # select view - views = self.list() + try: + start_time = time.monotonic() + selected_view_name = await self._select_view( + question=question, event_tracker=event_tracker, llm_options=llm_options + ) - if len(views) == 0: - raise ValueError("Empty collection") - if len(views) == 1: - selected_view = next(iter(views)) - else: - selected_view = await self._view_selector.select_view( + start_time_view = time.monotonic() + view_result = await self._ask_view( + selected_view_name=selected_view_name, question=question, - views=views, event_tracker=event_tracker, llm_options=llm_options, + dry_run=dry_run, ) + end_time_view = time.monotonic() - view = self.get(selected_view) - - start_time_view = time.monotonic() - view_result = await view.ask( - query=question, - llm=self._llm, - event_tracker=event_tracker, - n_retries=self.n_retries, - dry_run=dry_run, - llm_options=llm_options, - ) - end_time_view = time.monotonic() + natural_response = ( + await self._generate_textual_response(view_result, question, event_tracker, llm_options) + if not dry_run and return_natural_response + else "" + ) - textual_response = None - if not dry_run and return_natural_response: - textual_response = await self._nl_responder.generate_response( - result=view_result, - question=question, - event_tracker=event_tracker, - llm_options=llm_options, + result = ExecutionResult( + results=view_result.results, + context=view_result.context, + execution_time=time.monotonic() - start_time, + execution_time_view=end_time_view - start_time_view, + view_name=selected_view_name, + textual_response=natural_response, ) - result = ExecutionResult( - results=view_result.results, - context=view_result.context, - execution_time=time.monotonic() - start_time, - execution_time_view=end_time_view - start_time_view, - view_name=selected_view, - textual_response=textual_response, - ) + except HANDLED_EXCEPTION_TYPES as caught_exception: + if self._fallback_collection: + result = await self._handle_fallback( + question=question, + dry_run=dry_run, + return_natural_response=return_natural_response, + llm_options=llm_options, + selected_view_name=selected_view_name, + event_tracker=event_tracker, + caught_exception=caught_exception, + ) + else: + raise caught_exception + + if not is_fallback_call: + await event_tracker.request_end(RequestEnd(result=result)) - await event_tracker.request_end(RequestEnd(result=result)) return result def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index b0e9586e..4a8de2b4 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -152,6 +152,7 @@ async def _ui_ask_query( generated_query = str(execution_result.context) data = self._load_results_into_dataframe(execution_result.results) textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response + except UnsupportedQueryError: generated_query = {"Query": "unsupported"} data = pd.DataFrame() diff --git a/tests/unit/test_fallback_collection.py b/tests/unit/test_fallback_collection.py new file mode 100644 index 00000000..137581b6 --- /dev/null +++ b/tests/unit/test_fallback_collection.py @@ -0,0 +1,172 @@ +from typing import List, Optional +from unittest.mock import AsyncMock, Mock + +import pytest +from sqlalchemy import create_engine + +import dbally +from dbally.audit import CLIEventHandler, EventTracker, OtelEventHandler +from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler +from dbally.collection import Collection, ViewExecutionResult +from dbally.iql_generator.prompt import UnsupportedQueryError +from dbally.llms import LLM +from dbally.llms.clients import LLMOptions +from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig +from tests.unit.mocks import MockIQLGenerator, MockLLM, MockViewBase, MockViewSelector + +engine = create_engine("sqlite://", echo=True) + + +class MyText2SqlView(BaseText2SQLView): + """ + A Text2SQL view for the example. + """ + + def get_tables(self) -> List[TableConfig]: + return [ + TableConfig( + name="mock_table", + columns=[ + ColumnConfig("mock_field1", "SERIAL PRIMARY KEY"), + ColumnConfig("mock_field2", "VARCHAR(255)"), + ], + ), + ] + + async def ask( + self, + query: str, + llm: LLM, + event_tracker: EventTracker, + n_retries: int = 3, + dry_run: bool = False, + llm_options: Optional[LLMOptions] = None, + ) -> ViewExecutionResult: + return ViewExecutionResult( + results=[{"mock_result": "fallback_result"}], context={"mock_context": "fallback_context"} + ) + + +class MockView1(MockViewBase): + """ + Mock view 1 + """ + + def execute(self, dry_run=False) -> ViewExecutionResult: + return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + + def get_iql_generator(self, *_, **__) -> MockIQLGenerator: + raise UnsupportedQueryError + + +class MockView2(MockViewBase): + """ + Mock view 2 + """ + + +@pytest.fixture(name="base_collection") +def mock_base_collection() -> Collection: + """ + Returns a collection with two mock views + """ + collection = dbally.create_collection( + "foo", + llm=MockLLM(), + view_selector=MockViewSelector("MockView1"), + nl_responder=AsyncMock(), + ) + collection.add(MockView1) + collection.add(MockView2) + return collection + + +@pytest.fixture(name="fallback_collection") +def mock_fallback_collection() -> Collection: + """ + Returns a collection with two mock views + """ + collection = dbally.create_collection( + "fallback_foo", + llm=MockLLM(), + view_selector=MockViewSelector("MyText2SqlView"), + nl_responder=AsyncMock(), + ) + collection.add(MyText2SqlView, lambda: MyText2SqlView(engine)) + return collection + + +async def test_no_fallback_collection(base_collection: Collection, fallback_collection: Collection): + with pytest.raises(UnsupportedQueryError) as exc_info: + result = await base_collection.ask("Mock fallback question") + print(result) + print(exc_info) + + +async def test_fallback_collection(base_collection: Collection, fallback_collection: Collection): + base_collection.set_fallback(fallback_collection) + result = await base_collection.ask("Mock fallback question") + assert result.results == [{"mock_result": "fallback_result"}] + assert result.context == {"mock_context": "fallback_context"} + + +def test_get_all_event_handlers_no_fallback(): + handler1 = CLIEventHandler() + handler2 = BufferEventHandler() + + collection = Collection( + name="test_collection", + llm=MockLLM(), + nl_responder=AsyncMock(), + view_selector=Mock(), + event_handlers=[handler1, handler2], + ) + + result = collection.get_all_event_handlers() + + assert result == [handler1, handler2] + + +def test_get_all_event_handlers_with_fallback(): + handler1 = CLIEventHandler() + handler2 = BufferEventHandler() + handler3 = OtelEventHandler() + + fallback_collection = Collection( + name="fallback_collection", view_selector=Mock(), llm=Mock(), nl_responder=Mock(), event_handlers=[handler3] + ) + + collection = Collection( + name="test_collection", + view_selector=Mock(), + llm=MockLLM(), + nl_responder=AsyncMock(), + event_handlers=[handler1, handler2], + fallback_collection=fallback_collection, + ) + + result = collection.get_all_event_handlers() + + assert set(result) == {handler1, handler2, handler3} + + +def test_get_all_event_handlers_with_duplicates(): + handler1 = CLIEventHandler() + handler2 = BufferEventHandler() + + fallback_collection = Collection( + name="fallback_collection", view_selector=Mock(), llm=Mock(), nl_responder=Mock(), event_handlers=[handler2] + ) + + collection = Collection( + name="test_collection", + view_selector=Mock(), + llm=Mock(), + nl_responder=Mock(), + event_handlers=[handler1, handler2], + fallback_collection=fallback_collection, + ) + + result = collection.get_all_event_handlers() + + assert set(result) == {handler1, handler2}