Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(iql): add iql gen exception #77

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.llms.clients.exceptions import LLMError
from dbally.prompt.elements import FewShotExample
from dbally.prompt.template import PromptTemplate
from dbally.views.exposed_functions import ExposedFunction
Expand Down Expand Up @@ -52,13 +53,15 @@ async def generate_iql(
event_tracker: Event store used to audit the generation process.
examples: List of examples to be injected into the conversation.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors.
n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.

Returns:
Generated IQL query.

Raises:
IQLError: If IQL generation fails after all retries.
LLMError: If LLM text generation fails after all retries.
IQLError: If IQL parsing fails after all retries.
UnsupportedQueryError: If the question is not supported by the view.
"""
prompt_format = IQLGenerationPromptFormat(
question=question,
Expand All @@ -82,6 +85,9 @@ async def generate_iql(
allowed_functions=filters,
event_tracker=event_tracker,
)
except LLMError as exc:
if retry == n_retries:
raise exc
except IQLError as exc:
if retry == n_retries:
raise exc
Expand Down
3 changes: 3 additions & 0 deletions src/dbally/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ async def generate_text(

Returns:
Text response from LLM.

Raises:
LLMError: If LLM text generation fails.
"""
options = (self.default_options | options) if options else self.default_options
event = LLMEvent(prompt=prompt.chat, type=type(prompt).__name__)
Expand Down
3 changes: 3 additions & 0 deletions src/dbally/nl_responder/nl_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ async def generate_response(

Returns:
Natural language response to the user question.

Raises:
LLMError: If LLM text generation fails.
"""
prompt_format = NLResponsePromptFormat(
question=question,
Expand Down
3 changes: 3 additions & 0 deletions src/dbally/view_selection/llm_view_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ async def select_view(

Returns:
The most relevant view name.

Raises:
LLMError: If LLM text generation fails.
"""
prompt_format = ViewSelectionPromptFormat(question=question, views=views)
formatted_prompt = self._prompt_template.format_prompt(prompt_format)
Expand Down
26 changes: 26 additions & 0 deletions src/dbally/views/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

from dbally.exceptions import DbAllyError


class IQLGenerationError(DbAllyError):
"""
Exception for when an error occurs while generating IQL for a view.
"""

def __init__(
self,
view_name: str,
filters: Optional[str] = None,
aggregation: Optional[str] = None,
) -> None:
"""
Args:
view_name: Name of the view that caused the error.
filters: Filters generated by the view.
aggregation: Aggregation generated by the view.
"""
super().__init__(f"Error while generating IQL for view {view_name}")
self.view_name = view_name
self.filters = filters
self.aggregation = aggregation
36 changes: 27 additions & 9 deletions src/dbally/views/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
from dbally.iql import IQLQuery
from dbally.iql._exceptions import IQLError
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exposed_functions import ExposedFunction

from ..similarity import AbstractSimilarityIndex
Expand Down Expand Up @@ -57,21 +60,36 @@ async def ask(
The result of the query.

Raises:
IQLError: If the generated IQL query is not valid.
LLMError: If LLM text generation API fails.
IQLGenerationError: If the IQL generation fails.
"""
iql_generator = self.get_iql_generator(llm)

filters = self.list_filters()
examples = self.list_few_shots()

iql = await iql_generator.generate_iql(
question=query,
filters=filters,
examples=examples,
event_tracker=event_tracker,
llm_options=llm_options,
n_retries=n_retries,
)
try:
iql = await iql_generator.generate_iql(
question=query,
filters=filters,
examples=examples,
event_tracker=event_tracker,
llm_options=llm_options,
n_retries=n_retries,
)
except UnsupportedQueryError as exc:
raise IQLGenerationError(
view_name=self.__class__.__name__,
filters=None,
aggregation=None,
) from exc
except IQLError as exc:
raise IQLGenerationError(
view_name=self.__class__.__name__,
filters=exc.source,
aggregation=None,
) from exc

await self.apply_filters(iql)

result = self.execute(dry_run=dry_run)
Expand Down
Loading