diff --git a/pandasai/connectors/pandas.py b/pandasai/connectors/pandas.py index d7063746a..d59138195 100644 --- a/pandasai/connectors/pandas.py +++ b/pandasai/connectors/pandas.py @@ -16,6 +16,7 @@ from ..helpers.file_importer import FileImporter from ..helpers.logger import Logger from .base import BaseConnector +import sqlglot class PandasConnectorConfig(BaseModel): @@ -165,6 +166,7 @@ def enable_sql_query(self, table_name=None): raise PandasConnectorTableNotFound("Table name not found!") table = table_name or self.name + duckdb_relation = duckdb.from_df(self.pandas_df) duckdb_relation.create(table) self.sql_enabled = True @@ -173,7 +175,8 @@ def enable_sql_query(self, table_name=None): def execute_direct_sql_query(self, sql_query): if not self.sql_enabled: self.enable_sql_query() - sql_query = sql_query.replace("`", '"') + + sql_query = sqlglot.transpile(sql_query, read="mysql", write="duckdb")[0] return duckdb.query(sql_query).df() @property diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index 663e60225..5c414095a 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -11,6 +11,7 @@ from sqlalchemy import asc, create_engine, select, text from sqlalchemy.engine import Connection +import sqlglot import pandasai.pandas as pd from pandasai.exceptions import MaliciousQueryError @@ -651,7 +652,7 @@ def cs_table_name(self): return f'"{self.config.table}"' def execute_direct_sql_query(self, sql_query): - sql_query = sql_query.replace("`", '"') + sql_query = sqlglot.transpile(sql_query, read="mysql", write="postgres")[0] return super().execute_direct_sql_query(sql_query) diff --git a/pandasai/ee/agents/semantic_agent/__init__.py b/pandasai/ee/agents/semantic_agent/__init__.py index f3a2f0ad6..5c9f56bcf 100644 --- a/pandasai/ee/agents/semantic_agent/__init__.py +++ b/pandasai/ee/agents/semantic_agent/__init__.py @@ -14,7 +14,7 @@ from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import ( GenerateDFSchemaPrompt, ) -from pandasai.exceptions import InvalidConfigError, InvalidTrainJson +from pandasai.exceptions import InvalidConfigError, InvalidSchemaJson, InvalidTrainJson from pandasai.helpers.cache import Cache from pandasai.helpers.memory import Memory from pandasai.llm.bamboo_llm import BambooLLM @@ -51,6 +51,8 @@ def __init__( self._create_schema() + self._sort_dfs_according_to_schema() + self.init_duckdb_instance() # semantic agent works only with direct sql true @@ -125,8 +127,37 @@ def query(self, query): def init_duckdb_instance(self): for index, tables in enumerate(self._schema): if isinstance(self.dfs[index], PandasConnector): + self._sync_pandas_dataframe_schema(self.dfs[index], tables) self.dfs[index].enable_sql_query(tables["table"]) + def _sync_pandas_dataframe_schema(self, df: PandasConnector, schema: dict): + for dimension in schema["dimensions"]: + if dimension["type"] == "date": + column = dimension["sql"] + df.pandas_df[column] = pd.to_datetime(df.pandas_df[column]) + + def _sort_dfs_according_to_schema(self): + schema_dict = { + table["table"]: [dim["sql"] for dim in table["dimensions"]] + for table in self._schema + } + sorted_dfs = [] + + for table in self._schema: + matched = False + for df in self.dfs: + df_columns = df.get_head().columns + if all(column in df_columns for column in schema_dict[table["table"]]): + sorted_dfs.append(df) + matched = True + + if not matched: + raise InvalidSchemaJson( + f"Some sql column of table {table['table']} doesn't match with any dataframe" + ) + + self.dfs = sorted_dfs + def _create_schema(self): """ Generate schema on the initialization of Agent class diff --git a/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py b/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py index 6c22bfc82..00e75764e 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py @@ -1,3 +1,4 @@ +import traceback from typing import Any, Callable from pandasai.ee.helpers.query_builder import QueryBuilder @@ -13,17 +14,21 @@ class CodeGenerator(BaseLogicUnit): """ def __init__( - self, on_code_generation: Callable[[str, Exception], None] = None, **kwargs + self, + on_code_generation: Callable[[str, Exception], None] = None, + on_failure=None, + **kwargs, ): super().__init__(**kwargs) self.on_code_generation = on_code_generation + self.on_failure = on_failure - def execute(self, input: Any, **kwargs) -> Any: + def execute(self, input_data: Any, **kwargs) -> Any: """ This method will return output according to Implementation. - :param input: Your input data. + :param input_data: Your input data. :param kwargs: A dictionary of keyword arguments. - 'logger' (any): The logger for logging. - 'config' (Config): Global configurations for the test @@ -36,15 +41,16 @@ def execute(self, input: Any, **kwargs) -> Any: schema = pipeline_context.get("df_schema") query_builder = QueryBuilder(schema) - sql_query = query_builder.generate_sql(input) + retry_count = 0 + while retry_count <= pipeline_context.config.max_retries: + try: + sql_query = query_builder.generate_sql(input_data) - print(sql_query) + response_type = self._get_type(input_data) - response_type = self._get_type(input) + gen_code = self._generate_code(response_type, input_data) - gen_code = self._generate_code(response_type, input) - - code = f""" + code = f""" {"import matplotlib.pyplot as plt" if response_type == "plot" else ""} import pandas as pd @@ -54,27 +60,46 @@ def execute(self, input: Any, **kwargs) -> Any: {gen_code} """ - logger.log(f"""Code Generated: {code}""") + logger.log(f"""Code Generated: {code}""") - # Implement error handling pipeline here... + # Implement error handling pipeline here... - return LogicUnitOutput( - code, - True, - "Code Generated Successfully", - {"content_type": "string", "value": code}, - ) + return LogicUnitOutput( + code, + True, + "Code Generated Successfully", + {"content_type": "string", "value": code}, + ) + except Exception: + if ( + retry_count == pipeline_context.config.max_retries + or not self.on_failure + ): + raise + + traceback_errors = traceback.format_exc() + + input_data = self.on_failure(input, traceback_errors) + + retry_count += 1 def _get_type(self, input: dict) -> bool: - return "number" if input["type"] == "number" else "plot" + return ( + "plot" + if input["type"] in ["bar", "line", "histogram", "pie", "scatter"] + else input["type"] + ) def _generate_code(self, type, query): if type == "number": code = self._generate_code_for_number(query) - - # Format code final output return f""" -result = {{"type": "number","value": {code}}} +{code} +result = {{"type": "number","value": total_value}} +""" + elif type == "dataframe": + return """ +result = {{"type": "dataframe","value": data}} """ else: code = self.generate_matplotlib_code(query) @@ -90,7 +115,7 @@ def _generate_code_for_number(self, query: dict) -> str: else: value = query["dimensions"][0].split(".")[1] - return f'data["{value}"].iloc[0]' + return f'total_value = data["{value}"].sum()\n' def generate_matplotlib_code(self, query: dict) -> str: chart_type = query["type"] @@ -137,11 +162,11 @@ def generate_matplotlib_code(self, query: dict) -> str: code += code_generator(query) if x_label: - code += f"plt.xlabel('{x_label}')\n" + code += f"plt.xlabel('''{x_label}''')\n" if y_label: - code += f"plt.ylabel('{y_label}')\n" + code += f"plt.ylabel('''{y_label}''')\n" if title: - code += f"plt.title('{title}')\n" + code += f"plt.title('''{title}''')\n" if legend_display: code += f"plt.legend(loc='{legend_position}')\n" @@ -153,7 +178,7 @@ def generate_matplotlib_code(self, query: dict) -> str: return code def _generate_bar_code(self, query): - x_key = query["dimensions"][0].split(".")[1] + x_key = self._get_dimensions_key(query) plots = "" for measure in query["measures"]: if isinstance(measure, str): @@ -175,12 +200,7 @@ def _generate_pie_code(self, query): return f"""plt.pie(data["{measure}"], labels=data["{dimension}"], autopct='%1.1f%%')\n""" def _generate_line_code(self, query): - if "dimensions" in query and len(query["dimensions"]) > 0: - x_key = query["dimensions"][0].split(".")[1] - else: - dimension = query["timeDimensions"][0]["dimension"] - x_key = dimension.split(".")[1] - + x_key = self._get_dimensions_key(query) plots = "" for measure in query["measures"]: field_name = measure.split(".")[1] @@ -200,3 +220,11 @@ def _generate_hist_code(self, query): def _generate_box_code(self, query): y_key = query["measures"][0].split(".")[1] return f"plt.boxplot(data['{y_key}'])\n" + + def _get_dimensions_key(self, query): + if "dimensions" in query and len(query["dimensions"]) > 0: + return query["dimensions"][0].split(".")[1] + + time_dimension = query["timeDimensions"][0] + dimension = time_dimension["dimension"].split(".")[1] + return f"{dimension}_by_{time_dimension['granularity']}" diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py index 364dd1ab8..9d9c60189 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py @@ -1,6 +1,9 @@ from typing import Optional from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator +from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import ( + FixSemanticJsonPipeline, +) from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( SemanticPromptGeneration, @@ -40,14 +43,42 @@ def __init__( on_execution=on_prompt_generation, ), LLMCall(), - CodeGenerator(on_execution=on_code_generation), + CodeGenerator( + on_execution=on_code_generation, + on_failure=self.on_wrong_semantic_json, + ), CodeCleaning(), ], ) + self.fix_semantic_json_pipeline = FixSemanticJsonPipeline( + context=context, + logger=logger, + query_exec_tracker=query_exec_tracker, + on_code_generation=on_code_generation, + on_prompt_generation=on_prompt_generation, + ) + self._context = context self._logger = logger def run(self, input: ErrorCorrectionPipelineInput): self._logger.log(f"Executing Pipeline: {self.__class__.__name__}") return self.pipeline.run(input) + + def on_wrong_semantic_json(self, code, errors): + self.query_exec_tracker.add_step( + { + "type": "CodeGenerator", + "success": False, + "message": "Failed to validate json", + "execution_time": None, + "data": { + "content_type": "code", + "value": code, + "exception": errors, + }, + } + ) + correction_input = ErrorCorrectionPipelineInput(code, errors) + return self.fix_semantic_json_pipeline.run(correction_input) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_prompt_generation.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_prompt_generation.py deleted file mode 100644 index 4826cf78c..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_prompt_generation.py +++ /dev/null @@ -1,98 +0,0 @@ -import traceback -from typing import Any, Callable - -from pandasai.exceptions import ExecuteSQLQueryNotUsed, InvalidLLMOutputType -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( - ErrorCorrectionPipelineInput, -) -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext -from pandasai.prompts.base import BasePrompt -from pandasai.prompts.correct_error_prompt import CorrectErrorPrompt -from pandasai.prompts.correct_execute_sql_query_usage_error_prompt import ( - CorrectExecuteSQLQueryUsageErrorPrompt, -) -from pandasai.prompts.correct_output_type_error_prompt import ( - CorrectOutputTypeErrorPrompt, -) - - -class ErrorPromptGeneration(BaseLogicUnit): - on_prompt_generation: Callable[[str], None] - - def __init__( - self, - on_prompt_generation=None, - skip_if=None, - on_execution=None, - before_execution=None, - ): - self.on_prompt_generation = on_prompt_generation - super().__init__(skip_if, on_execution, before_execution) - - def execute(self, input: ErrorCorrectionPipelineInput, **kwargs) -> Any: - """ - A method to retry the code execution with error correction framework. - - Args: - code (str): A python code - context (PipelineContext) : Pipeline Context - logger (Logger) : Logger - e (Exception): An exception - dataframes - - Returns (str): A python code - """ - self.context: PipelineContext = kwargs.get("context") - self.logger: Logger = kwargs.get("logger") - e = input.exception - - prompt = self.get_prompt(e, input.code) - if self.on_prompt_generation: - self.on_prompt_generation(prompt) - - self.logger.log(f"Using prompt: {prompt}") - - return LogicUnitOutput( - prompt, - True, - "Prompt Generated Successfully", - { - "content_type": "prompt", - "value": prompt.to_string(), - }, - ) - - def get_prompt(self, e: Exception, code: str) -> BasePrompt: - """ - Return a prompt by key. - - Args: - values (dict): The values to use for the prompt - - Returns: - BasePrompt: The prompt - """ - traceback_errors = traceback.format_exc() - return ( - CorrectOutputTypeErrorPrompt( - context=self.context, - code=code, - error=traceback_errors, - output_type=self.context.get("output_type"), - ) - if isinstance(e, InvalidLLMOutputType) - else ( - CorrectExecuteSQLQueryUsageErrorPrompt( - context=self.context, code=code, error=traceback_errors - ) - if isinstance(e, ExecuteSQLQueryNotUsed) - else CorrectErrorPrompt( - context=self.context, - code=code, - error=traceback_errors, - ) - ) - ) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py new file mode 100644 index 000000000..3ec39ea40 --- /dev/null +++ b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py @@ -0,0 +1,44 @@ +from typing import Optional + +from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_schema_prompt import ( + FixSemanticSchemaPrompt, +) +from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall +from pandasai.helpers.logger import Logger +from pandasai.helpers.query_exec_tracker import QueryExecTracker +from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( + ErrorCorrectionPipelineInput, +) +from pandasai.pipelines.pipeline import Pipeline +from pandasai.pipelines.pipeline_context import PipelineContext + + +class FixSemanticJsonPipeline: + """ + Error Correction Pipeline to regenerate prompt and code + """ + + _context: PipelineContext + _logger: Logger + + def __init__( + self, + context: Optional[PipelineContext] = None, + logger: Optional[Logger] = None, + query_exec_tracker: QueryExecTracker = None, + on_prompt_generation=None, + on_code_generation=None, + ): + self.pipeline = Pipeline( + context=context, + logger=logger, + query_exec_tracker=query_exec_tracker, + steps=[FixSemanticSchemaPrompt(), LLMCall()], + ) + + self._context = context + self._logger = logger + + def run(self, input: ErrorCorrectionPipelineInput): + self._logger.log(f"Executing Pipeline: {self.__class__.__name__}") + return self.pipeline.run(input) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_schema_prompt.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_schema_prompt.py new file mode 100644 index 000000000..e7e5425d9 --- /dev/null +++ b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_schema_prompt.py @@ -0,0 +1,61 @@ +import json +from typing import Any, Callable + +from pandasai.ee.agents.semantic_agent.prompts.fix_semantic_json import ( + FixSemanticJsonPrompt, +) +from pandasai.helpers.logger import Logger +from pandasai.pipelines.base_logic_unit import BaseLogicUnit +from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( + ErrorCorrectionPipelineInput, +) +from pandasai.pipelines.logic_unit_output import LogicUnitOutput +from pandasai.pipelines.pipeline_context import PipelineContext + + +class FixSemanticSchemaPrompt(BaseLogicUnit): + on_prompt_generation: Callable[[str], None] + + def __init__( + self, + on_prompt_generation=None, + skip_if=None, + on_execution=None, + before_execution=None, + ): + self.on_prompt_generation = on_prompt_generation + super().__init__(skip_if, on_execution, before_execution) + + def execute(self, input: ErrorCorrectionPipelineInput, **kwargs) -> Any: + """ + A method to retry the code execution with error correction framework. + + Args: + code (str): A python code + context (PipelineContext) : Pipeline Context + logger (Logger) : Logger + e (Exception): An exception + dataframes + + Returns (str): A python code + """ + self.context: PipelineContext = kwargs.get("context") + self.logger: Logger = kwargs.get("logger") + + prompt = FixSemanticJsonPrompt( + context=self.context, + generated_json=input.code, + error=input.exception, + schema=json.dumps(self.context.get("df_schema")), + ) + self.logger.log(f"Using prompt: {prompt}") + + return LogicUnitOutput( + prompt, + True, + "Prompt Generated Successfully", + { + "content_type": "prompt", + "value": prompt.to_string(), + }, + ) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py index b511640b2..14e9ea870 100644 --- a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py +++ b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py @@ -4,6 +4,9 @@ from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.error_correction_pipeline import ( ErrorCorrectionPipeline, ) +from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import ( + FixSemanticJsonPipeline, +) from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( SemanticPromptGeneration, @@ -18,6 +21,9 @@ from pandasai.pipelines.chat.cache_lookup import CacheLookup from pandasai.pipelines.chat.code_cleaning import CodeCleaning from pandasai.pipelines.chat.code_execution import CodeExecution +from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( + ErrorCorrectionPipelineInput, +) from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.pipelines.chat.result_validation import ResultValidation from pandasai.pipelines.pipeline import Pipeline @@ -61,7 +67,10 @@ def __init__( on_execution=on_prompt_generation, ), LLMCall(), - CodeGenerator(on_execution=on_code_generation), + CodeGenerator( + on_execution=on_code_generation, + on_failure=self.on_wrong_semantic_json, + ), CodeCleaning( skip_if=self.no_code, on_failure=self.on_code_cleaning_failure, @@ -95,6 +104,31 @@ def __init__( on_prompt_generation=on_prompt_generation, ) + self.fix_semantic_json_pipeline = FixSemanticJsonPipeline( + context=context, + logger=logger, + query_exec_tracker=self.query_exec_tracker, + on_code_generation=on_code_generation, + on_prompt_generation=on_prompt_generation, + ) + self.context = context self._logger = logger self.last_error = None + + def on_wrong_semantic_json(self, code, errors): + self.query_exec_tracker.add_step( + { + "type": "CodeGenerator", + "success": False, + "message": "Failed to validate json", + "execution_time": None, + "data": { + "content_type": "code", + "value": code, + "exception": errors, + }, + } + ) + correction_input = ErrorCorrectionPipelineInput(code, errors) + return self.fix_semantic_json_pipeline.run(correction_input) diff --git a/pandasai/ee/agents/semantic_agent/prompts/fix_semantic_json.py b/pandasai/ee/agents/semantic_agent/prompts/fix_semantic_json.py new file mode 100644 index 000000000..b027eb7f1 --- /dev/null +++ b/pandasai/ee/agents/semantic_agent/prompts/fix_semantic_json.py @@ -0,0 +1,39 @@ +from pathlib import Path + +from jinja2 import Environment, FileSystemLoader + +from pandasai.prompts.base import BasePrompt + + +class FixSemanticJsonPrompt(BasePrompt): + """Prompt to generate Python code from a dataframe.""" + + template_path = "fix_semantic_json_prompt.tmpl" + + def __init__(self, **kwargs): + """Initialize the prompt.""" + self.props = kwargs + + if self.template: + env = Environment() + self.prompt = env.from_string(self.template) + elif self.template_path: + # find path to template file + current_dir_path = Path(__file__).parent + + path_to_template = current_dir_path / "templates" + env = Environment(loader=FileSystemLoader(path_to_template)) + self.prompt = env.get_template(self.template_path) + + self._resolved_prompt = None + + def to_json(self): + context = self.props["context"] + memory = context.memory + conversations = memory.to_json() + system_prompt = memory.get_system_prompt() + return { + "conversation": conversations, + "system_prompt": system_prompt, + "prompt": self.to_string(), + } diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/fix_semantic_json_prompt.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/fix_semantic_json_prompt.tmpl new file mode 100644 index 000000000..b973c53df --- /dev/null +++ b/pandasai/ee/agents/semantic_agent/prompts/templates/fix_semantic_json_prompt.tmpl @@ -0,0 +1,13 @@ +=== SemanticAgent === +The user asked the following question: +{{context.memory.get_conversation()}} +# SCHEMA +{{schema}} + +You generated this Json: +{{generated_json}} + +It fails with the following error: +{{error}} + +Understand the error in json return the fixed json \ No newline at end of file diff --git a/pandasai/ee/helpers/query_builder.py b/pandasai/ee/helpers/query_builder.py index 48a0d761f..a3c0ba85c 100644 --- a/pandasai/ee/helpers/query_builder.py +++ b/pandasai/ee/helpers/query_builder.py @@ -1,5 +1,11 @@ import re +from pandasai.exceptions import InvalidSchemaJson + + +MISSING_TABLE_NAME_MESSAGE = "All measures, dimensions, timeDimensions, order and filters must have the format Table_Name.Dimension or Table_Name.Measure" +TABLE_NOT_FOUND_MESSAGE = "Table {0} Doesn't exist" + class QueryBuilder: """ @@ -28,6 +34,7 @@ def __init__(self, schema): } def generate_sql(self, query): + self._validate_query(query) measures = query.get("measures", []) dimensions = query.get("dimensions", []) time_dimensions = query.get("timeDimensions", []) @@ -54,10 +61,60 @@ def generate_sql(self, query): return sql + def _validate_table(self, value: str): + value_splitted = value.split(".") + if len(value_splitted) == 1: + raise InvalidSchemaJson(MISSING_TABLE_NAME_MESSAGE) + + table = self.find_table(value_splitted[0]) + if not table: + raise InvalidSchemaJson(TABLE_NOT_FOUND_MESSAGE.format(value_splitted[0])) + + def _validate_query(self, query: dict): + for measure in query.get("measures", []): + self._validate_table(measure) + + for dimension in query.get("dimensions", []): + self._validate_table(dimension) + + for dimension in query.get("timeDimensions", []): + self._validate_table(dimension["dimension"]) + + for order in query.get("order", []): + self._validate_table(order["id"]) + + for filter in query.get("filters", []): + self._validate_table(filter["member"]) + + def _validate_fix_query(self, query): + for index, measure in enumerate(query.get("measures", [])): + query["measures"][index] = self._validate_and_fix_mapped_measure(measure) + + for index, dimension in enumerate(query.get("dimensions", [])): + query["dimensions"][index] = self._validate_and_fix_mapped_dimension( + dimension + ) + + for index, dimension in enumerate(query.get("timeDimensions", [])): + query["timeDimensions"][index][ + "dimension" + ] = self._validate_and_fix_mapped_dimension(dimension["dimension"]) + + for index, order in enumerate(query.get("order", [])): + query["order"][index]["id"] = self._validate_and_fix_mapped_order( + order["id"] + ) + + for index, filter in enumerate(query.get("filters", [])): + query["filters"][index]["member"] = self._validate_and_fix_mapped_order( + filter["member"] + ) + + return query + def _generate_columns(self, dimensions, time_dimensions, measures): - all_dimensions = list( - dict.fromkeys(dimensions + [td["dimension"] for td in time_dimensions]) - ) + all_dimensions = list(dict.fromkeys(dimensions)) + # + [td["dimension"] for td in time_dimensions] columns = [] for dim in all_dimensions: @@ -87,9 +144,111 @@ def _generate_columns(self, dimensions, time_dimensions, measures): return list(dict.fromkeys(columns)) # preserve order and return unique columns + def _validate_and_fix_mapped_measure(self, value): + value_splitted = value.split(".") + if len(value_splitted) == 1: + table_name = self._find_table_name_in_measure_if_not_exists( + value_splitted[0] + ) + if table_name is None: + raise ValueError( + "Measure must have table expected format is TableName.measure" + ) + return f"{table_name}.{value_splitted[0]}" + return value + + def _validate_and_fix_mapped_dimension(self, value): + value_splitted = value.split(".") + if len(value_splitted) == 1: + table_name = self._find_table_name_in_dimension_if_not_exists( + value_splitted[0] + ) + if table_name is None: + raise ValueError( + "Measure must have table expected format is TableName.measure" + ) + return f"{table_name}.{value_splitted[0]}" + return value + + def _validate_and_fix_mapped_order(self, value): + value_splitted = value.split(".") + if len(value_splitted) == 1: + table_name = self._find_table_name_in_orders_if_not_exists( + value_splitted[0] + ) + if table_name is None: + raise ValueError( + "Measure must have table expected format is TableName.measure" + ) + return f"{table_name}.{value_splitted[0]}" + return value + + def _validate_and_fix_mapped_filter(self, value): + value_splitted = value.split(".") + if len(value_splitted) == 1: + table_name = self._find_table_name_in_filter_if_not_exists( + value_splitted[0] + ) + if table_name is None: + raise ValueError( + "Measure must have table expected format is TableName.measure" + ) + return f"{table_name}.{value_splitted[0]}" + return value + + def _find_table_name_in_filter_if_not_exists(self, filter_name: str): + """ + Find and add table name if not exists in Measure + """ + for table in self.schema: + for dimension in table["dimensions"]: + if dimension["name"] == filter_name: + return table["name"] + + return None + + def _find_table_name_in_measure_if_not_exists(self, measure_name: str): + """ + Find and add table name if not exists in Measure + """ + for table in self.schema: + for measure in table["measures"]: + if measure["name"] == measure_name: + return table["name"] + + return None + + def _find_table_name_in_dimension_if_not_exists(self, dimension_name: str): + """ + Find and add table name if not exists in Measure + """ + for table in self.schema: + for dimension in table["dimensions"]: + if dimension["name"] == dimension_name: + return table["name"] + + return None + + def _find_table_name_in_orders_if_not_exists(self, dimension_name: str): + """ + Find and add table name if not exists in Measure + """ + for table in self.schema: + for dimension in table["dimensions"]: + if dimension["name"] == dimension_name: + return table["name"] + + for measure in table["measures"]: + if measure["name"] == dimension_name: + return table["name"] + + return None + def _generate_time_dimension_column(self, time_dimension): dimension = time_dimension["dimension"] - granularity = time_dimension["granularity"] + granularity = ( + time_dimension["granularity"] if "granularity" in time_dimension else "day" + ) if granularity not in self.supported_granularities: raise ValueError( @@ -183,7 +342,7 @@ def _build_group_by_clause(self, dimensions, time_dimensions): group_by_dimensions = [ self.find_dimension(dim)["name"] for dim in dimensions ] + [ - f"{self.find_dimension(td['dimension'])['name']}_by_{td['granularity']}" + f"{self.find_dimension(td['dimension'])['name']}_by_{td.get('granularity', 'day')}" for td in time_dimensions ] @@ -202,11 +361,36 @@ def _build_order_clause(self, query): if "order" not in query or len(query["order"]) == 0: return "" - order = query["order"][0] - name = (self.find_measure(order["id"]) or self.find_dimension(order["id"]))[ - "name" - ] - return f" ORDER BY {name} {order['direction']}" + order_clauses = [] + for order in query["order"]: + name = None + if measure := self.find_measure(order["id"]): + name = measure["name"] + + if ( + name is None + and "timeDimensions" in query + and len(query["timeDimensions"]) > 0 + ): + for time_dimension in query["timeDimensions"]: + if ( + dimension + := f"{self.find_dimension(order['id'])['name']}_by_{time_dimension['granularity']}" + ): + name = dimension + + if name is None and "dimensions" in query and len(query["dimensions"]) > 0: + if dimension := self.find_dimension(order["id"]): + name = dimension["name"] + + if name is None: + name = ( + self.find_measure(order["id"]) or self.find_dimension(order["id"]) + )["name"] + + order_clauses.append(f"{name} {order['direction']}") + + return f" ORDER BY {', '.join(order_clauses)}" def _build_limit_clause(self, query): return f" LIMIT {query['limit']}" if "limit" in query else "" @@ -223,15 +407,14 @@ def resolve_date_range(self, time_dimension): table_column = f"`{table['table']}`.`{dimension_info['sql']}`" - if isinstance(date_range, list): - if len(date_range) != 2: - raise ValueError( - "Invalid date range. It should contain exactly two dates." - ) + if isinstance(date_range, list) and len(date_range) == 2: start_date, end_date = date_range return f"{table_column} BETWEEN '{start_date}' AND '{end_date}'" else: - if date_range not in self.supportedDateRanges: + if isinstance(date_range, list) and len(date_range) == 1: + date_range = date_range[0] + + if date_range not in self.supported_date_ranges: raise ValueError(f"Unsupported date range: {date_range}") if date_range == "last week": @@ -286,6 +469,7 @@ def process_filter(self, filter): "lte": "<=", "beforeDate": "<", "afterDate": ">", + "in": "IN", } multi_value_operators = {"equals": "IN", "notEquals": "NOT IN"} @@ -307,12 +491,12 @@ def _build_query_condition( multi_value_operators, ): if operator in single_value_operators: - if operator in ["equals", "notEquals"]: + if operator in ["equals", "notEquals", "in"]: if len(values) == 1: operator_str = "=" if operator == "equals" else "!=" return f"{table_column} {operator_str} '{values[0]}'" else: - operator_str = "IN" if operator == "equals" else "NOT IN" + operator_str = "IN" if operator in ["equals", "in"] else "NOT IN" formatted_values = "', '".join(values) return f"{table_column} {operator_str} ('{formatted_values}')" diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index ef8be33a7..b0e6ebfa6 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -254,3 +254,11 @@ class InvalidTrainJson(Exception): Args: Exception (Exception): Invalid train json """ + + +class InvalidSchemaJson(Exception): + """ + Raise error if schema json is not correct + Args: + Exception (Exception): Invalid json schema + """ diff --git a/poetry.lock b/poetry.lock index 0d65e8280..b879dc8d1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -5314,6 +5314,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -6048,30 +6049,51 @@ description = "Database Abstraction Library" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ + {file = "SQLAlchemy-1.4.50-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:54138aa80d2dedd364f4e8220eef284c364d3270aaef621570aa2bd99902e2e8"}, {file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d00665725063692c42badfd521d0c4392e83c6c826795d38eb88fb108e5660e5"}, {file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85292ff52ddf85a39367057c3d7968a12ee1fb84565331a36a8fead346f08796"}, {file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d0fed0f791d78e7767c2db28d34068649dfeea027b83ed18c45a423f741425cb"}, {file = "SQLAlchemy-1.4.50-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db4db3c08ffbb18582f856545f058a7a5e4ab6f17f75795ca90b3c38ee0a8ba4"}, + {file = "SQLAlchemy-1.4.50-cp310-cp310-win32.whl", hash = "sha256:6c78e3fb4a58e900ec433b6b5f4efe1a0bf81bbb366ae7761c6e0051dd310ee3"}, + {file = "SQLAlchemy-1.4.50-cp310-cp310-win_amd64.whl", hash = "sha256:d55f7a33e8631e15af1b9e67c9387c894fedf6deb1a19f94be8731263c51d515"}, + {file = "SQLAlchemy-1.4.50-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:324b1fdd50e960a93a231abb11d7e0f227989a371e3b9bd4f1259920f15d0304"}, {file = "SQLAlchemy-1.4.50-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14b0cacdc8a4759a1e1bd47dc3ee3f5db997129eb091330beda1da5a0e9e5bd7"}, {file = "SQLAlchemy-1.4.50-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1fb9cb60e0f33040e4f4681e6658a7eb03b5cb4643284172f91410d8c493dace"}, + {file = "SQLAlchemy-1.4.50-cp311-cp311-win32.whl", hash = "sha256:8bdab03ff34fc91bfab005e96f672ae207d87e0ac7ee716d74e87e7046079d8b"}, + {file = "SQLAlchemy-1.4.50-cp311-cp311-win_amd64.whl", hash = "sha256:52e01d60b06f03b0a5fc303c8aada405729cbc91a56a64cead8cb7c0b9b13c1a"}, + {file = "SQLAlchemy-1.4.50-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:77fde9bf74f4659864c8e26ac08add8b084e479b9a18388e7db377afc391f926"}, {file = "SQLAlchemy-1.4.50-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4cb501d585aa74a0f86d0ea6263b9c5e1d1463f8f9071392477fd401bd3c7cc"}, {file = "SQLAlchemy-1.4.50-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a7a66297e46f85a04d68981917c75723e377d2e0599d15fbe7a56abed5e2d75"}, + {file = "SQLAlchemy-1.4.50-cp312-cp312-win32.whl", hash = "sha256:e86c920b7d362cfa078c8b40e7765cbc34efb44c1007d7557920be9ddf138ec7"}, + {file = "SQLAlchemy-1.4.50-cp312-cp312-win_amd64.whl", hash = "sha256:6b3df20fbbcbcd1c1d43f49ccf3eefb370499088ca251ded632b8cbaee1d497d"}, + {file = "SQLAlchemy-1.4.50-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:fb9adc4c6752d62c6078c107d23327aa3023ef737938d0135ece8ffb67d07030"}, {file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1db0221cb26d66294f4ca18c533e427211673ab86c1fbaca8d6d9ff78654293"}, {file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7dbe6369677a2bea68fe9812c6e4bbca06ebfa4b5cde257b2b0bf208709131"}, {file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a9bddb60566dc45c57fd0a5e14dd2d9e5f106d2241e0a2dc0c1da144f9444516"}, {file = "SQLAlchemy-1.4.50-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82dd4131d88395df7c318eeeef367ec768c2a6fe5bd69423f7720c4edb79473c"}, + {file = "SQLAlchemy-1.4.50-cp36-cp36m-win32.whl", hash = "sha256:1b9c4359d3198f341480e57494471201e736de459452caaacf6faa1aca852bd8"}, + {file = "SQLAlchemy-1.4.50-cp36-cp36m-win_amd64.whl", hash = "sha256:35e4520f7c33c77f2636a1e860e4f8cafaac84b0b44abe5de4c6c8890b6aaa6d"}, + {file = "SQLAlchemy-1.4.50-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:f5b1fb2943d13aba17795a770d22a2ec2214fc65cff46c487790192dda3a3ee7"}, {file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:273505fcad22e58cc67329cefab2e436006fc68e3c5423056ee0513e6523268a"}, {file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3257a6e09626d32b28a0c5b4f1a97bced585e319cfa90b417f9ab0f6145c33c"}, {file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d69738d582e3a24125f0c246ed8d712b03bd21e148268421e4a4d09c34f521a5"}, {file = "SQLAlchemy-1.4.50-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34e1c5d9cd3e6bf3d1ce56971c62a40c06bfc02861728f368dcfec8aeedb2814"}, + {file = "SQLAlchemy-1.4.50-cp37-cp37m-win32.whl", hash = "sha256:7b4396452273aedda447e5aebe68077aa7516abf3b3f48408793e771d696f397"}, + {file = "SQLAlchemy-1.4.50-cp37-cp37m-win_amd64.whl", hash = "sha256:752f9df3dddbacb5f42d8405b2d5885675a93501eb5f86b88f2e47a839cf6337"}, + {file = "SQLAlchemy-1.4.50-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:35c7ed095a4b17dbc8813a2bfb38b5998318439da8e6db10a804df855e3a9e3a"}, {file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1fcee5a2c859eecb4ed179edac5ffbc7c84ab09a5420219078ccc6edda45436"}, {file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbaf6643a604aa17e7a7afd74f665f9db882df5c297bdd86c38368f2c471f37d"}, {file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2e70e0673d7d12fa6cd363453a0d22dac0d9978500aa6b46aa96e22690a55eab"}, {file = "SQLAlchemy-1.4.50-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b881ac07d15fb3e4f68c5a67aa5cdaf9eb8f09eb5545aaf4b0a5f5f4659be18"}, + {file = "SQLAlchemy-1.4.50-cp38-cp38-win32.whl", hash = "sha256:8a219688297ee5e887a93ce4679c87a60da4a5ce62b7cb4ee03d47e9e767f558"}, + {file = "SQLAlchemy-1.4.50-cp38-cp38-win_amd64.whl", hash = "sha256:a648770db002452703b729bdcf7d194e904aa4092b9a4d6ab185b48d13252f63"}, + {file = "SQLAlchemy-1.4.50-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:4be4da121d297ce81e1ba745a0a0521c6cf8704634d7b520e350dce5964c71ac"}, {file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f6997da81114daef9203d30aabfa6b218a577fc2bd797c795c9c88c9eb78d49"}, {file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdb77e1789e7596b77fd48d99ec1d2108c3349abd20227eea0d48d3f8cf398d9"}, {file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:128a948bd40780667114b0297e2cc6d657b71effa942e0a368d8cc24293febb3"}, {file = "SQLAlchemy-1.4.50-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2d526aeea1bd6a442abc7c9b4b00386fd70253b80d54a0930c0a216230a35be"}, + {file = "SQLAlchemy-1.4.50-cp39-cp39-win32.whl", hash = "sha256:a7c9b9dca64036008962dd6b0d9fdab2dfdbf96c82f74dbd5d86006d8d24a30f"}, + {file = "SQLAlchemy-1.4.50-cp39-cp39-win_amd64.whl", hash = "sha256:df200762efbd672f7621b253721644642ff04a6ff957236e0e2fe56d9ca34d2c"}, {file = "SQLAlchemy-1.4.50.tar.gz", hash = "sha256:3b97ddf509fc21e10b09403b5219b06c5b558b27fc2453150274fa4e70707dbf"}, ] @@ -6154,6 +6176,94 @@ databricks-sql-connector = ">=2,<3" PyHive = ">=0,<1" SQLAlchemy = ">=1,<2" +[[package]] +name = "sqlglot" +version = "25.0.3" +description = "An easily customizable SQL parser and transpiler" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sqlglot-25.0.3-py3-none-any.whl", hash = "sha256:810dedc451e2d4e947effe50eeb6a5e85a6f624053086bb77d0228a45b78e0b7"}, + {file = "sqlglot-25.0.3.tar.gz", hash = "sha256:84bddaf24e28d761ea5d3857de2bbb9515f1482799189de54f1c1e8b1bc6147c"}, +] + +[package.dependencies] +sqlglotrs = {version = "0.2.5", optional = true, markers = "extra == \"rs\""} + +[package.extras] +dev = ["duckdb (>=0.6)", "maturin (>=1.4,<2.0)", "mypy", "pandas", "pandas-stubs", "pdoc", "pre-commit", "python-dateutil", "ruff (==0.4.3)", "types-python-dateutil", "typing-extensions"] +rs = ["sqlglotrs (==0.2.5)"] + +[[package]] +name = "sqlglotrs" +version = "0.2.5" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sqlglotrs-0.2.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2acbae527bde996379a2d686b8ba59fe1020763aa9a8edf729ce75fe323cdc4d"}, + {file = "sqlglotrs-0.2.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6aa56bcaf6e2a5938406364ab8b99871d919adf1f0e2ec2e7f4649c9d3f4d7a7"}, + {file = "sqlglotrs-0.2.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8ff9eee2925aad22b177236e9c4ca2602edce77406688fdc00aeb51184800d8"}, + {file = "sqlglotrs-0.2.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1d7a467638bc85c603fae00f63c3f7df36db4091ab6127217499c8fbaef950bf"}, + {file = "sqlglotrs-0.2.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38887a78fc4c5884512f543eea390df0e701ff76e4bf81cc5a0cd11fe05415be"}, + {file = "sqlglotrs-0.2.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf15597e1ee5e0a7b2af4cedbca605a4a74f16efd1b6efbcec01aa3aa0dfb952"}, + {file = "sqlglotrs-0.2.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc89bf7cd6bf0273d63feea21fd1f232049100d9826c4102d78c391b3749d068"}, + {file = "sqlglotrs-0.2.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1cf01b88c1592e71d40d5804035003198bdf0d86f420e351deec67b4c8e9d252"}, + {file = "sqlglotrs-0.2.5-cp310-none-win32.whl", hash = "sha256:4bd1179d7b2ecccd7758d94a958462c8a96e7c2353743ef540f838ee267728e0"}, + {file = "sqlglotrs-0.2.5-cp310-none-win_amd64.whl", hash = "sha256:638de506ff1aec4bb60663a8ae14a19eeef9213ca87fdd3fd04d442ba213d6b2"}, + {file = "sqlglotrs-0.2.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:9a9b1733d9f4a8200150adcb20bc1439ae7c605fbd84441255e1bcd1758179ed"}, + {file = "sqlglotrs-0.2.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:15772fa7b1f785390526e4e553bdaadd9a7e9b6d16460f51a44f0d0ae23ac058"}, + {file = "sqlglotrs-0.2.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae429d9c634523fbb2942c6a39ec17b4de66de6385123c7c6e6c2bb5028b64c6"}, + {file = "sqlglotrs-0.2.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:15410c109d0df94e6629302585388e9e155b84887b831df8560ac748eb0bf3cf"}, + {file = "sqlglotrs-0.2.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dfd4b8ec6a8e74072638e0d3f2940db6d46ff8b5394d028c3d9f7883e78dca54"}, + {file = "sqlglotrs-0.2.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:485066352ab78dbfc124c2e9a1b64e9afbdcce94f2da8fc792e0fc1f85b48558"}, + {file = "sqlglotrs-0.2.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c7ec211fb0a71ee31cb06ec692cbcd091518ca6d8c747deaa3c41f635f3785e"}, + {file = "sqlglotrs-0.2.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:142bdda9331ddf5b0f816e8f866f77309b7e8a9ab3a176e762f223b1c7be4400"}, + {file = "sqlglotrs-0.2.5-cp311-none-win32.whl", hash = "sha256:8a08f7f4f996a7eb7e22a26762edc6ea1aac52b0ffe38b656ec82b5db79c1a8b"}, + {file = "sqlglotrs-0.2.5-cp311-none-win_amd64.whl", hash = "sha256:f8ca6ebcee9e5a9f22ec0a70127fe886c72d17c56c77aa991333d19dc04c95c2"}, + {file = "sqlglotrs-0.2.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3ef969600fab89c7a779c77fc9bb2ac3e167e9080b558b4594c0bd496f0789b3"}, + {file = "sqlglotrs-0.2.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1883817065fbbbd2a4aec83e636513164f149b5908016706285536d0579892ad"}, + {file = "sqlglotrs-0.2.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eca6c7646f70b55cef024ae6c22f600d53af0c313c1a424f5dbd79357e80cedb"}, + {file = "sqlglotrs-0.2.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8499c90cf665a794715e6aedd8ff7c54da6dfaa28c0c6b25b2f7204e116c9743"}, + {file = "sqlglotrs-0.2.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81ed3f0f09f05b221c023b6e8b78edfaf92d6fdbe5d126bda0b36365e72e2ba1"}, + {file = "sqlglotrs-0.2.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:03810fb4d6bb2be12366d99cff2065debd9d55c633a3c8217395460be1f96ae6"}, + {file = "sqlglotrs-0.2.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ed01ecb2cf026b01b7c2a4ee34884733768e54f50d538d326aaa4d4430f8b32"}, + {file = "sqlglotrs-0.2.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1c741aa802440749db9ce8dda7109d407daea018dec671ee7329d54134f040d7"}, + {file = "sqlglotrs-0.2.5-cp312-none-win32.whl", hash = "sha256:87389ca6b0cccaa0e284ebe6c0da7436a797b2f11bc2e24e7c90a11108a1c2b3"}, + {file = "sqlglotrs-0.2.5-cp312-none-win_amd64.whl", hash = "sha256:7e16ae96a7ec89d8159d1a1127ec31c7786f4801acb5480a885cdd669f0165bb"}, + {file = "sqlglotrs-0.2.5-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:1c3b2b33b30ec54ef8ed412cb880b33267f5b4dd0368d3cb515fbbe26c130b3e"}, + {file = "sqlglotrs-0.2.5-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:fcbbd976f61edc8ca788b041befdfb902f9b7fd202423e5b3f743ff0aab133b9"}, + {file = "sqlglotrs-0.2.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6dd3fdadffcffef79a7b044bb2124b73d29ea734c0a82c72a230dcaa7f15c224"}, + {file = "sqlglotrs-0.2.5-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac6ab16fb8d453f8dd2ec9452f1c0c6f0d2350812170a34a5c3343283374156b"}, + {file = "sqlglotrs-0.2.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de112dd19ec7ef478a2ffb9a0fa48b9506ea8544254d8a999cf820efb432f3c8"}, + {file = "sqlglotrs-0.2.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5af93bb4429c8bee14e31e482fe26a313cf3b18fe033bbdbfc064faaf4dc9eb2"}, + {file = "sqlglotrs-0.2.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c65c43ebd5c5c381edc27be1c08f690069141e12e637ad188896e5a036587c7"}, + {file = "sqlglotrs-0.2.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:135fc0d10bdb9324d5ba5766dc97f0b3f033cee9aea5c68d5c3cb1c1a1d1ff22"}, + {file = "sqlglotrs-0.2.5-cp37-none-win32.whl", hash = "sha256:b27242fead5aeefa5e57ab8afa1ebbd1ea9b765440b795c426b2dabac88ad732"}, + {file = "sqlglotrs-0.2.5-cp37-none-win_amd64.whl", hash = "sha256:72be938208de4b0cb8830316e1cae432b3aac6a6d8b7610b306c7cdb1a228a74"}, + {file = "sqlglotrs-0.2.5-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:ee6e0a5dc335bbec38bf7a051f3ad8fd34994d314198449e853926af35d4f03c"}, + {file = "sqlglotrs-0.2.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d2a0f28dbc60df24c6cd8ce882d6e8c16601d5e002e83fa2243ad32a89ca222"}, + {file = "sqlglotrs-0.2.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e36caa661ad7e6e83910450b6be7ae7b5e54be4fcdd575f7f1833d50df78dda9"}, + {file = "sqlglotrs-0.2.5-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1a849ae0a4ef6b65ecb8db651aa720d46577c7f4ed619df279abdf074346fe9b"}, + {file = "sqlglotrs-0.2.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:564875e5c1f503d298a214f89cd9dbaae314c307ce830893566e30c02c7ee14d"}, + {file = "sqlglotrs-0.2.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35359fea9e08c37c696ff6004468982750a3ab5d59491ab8253ecc13cc93f180"}, + {file = "sqlglotrs-0.2.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93e4dcea915fddb803004a411370966e974b25cb728cb3523205d9d2d92301e6"}, + {file = "sqlglotrs-0.2.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f6328501aa9924ccaf9cfb335b510b1796bee11dd74dd8349b30446585e298d7"}, + {file = "sqlglotrs-0.2.5-cp38-none-win32.whl", hash = "sha256:4421ab8097c5fa8999222f1d92fbbf8bdef0f841ef5f04b315050e4de378d19b"}, + {file = "sqlglotrs-0.2.5-cp38-none-win_amd64.whl", hash = "sha256:dfa80d1bc147817a02105ee7f6bd12e7f698af91e76788d02f6fcd209f6c8f33"}, + {file = "sqlglotrs-0.2.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b1fdcad5c279efa65ffd8c0db55ddca8d1fe8a196d225292fa545d0bb7a6ccd4"}, + {file = "sqlglotrs-0.2.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bdf11a8fa5198c2acf42024d3fa6da26e75a997f5b044e0c5ec10e4fa5d431c7"}, + {file = "sqlglotrs-0.2.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2d7ef75407e7096104a9f611c3512d06d313b596988522c46d89570a58da265"}, + {file = "sqlglotrs-0.2.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d5304581f4c4da0e6c03500bd9f1e2fe8bce81f71117cdc0bdc32cd2b072c562"}, + {file = "sqlglotrs-0.2.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ff7398ce89398578b2cb8fa55b2c036510a02e7050bbd541fb16a3a1eb575fb1"}, + {file = "sqlglotrs-0.2.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f359d3b2c74d48bfa2e8ee7ab86865462aa526307241d0ede230143c7a8c566"}, + {file = "sqlglotrs-0.2.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cca1586871b0b2056136150fab81d7d7b3f028d2e3de5c4a615ad80d43b102eb"}, + {file = "sqlglotrs-0.2.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e12663dae6bd90841c7492a351422901ac91c766b784e232d4423aabf6602223"}, + {file = "sqlglotrs-0.2.5-cp39-none-win32.whl", hash = "sha256:785d81d4712fbe949f60059e878048ff5561580e3b4aa26aecc864bd33ed8c0a"}, + {file = "sqlglotrs-0.2.5-cp39-none-win_amd64.whl", hash = "sha256:ab5ef66d6bce6f6a6394a6e303f9eb792a6283558dcbf9550a4f7701c35d4809"}, + {file = "sqlglotrs-0.2.5.tar.gz", hash = "sha256:e0d1f4b1672ba2574600c5bfa02ab9a21512b4f4e56f73ff52e4eff41f0f6898"}, +] + [[package]] name = "starlette" version = "0.35.1" @@ -6183,12 +6293,20 @@ files = [ {file = "statsmodels-0.14.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5a6a0a1a06ff79be8aa89c8494b33903442859add133f0dda1daf37c3c71682e"}, {file = "statsmodels-0.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77b3cd3a5268ef966a0a08582c591bd29c09c88b4566c892a7c087935234f285"}, {file = "statsmodels-0.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c64ebe9cf376cba0c31aed138e15ed179a1d128612dd241cdf299d159e5e882"}, + {file = "statsmodels-0.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:229b2f676b4a45cb62d132a105c9c06ca8a09ffba060abe34935391eb5d9ba87"}, {file = "statsmodels-0.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb471f757fc45102a87e5d86e87dc2c8c78b34ad4f203679a46520f1d863b9da"}, {file = "statsmodels-0.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:582f9e41092e342aaa04920d17cc3f97240e3ee198672f194719b5a3d08657d6"}, {file = "statsmodels-0.14.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7ebe885ccaa64b4bc5ad49ac781c246e7a594b491f08ab4cfd5aa456c363a6f6"}, {file = "statsmodels-0.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b587ee5d23369a0e881da6e37f78371dce4238cf7638a455db4b633a1a1c62d6"}, {file = "statsmodels-0.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ef7fa4813c7a73b0d8a0c830250f021c102c71c95e9fe0d6877bcfb56d38b8c"}, + {file = "statsmodels-0.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afe80544ef46730ea1b11cc655da27038bbaa7159dc5af4bc35bbc32982262f2"}, {file = "statsmodels-0.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:a6ad7b8aadccd4e4dd7f315a07bef1bca41d194eeaf4ec600d20dea02d242fce"}, + {file = "statsmodels-0.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0eea4a0b761aebf0c355b726ac5616b9a8b618bd6e81a96b9f998a61f4fd7484"}, + {file = "statsmodels-0.14.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4c815ce7a699047727c65a7c179bff4031cff9ae90c78ca730cfd5200eb025dd"}, + {file = "statsmodels-0.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575f61337c8e406ae5fa074d34bc6eb77b5a57c544b2d4ee9bc3da6a0a084cf1"}, + {file = "statsmodels-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8be53cdeb82f49c4cb0fda6d7eeeb2d67dbd50179b3e1033510e061863720d93"}, + {file = "statsmodels-0.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:6f7d762df4e04d1dde8127d07e91aff230eae643aa7078543e60e83e7d5b40db"}, + {file = "statsmodels-0.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:fc2c7931008a911e3060c77ea8933f63f7367c0f3af04f82db3a04808ad2cd2c"}, {file = "statsmodels-0.14.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3757542c95247e4ab025291a740efa5da91dc11a05990c033d40fce31c450dc9"}, {file = "statsmodels-0.14.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:de489e3ed315bdba55c9d1554a2e89faa65d212e365ab81bc323fa52681fc60e"}, {file = "statsmodels-0.14.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76e290f4718177bffa8823a780f3b882d56dd64ad1c18cfb4bc8b5558f3f5757"}, @@ -7265,4 +7383,4 @@ yfinance = ["yfinance"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.9.7 || >3.9.7,<4.0" -content-hash = "16c678bbdeb8bda34dad4558b16b0b8f71006e6c224bd6f3b162714114eebec0" +content-hash = "9117af78fd823a070e13aba67984d556d1056fbfecbd04bdbe74c3e863a6857c" diff --git a/pyproject.toml b/pyproject.toml index 1326981d2..0bd6db9cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ qdrant-client = {version = "^1.8.0", extras = ["fastembed"], optional = true } ibm-watsonx-ai = { version = "^0.2.3", optional = true, markers = "python_version >= '3.10'"} cx-Oracle = { version = "^8.3.0", optional = true } pinecone-client = { version = "^4.1.0", optional = true, markers = "python_version >= '3.10'"} +sqlglot = {extras = ["rs"], version = "^25.0.3"} [tool.poetry.group.dev] optional = true diff --git a/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py b/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py index 54a830400..c50d26752 100644 --- a/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py +++ b/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py @@ -111,9 +111,9 @@ def test_generate_matplolib_par_code( data = execute_sql_query(sql_query) plt.bar(data["ship_country"], data["order_count"], label="order_count") -plt.xlabel('Country') -plt.ylabel('Number of Orders') -plt.title('Orders Count by Country') +plt.xlabel('''Country''') +plt.ylabel('''Number of Orders''') +plt.title('''Orders Count by Country''') plt.legend(loc='best') @@ -151,7 +151,7 @@ def test_generate_matplolib_pie_chart_code( data = execute_sql_query(sql_query) plt.pie(data["order_count"], labels=data["ship_country"], autopct='%1.1f%%') -plt.title('Orders Count by Country') +plt.title('''Orders Count by Country''') plt.legend(loc='best') @@ -193,9 +193,9 @@ def test_generate_matplolib_line_chart_code( data = execute_sql_query(sql_query) plt.plot(data["order_date"], data["order_count"]) -plt.xlabel('Order Date') -plt.ylabel('Number of Orders') -plt.title('Orders Over Time') +plt.xlabel('''Order Date''') +plt.ylabel('''Number of Orders''') +plt.title('''Orders Over Time''') plt.legend(loc='best') @@ -232,7 +232,7 @@ def test_generate_matplolib_scatter_chart_code( data = execute_sql_query(sql_query) plt.scatter(data['order_date'], data['ship_via']) -plt.title('Total Freight by Order Date') +plt.title('''Total Freight by Order Date''') plt.legend(loc='best') @@ -274,9 +274,9 @@ def test_generate_matplolib_histogram_chart_code( data = execute_sql_query(sql_query) plt.hist(data['total_freight']) -plt.xlabel('Total Freight') -plt.ylabel('Frequency') -plt.title('Distribution of Total Freight') +plt.xlabel('''Total Freight''') +plt.ylabel('''Frequency''') +plt.title('''Distribution of Total Freight''') plt.savefig("charts.png") @@ -307,19 +307,20 @@ def test_generate_matplolib_boxplot_chart_code( logic_unit = code_gen.execute(json_str, context=context, logger=logger) assert isinstance(logic_unit, LogicUnitOutput) + print(logic_unit.output) assert ( logic_unit.output == """ -import matplotlib.pyplot as plt + import pandas as pd sql_query="SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country" data = execute_sql_query(sql_query) plt.boxplot(data['total_freight']) -plt.xlabel('Shipping Country') -plt.ylabel('Total Freight') -plt.title('Distribution of Total Freight by Shipping Country') +plt.xlabel('''Shipping Country''') +plt.ylabel('''Total Freight''') +plt.title('''Distribution of Total Freight by Shipping Country''') plt.savefig("charts.png") @@ -343,6 +344,7 @@ def test_generate_matplolib_number_type( logic_unit = code_gen.execute(json_str, context=context, logger=logger) assert isinstance(logic_unit, LogicUnitOutput) + print(logic_unit.output) assert ( logic_unit.output == """ @@ -353,7 +355,9 @@ def test_generate_matplolib_number_type( data = execute_sql_query(sql_query) -result = {"type": "number","value": data["order_count"].iloc[0]} +total_value = data["order_count"].sum() + +result = {"type": "number","value": total_value} """ ) @@ -391,16 +395,115 @@ def test_generate_timedimension_query( import matplotlib.pyplot as plt import pandas as pd -sql_query="SELECT `users`.`starredAt` AS starred_at, COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2022-01-01' AND '2023-03-31' GROUP BY starred_at_by_month" +sql_query="SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2022-01-01' AND '2023-03-31' GROUP BY starred_at_by_month" data = execute_sql_query(sql_query) -plt.plot(data["starred_at"], data["user_count"]) -plt.xlabel('Month') -plt.ylabel('Number of Stars') -plt.title('Stars Count per Month') +plt.plot(data["starred_at_by_month"], data["user_count"]) +plt.xlabel('''Month''') +plt.ylabel('''Number of Stars''') +plt.title('''Stars Count per Month''') plt.legend(loc='best') +plt.savefig("charts.png") + +result = {"type": "plot","value": "charts.png"} +""" + ) + + def test_generate_timedimension_for_year( + self, context: PipelineContext, logger: Logger + ): + code_gen = CodeGenerator() + context.add("df_schema", STARS_SCHEMA) + json_str = { + "type": "line", + "measures": ["Users.user_count"], + "timeDimensions": [ + { + "dimension": "Users.starred_at", + "dateRange": ["this year"], + "granularity": "month", + } + ], + "options": { + "xLabel": "Time Period", + "yLabel": "Stars Count", + "title": "Stars Count Per Month This Year", + "legend": {"display": True, "position": "bottom"}, + }, + "filters": [], + "order": [{"id": "Users.starred_at", "direction": "asc"}], + } + + logic_unit = code_gen.execute(json_str, context=context, logger=logger) + print(logic_unit.output) + assert isinstance(logic_unit, LogicUnitOutput) + assert ( + logic_unit.output + == """ +import matplotlib.pyplot as plt +import pandas as pd + +sql_query="SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` >= DATE_TRUNC('year', CURRENT_DATE) AND `users`.`starredAt` < DATE_TRUNC('year', CURRENT_DATE) + INTERVAL '1 year' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc" +data = execute_sql_query(sql_query) + +plt.plot(data["starred_at_by_month"], data["user_count"]) +plt.xlabel('''Time Period''') +plt.ylabel('''Stars Count''') +plt.title('''Stars Count Per Month This Year''') +plt.legend(loc='best') + + +plt.savefig("charts.png") + +result = {"type": "plot","value": "charts.png"} +""" + ) + + def test_generate_timedimension_histogram_for_year( + self, context: PipelineContext, logger: Logger + ): + code_gen = CodeGenerator() + context.add("df_schema", STARS_SCHEMA) + json_str = { + "type": "histogram", + "dimensions": ["Users.starred_at"], + "measures": ["Users.user_count"], + "timeDimensions": [ + { + "dimension": "Users.starred_at", + "dateRange": ["2023-01-01", "2023-12-31"], + "granularity": "month", + } + ], + "options": { + "xLabel": "Starred Month", + "yLabel": "Number of Users", + "title": "Distribution of Stars per Month in 2023", + "legend": {"display": False}, + }, + "filters": [], + "order": [{"id": "Users.starred_at", "direction": "asc"}], + } + + logic_unit = code_gen.execute(json_str, context=context, logger=logger) + assert isinstance(logic_unit, LogicUnitOutput) + assert ( + logic_unit.output + == """ +import matplotlib.pyplot as plt +import pandas as pd + +sql_query="SELECT `users`.`starredAt` AS starred_at, COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at, starred_at_by_month ORDER BY starred_at_by_month asc" +data = execute_sql_query(sql_query) + +plt.hist(data['user_count']) +plt.xlabel('''Starred Month''') +plt.ylabel('''Number of Users''') +plt.title('''Distribution of Stars per Month in 2023''') + + plt.savefig("charts.png") result = {"type": "plot","value": "charts.png"} diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py b/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py index 0226cc24c..5e16d9acc 100644 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py +++ b/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py @@ -36,41 +36,145 @@ class TestSemanticAgent: def sample_df(self): return pd.DataFrame( { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", + "order_id": [ + 10248, + 10249, + 10250, + 10251, + 10252, + 10253, + 10254, + 10255, + 10256, + 10257, + ], + "customer_id": [ + "VINET", + "TOMSP", + "HANAR", + "VICTE", + "SUPRD", + "HANAR", + "CHOPS", + "RICSU", + "WELLI", + "HILAA", + ], + "employee_id": [5, 6, 4, 3, 4, 3, 4, 7, 3, 4], + "order_date": pd.to_datetime( + [ + "1996-07-04", + "1996-07-05", + "1996-07-08", + "1996-07-08", + "1996-07-09", + "1996-07-10", + "1996-07-11", + "1996-07-12", + "1996-07-15", + "1996-07-16", + ] + ), + "required_date": pd.to_datetime( + [ + "1996-08-01", + "1996-08-16", + "1996-08-05", + "1996-08-05", + "1996-08-06", + "1996-08-07", + "1996-08-08", + "1996-08-09", + "1996-08-12", + "1996-08-13", + ] + ), + "shipped_date": pd.to_datetime( + [ + "1996-07-16", + "1996-07-10", + "1996-07-12", + "1996-07-15", + "1996-07-11", + "1996-07-16", + "1996-07-23", + "1996-07-26", + "1996-07-17", + "1996-07-22", + ] + ), + "ship_via": [3, 1, 2, 1, 2, 2, 2, 3, 2, 1], + "ship_name": [ + "Vins et alcools Chevalier", + "Toms Spezialitäten", + "Hanari Carnes", + "Victuailles en stock", + "Suprêmes délices", + "Hanari Carnes", + "Chop-suey Chinese", + "Richter Supermarkt", + "Wellington Importadora", + "HILARION-Abastos", + ], + "ship_address": [ + "59 rue de l'Abbaye", + "Luisenstr. 48", + "Rua do Paço, 67", + "2, rue du Commerce", + "Boulevard Tirou, 255", + "Rua do Paço, 67", + "Hauptstr. 31", + "Starenweg 5", + "Rua do Mercado, 12", + "Carrera 22 con Ave. Carlos Soublette #8-35", ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, + "ship_city": [ + "Reims", + "Münster", + "Rio de Janeiro", + "Lyon", + "Charleroi", + "Rio de Janeiro", + "Bern", + "Genève", + "Resende", + "San Cristóbal", ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, + "ship_region": [ + "CJ", + None, + "RJ", + "RH", + None, + "RJ", + None, + None, + "SP", + "Táchira", + ], + "ship_postal_code": [ + "51100", + "44087", + "05454-876", + "69004", + "B-6000", + "05454-876", + "3012", + "1204", + "08737-363", + "5022", + ], + "ship_country": [ + "France", + "Germany", + "Brazil", + "France", + "Belgium", + "Brazil", + "Switzerland", + "Switzerland", + "Brazil", + "Venezuela", ], } )