diff --git a/semantic_model_generator/data_processing/proto_utils.py b/semantic_model_generator/data_processing/proto_utils.py index 6b066d68..62f664f7 100644 --- a/semantic_model_generator/data_processing/proto_utils.py +++ b/semantic_model_generator/data_processing/proto_utils.py @@ -7,7 +7,9 @@ from google.protobuf.message import Message from strictyaml import dirty_load +from semantic_model_generator.data_processing.sql_parsing import extract_table_columns from semantic_model_generator.protos import semantic_model_pb2 +from semantic_model_generator.protos.semantic_model_pb2 import SemanticModel from semantic_model_generator.validate.schema import SCHEMA ProtoMsg = TypeVar("ProtoMsg", bound=Message) @@ -63,7 +65,140 @@ def proto_to_dict(message: ProtoMsg) -> dict[str, Any]: raise ValueError(f"Failed to convert protobuf message to dictionary: {e}") -def yaml_to_semantic_model(yaml_str: str) -> semantic_model_pb2.SemanticModel: +def context_to_column_format(ctx: SemanticModel) -> SemanticModel: + """ + Converts SemanticModel from a dimension/measure format to a column format. + Returns a new SemanticModel object that's in column format. + """ + ret = SemanticModel() + ret.CopyFrom(ctx) + for table in ret.tables: + column_format = len(table.columns) > 0 + dimension_measure_format = ( + len(table.dimensions) > 0 + or len(table.time_dimensions) > 0 + or len(table.measures) > 0 + ) + if column_format and dimension_measure_format: + raise ValueError( + f"table {table.name} defines both columns and dimensions/time_dimensions/measures." + ) + if column_format: + continue + for d in table.dimensions: + col = semantic_model_pb2.Column() + col.kind = semantic_model_pb2.ColumnKind.dimension + col.name = d.name + col.synonyms.extend(d.synonyms) + col.description = d.description + col.expr = d.expr + col.data_type = d.data_type + col.unique = d.unique + col.sample_values.extend(d.sample_values) + # Do in-memory indexing & and retrieval of sample values + # for dimensions that don't have a search service defined on them. + # The number of sample values passed to the model may be capped + # to the first N, but retrieving the samples values + # based on the question means that many more values can be added + # to the semantic model, and only passed to the model when relevant. + col.index_and_retrieve_values = not d.cortex_search_service_name + col.cortex_search_service_name = d.cortex_search_service_name + table.columns.append(col) + del table.dimensions[:] + + for td in table.time_dimensions: + col = semantic_model_pb2.Column() + col.kind = semantic_model_pb2.ColumnKind.time_dimension + col.name = td.name + col.synonyms.extend(td.synonyms) + col.description = td.description + col.expr = td.expr + col.data_type = td.data_type + col.unique = td.unique + col.sample_values.extend(td.sample_values) + table.columns.append(col) + del table.time_dimensions[:] + + for m in table.measures: + col = semantic_model_pb2.Column() + col.kind = semantic_model_pb2.ColumnKind.measure + col.name = m.name + col.synonyms.extend(m.synonyms) + col.description = m.description + col.expr = m.expr + col.data_type = m.data_type + col.default_aggregation = m.default_aggregation + col.sample_values.extend(m.sample_values) + table.columns.append(col) + del table.measures[:] + return ret + + +def _validate_metric(ctx: SemanticModel) -> None: + """ + Validates that the semantic model metric definition matches join paths defined. + """ + + def _find_table_by_name( + ctx: SemanticModel, table_name: str + ) -> semantic_model_pb2.Table | None: + for table in ctx.tables: + if table.name.lower() == table_name.lower(): + return table + return None + + if not ctx.metrics: + # No metric exsiting in the definition, exit validation. + return + if not ctx.relationships: + raise ValueError("Semantic model has metrics but no join paths defined.") + + join_pairs = [ + {join.left_table.lower(), join.right_table.lower()} + for join in ctx.relationships + ] + for metric in ctx.metrics: + # First find all tables referred in the metrics. All columns is supposed to be fully qualified with logical table names. + # Raises error if: + # 1. Found any columns not fully qualified with logical table name. + # 2. Only one logical table referred in a metric, indicating it should be defined as a measure, not a metric. + # 3. For now only supports metric defined across two tables. Raise error if more than two tables referred. + # 4. The join path between the two tables must be defined in the semantic model. + tbl_col_mapping = extract_table_columns(metric.expr) + non_qualified_cols = tbl_col_mapping.get("") + if non_qualified_cols and len(non_qualified_cols) > 0: + raise ValueError( + f"Error in {metric.name}; Columns within metric definition needs to be qualified with corresponding logical table name." + ) + tbls_referred = set(key.lower() for key in tbl_col_mapping.keys()) + if len(tbls_referred) == 1: + raise ValueError( + f"Error in {metric.name}; Metric calculation only referred to one logical table, please define as a measure, instead of metric" + ) + if len(tbls_referred) > 2: + raise ValueError( + f"Error in {metric.name}; Currently only accept metric defined across two tables" + ) + if tbls_referred not in join_pairs: + raise ValueError( + f"Error in {metric.name}; No direct join relationship defined between {','.join(sorted(tbls_referred))}" + ) + + for k, v in tbl_col_mapping.items(): + tbl = _find_table_by_name(ctx, k) + if tbl is None: + raise ValueError( + f"Error in {metric.name}; Metric calculation referred to undefined logical table name {k}" + ) + + for col in v: + if col.lower() not in [c.name.lower() for c in tbl.columns]: # type: ignore + raise ValueError( + f"Error in {metric.name}; Metric calculation referred to undefined logical column name {col} in table {k}" + ) + + +def yaml_to_semantic_model(yaml_str: str) -> SemanticModel: """ Deserializes the input yaml into a SemanticModel Protobuf message. The input yaml must be fully representable as json, so yaml features like @@ -84,4 +219,7 @@ def yaml_to_semantic_model(yaml_str: str) -> semantic_model_pb2.SemanticModel: yaml_str, SCHEMA, label="semantic model", allow_flow_style=True ) msg = semantic_model_pb2.SemanticModel() - return json_format.ParseDict(parsed_yaml.data, msg) + ctx: SemanticModel = json_format.ParseDict(parsed_yaml.data, msg) + col_format_ctx = context_to_column_format(ctx) + _validate_metric(col_format_ctx) + return col_format_ctx diff --git a/semantic_model_generator/data_processing/sql_parsing.py b/semantic_model_generator/data_processing/sql_parsing.py new file mode 100644 index 00000000..11c674db --- /dev/null +++ b/semantic_model_generator/data_processing/sql_parsing.py @@ -0,0 +1,449 @@ +import re +from typing import Set, Optional, Union + +from typing import Dict, List + +import sqlglot +from sqlglot.dialects.dialect import NormalizationStrategy +from sqlglot.dialects.snowflake import Snowflake +from sqlglot.errors import OptimizeError, ParseError, TokenError +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlglot.optimizer.qualify_columns import Resolver, _qualify_columns +from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.schema import Schema, ensure_schema + +from semantic_model_generator.validate.keywords import SF_RESERVED_WORDS + +DOUBLE_QUOTE = '"' +_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER = r"[A-Za-z_][A-Za-z0-9_]*(?:\$[A-Za-z0-9_]*)?" +_SF_UNQUOTED_CASE_SENSITIVE_IDENTIFIER = r"[A-Z_][A-Z0-9_]*(?:\$[A-Z0-9_]*)?" +UNQUOTED_CASE_INSENSITIVE_RE = re.compile( + f"^({_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER})$" +) +UNQUOTED_CASE_SENSITIVE_RE = re.compile(f"^({_SF_UNQUOTED_CASE_SENSITIVE_IDENTIFIER})$") + + +def _get_escaped_name(id: str) -> str: + """Add double quotes to escape quotes. + Replace double quotes with double double quotes if there is existing double + quotes. + + NOTE: See note in :meth:`_is_quoted`. + + Args: + id: The string to be checked & treated. + + Returns: + String with quotes would doubled; original string would add double quotes. + """ + escape_quotes = id.replace(DOUBLE_QUOTE, DOUBLE_QUOTE + DOUBLE_QUOTE) + return DOUBLE_QUOTE + escape_quotes + DOUBLE_QUOTE + + +def get_escaped_names( + ids: Optional[Union[str, List[str]]] +) -> Optional[Union[str, List[str]]]: + """Given a user provided identifier(s), this method will compute the equivalent + column name identifier(s) in case of column name contains special characters, and + maintains case-sensitivity + https://docs.snowflake.com/en/sql-reference/identifiers-syntax. + + Args: + ids: User provided column name identifier(s). + + Returns: + Double-quoted Identifiers for column names, to make sure that column names are + case sensitive. + + Raises: + ValueError: if input types is unsupported or column name identifiers are + invalid. + """ + + if ids is None: + return None + elif type(ids) is list: + return [_get_escaped_name(id) for id in ids] + elif type(ids) is str: + return _get_escaped_name(ids) + else: + raise ValueError( + "Unsupported type. Only string or list of string are supported for " + "selecting columns." + ) + + +def should_be_quoted(identifier: str) -> bool: + """Checks whether a given identifier should be quoted. + + NOTE: Assumes the identifier is given as it is stored in DB metadata + and as shown in INFORMATION_SCHEMA or the output a DESCRIBE command. + (The upper case for unquoted identifiers.) + + Args: + identifier: The identifier to be checked + + Returns: + Whether should be quoted. + """ + if UNQUOTED_CASE_SENSITIVE_RE.match(identifier): + if identifier in SF_RESERVED_WORDS: + return True + return False + + return True + + +def get_llm_friendly_name(identifier: str) -> str: + """Return the form simplest for an LLM (lower case preferred) + of an identifier. Put the identifier in double quotes if needed. + + NOTE: Assumes the identifier is given as it is stored in DB metadata + and as shown in INFORMATION_SCHEMA or the output a DESCRIBE command. + (The upper case for unquoted identifiers.) + + Args: + identifier: The identifier to be checked + + Returns: + Transformed identifier. + """ + + if should_be_quoted(identifier): + return get_escaped_names(identifier) # type: ignore + + return identifier.lower() + + +def get_all_table_names(sql_str: str) -> List[str]: + """ + Given a string of SQL, returns all the tables present in the query + """ + return [ + table.name + for table in sqlglot.parse_one(sql_str).find_all(sqlglot.exp.Table) + if table and table.name + ] + + +def get_table_names_excluding_subqueries(sql_str: str) -> List[str]: + """ + Given a string of SQL, returns all the tables present but not present in subqueries. + For example, in the following query, this method excludes sub_table. + SELECT primary_table.id, (SELECT MAX(value) FROM sub_table + WHERE sub_table.primary_id = primary_table.id) AS max_value FROM primary_table + + Does not differentiate between a FROM in a WITH clauses and vanilla FROM. + Does not consider table functions as table names. + + Note that queries with semi-structured access like f.value:ProductID::INTEGER + are not supported in the current prod version of SQLGLOT==16.7.3, though locally + the version 18.16.1 should enable this type of parsing. + + This can be useful in two ways: if one wants to swap columns between tables + in sub-queries with the main queries to create an error, and also if one simply + wants the tables used exclusively in the main FROM and JOIN clauses. + """ + parsed = sqlglot.parse_one(sql_str) + main_query_tables = [] + + def _is_in_main_query(expr: sqlglot.Expression) -> bool: + """Checks if the expression is present in a subquery by ascending the parents""" + while expr.parent: + if isinstance(expr.parent, sqlglot.expressions.Subquery) or isinstance( + expr.parent, sqlglot.expressions.Where + ): + return False + expr = expr.parent + return True + + def _find_tables(expression: sqlglot.Expression) -> None: + """Recurses through the parsed tree to find tables not in subqueries""" + if ( + isinstance(expression, sqlglot.expressions.Table) + and _is_in_main_query(expression) + and expression.name + ): + main_query_tables.append(expression.name) + for arg in expression.args.values(): + if isinstance(arg, sqlglot.expressions.Expression): + _find_tables(arg) + elif isinstance(arg, list): + for sub in arg: + _find_tables(sub) + + # Start with the FROM clause + from_clause = parsed.args.get("from") + if from_clause: + for from_expr in from_clause.find_all(sqlglot.expressions.Table): + _find_tables(from_expr) + + # Look at the With clauses as well + for with_clause in parsed.find_all(sqlglot.expressions.With): + _find_tables(with_clause) + + # Then look at the JOIN clauses + for join in parsed.find_all(sqlglot.expressions.Join): + _find_tables(join) + + return list(dict.fromkeys(main_query_tables)) # Remove duplicates, preserves order + + +def get_all_column_names(sql_str: str) -> List[str]: + """ + Given a string of SQL, returns all the columns selected from in the query + """ + return [ + column.name + for column in sqlglot.parse_one(sql_str).find_all(sqlglot.exp.Column) + ] + + +def get_all_column_names_from_select(sql_str: str) -> List[str]: + """ + Gets all the columns present specifically in the select clause + Ignores columns used in a subquery or in a where clause + """ + parsed = sqlglot.parse_one(sql_str) + select_columns = [] + + def _is_in_main_query(expr: sqlglot.Expression) -> bool: + """ + Checks if the expression is present in a subquery or + where clause by ascending the parents + """ + while expr.parent: + if isinstance(expr.parent, sqlglot.expressions.Subquery) or isinstance( + expr.parent, sqlglot.expressions.Where + ): + return False + expr = expr.parent + return True + + # Traverse SELECT expressions + for select in parsed.find_all(sqlglot.expressions.Select): + for expression in select.args.get("expressions", []): + # don't look in where clauses + if not _is_in_main_query(expression): + continue + # If it's a column or an alias containing a column, add it to the list + if isinstance(expression, sqlglot.expressions.Column): + select_columns.append(expression.name) + elif hasattr(expression, "this") and isinstance( + expression.this, sqlglot.expressions.Column + ): + select_columns.append(expression.this.sql()) + + return select_columns + + +def get_tables_with_distinct_columns_from_used_in_query( + sql_str: str, + all_table_names_in_schemas: List[str], + columns_per_table: Dict[str, List[str]], +) -> List[str]: + """ + Finds unused tables in the schema that do not contain at least one of the + currently used columns from the currently in-use table names. + + Note that to simplify the implementation, a subquery is not considered in-use. + + If there are no columns in use, then return an empty list by default + """ + tables_without_present_columns = [] + columns_used_in_query = get_all_column_names(sql_str) + table_names_in_use = get_table_names_excluding_subqueries(sql_str) + + if len(columns_used_in_query) == 0: + return [] + + for table_name in all_table_names_in_schemas: + if table_name in table_names_in_use: + continue + current_column_is_not_present_in_table = False + for used_column in columns_used_in_query: + if used_column not in columns_per_table[table_name]: + current_column_is_not_present_in_table = True + if current_column_is_not_present_in_table: + tables_without_present_columns.append(table_name) + return tables_without_present_columns + + +def get_columns_not_present_in_any_in_use_tables( + table_names_in_use: List[str], + all_table_names_in_schemas: List[str], + columns_per_table: Dict[str, List[str]], +) -> List[str]: + """ + Gets all the columns of other tables that are not present in currently + used tables. + + Assumes that the tables will have column information in the dictionary. + """ + all_columns_in_schema = [ + column + for table in all_table_names_in_schemas + for column in columns_per_table.get(table, []) + ] + all_columns_in_current_tables = [ + column + for table in table_names_in_use + for column in columns_per_table.get(table, []) + ] + return [ + column + for column in all_columns_in_schema + if column not in all_columns_in_current_tables + ] + + +class FQNNormalizationError(Exception): + pass + + +class FQNNormalizationQualifyColumnError(FQNNormalizationError): + pass + + +class FQNNormalizationParsingError(FQNNormalizationError): + pass + + +class LowercaseSnowflake(Snowflake): # type: ignore + """'snowflake' dialect enforces uppercase identifiers normalization by default + thus here we introduce custom dialect to apply lowercase instead""" + + NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE + + +def transform_sql_to_fqn_form( + sql: str, schema_simple: dict[str, dict[str, str]], pretty_output: bool = False +) -> str: + """Transform query into "fully qualified names" form, where all column identifiers + are expanded to table_name.column_name form. Additionally all non-mixed-case identifiers + are lowercaseed and the whole query is formatted with sqlglot. + + Args: + sql (str): input SQL query text + schema_simple (dict[str, dict[str, str]]): schema dict, the same as in + SQLGlot's optimize method: https://github.com/tobymao/sqlglot?tab=readme-ov-file#sql-optimizer + pretty_output (bool, optional): apply tabs and newlines in final formatting. Defaults to False. + + Raises: + FQNNormalizationParsingError: in case of unparsable query + FQNNormalizationQualifyColumnError: in case of unidentifiable columns which are present in a query + but are missing from provided schema + + Returns: + str: normalized query + """ + # We have to normalize schema here in order to handle mixed-case columns + normalized_schema = { + get_llm_friendly_name(table_name): { + get_llm_friendly_name(c_name): c_type for c_name, c_type in columns.items() + } + for table_name, columns in schema_simple.items() + } + schema: Schema = ensure_schema(normalized_schema, dialect=LowercaseSnowflake) + + try: + parsed = sqlglot.parse_one(sql, dialect=LowercaseSnowflake) + except (ParseError, TokenError) as e: + raise FQNNormalizationParsingError(str(e)) + + # normalize all unquoted identifiers to lowercase + parsed = normalize_identifiers(parsed, LowercaseSnowflake) + # this handles normalization of some mixed-case table names + parsed = qualify_tables(parsed, schema=schema, dialect=LowercaseSnowflake) + + # traverse every possible scope (one scope coresponds to one perticular select expression in a query) + for scope in traverse_scope(parsed): + # Resolver object helps with identifying column's parent table in a given scope + resolver = Resolver(scope, schema, infer_schema=schema.empty) + try: + # sqlglot's optimization, which should qualify all columns (but in some cases it won't) + _qualify_columns(scope, resolver) + except OptimizeError as e: + raise FQNNormalizationQualifyColumnError(str(e)) + + # gather alias to table mapping for all tables in a current scope + aliases = { + table.alias: table.this + for table in scope.expression.find_all(sqlglot.exp.Table) + if table.alias + } + aliases.update( + { + a.name: a.parent.this + for a in scope.expression.find_all(sqlglot.exp.TableAlias) + if isinstance(a.parent, sqlglot.exp.Table) + } + ) + + # remove all table aliases + for table in scope.tables: + if table.alias: + table.set("alias", None) + + # fix all cases unhandled by _qualify_columns(...) optimization (GROUP BY and QUALIFY OVER) + # and expand all remaining table aliases to table identifiers + for column in scope._raw_columns: # type: ignore + if column.table == "": + if (table_name_or_alias := resolver.get_table(column.name)) is not None: + table_name = aliases.get( + table_name_or_alias.name, table_name_or_alias.name + ) + column.set("table", table_name) + elif column.table in aliases: + column.set("table", aliases[column.table]) + + # fix case when we star-expand one particular table via alias: + # SELECT f.* FROM foo as f --> SELECT foo.* FROM foo + for star in scope.find_all(sqlglot.exp.Star): # type: ignore + if ( + isinstance(star.parent, sqlglot.exp.Column) + or isinstance(star.parent, sqlglot.exp.Table) + ) and star.parent.table in aliases: # type: ignore + star.parent.set("table", aliases[star.parent.table]) # type: ignore + + return str(parsed.sql(LowercaseSnowflake, pretty=pretty_output)) + + +def extract_table_columns(sql: str) -> Dict[str, Set[str]]: + """ + Given an arbitrary SQL, returns a map from referenced tables to their referenced columns. + """ + # First, qualify all columns names with their table names. + qualified_sql = transform_sql_to_fqn_form(sql, {}, pretty_output=True) + table_columns: Dict[str, Set[str]] = {} + parse = sqlglot.parse_one(qualified_sql, read=Snowflake) + + # Find all tables that are referenced. + # This covers the case where a column is never specifically referenced in the + # table, such as `select count(*) from table`. + for t in parse.find_all(sqlglot.expressions.Table): + if t.name not in table_columns: + table_columns[t.name] = set() + + # Now map all the referenced columns to their tables. + for e in parse.find_all(sqlglot.expressions.Column): + table_columns.setdefault(e.table, set()).add(e.name) + + # Finally, find any tables that we `select * from` and add * to the columns. + # Note that this only finds `*` when it is not qualified by a table name. + # If it is qualified by by a table name, it will be parsed as column and + # handled above. + for ss in parse.find_all(sqlglot.expressions.Select): + if sqlglot.expressions.Star() not in ss.expressions: + continue + from_ = ss.args.get("from", None) + joins = ss.args.get("joins", []) + if from_ is None: + print(f"No from clause found in `select *` statement in query {sql}") + continue + + for table in [from_] + joins: + table_columns.setdefault(table.this.name, set()).add("*") + + return table_columns diff --git a/semantic_model_generator/tests/validate_model_test.py b/semantic_model_generator/tests/validate_model_test.py index 4191a3be..8f143638 100644 --- a/semantic_model_generator/tests/validate_model_test.py +++ b/semantic_model_generator/tests/validate_model_test.py @@ -245,8 +245,8 @@ def test_invalid_yaml_too_long_context( expected_error = ( "Your semantic model is too large. " - "Passed size is 164952 characters. " - "We need you to remove 41032 characters in your semantic model. Please check: \n" + "Passed size is 165064 characters. " + "We need you to remove 41144 characters in your semantic model. Please check: \n" " (1) If you have long descriptions that can be truncated. \n" " (2) If you can remove some columns that are not used within your tables. \n" " (3) If you have extra tables you do not need."