Skip to content

Commit

Permalink
update logic
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-cnivera committed Aug 6, 2024
1 parent 5f8c674 commit 2e3c74e
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 8 deletions.
79 changes: 73 additions & 6 deletions semantic_model_generator/tests/samples/validate_yamls.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@
sample_values:
- '631'
"""
_LONG_VQR_CONTEXT = """

_LONG_VQR_CONTEXT_TEMPLATE = """
- name: "Max spend"
question: "Over the past week what was spend?"
question: "Over the past week what was spend from {index}?"
verified_at: 1714497970
verified_by: jonathan
sql: "
Expand All @@ -75,8 +76,11 @@
Minute DESC;
"
"""
_VALID_YAML_LONG_VQR_CONTEXT = (
"""name: my test semantic model

# Generate 100 unique variations to avoid duplicate verified query error.
long_vqr_contexts = [_LONG_VQR_CONTEXT_TEMPLATE.format(index=i) for i in range(100)]

_VALID_YAML_LONG_VQR_CONTEXT = """name: my test semantic model
tables:
- name: ALIAS
base_table:
Expand Down Expand Up @@ -119,8 +123,8 @@
sample_values:
- '631'
verified_queries:
"""
+ _LONG_VQR_CONTEXT * 100
""" + "\n".join(
long_vqr_contexts
)


Expand Down Expand Up @@ -375,3 +379,66 @@
)
],
)

_VALID_YAML_WITH_SINGLE_VERIFIED_QUERY = """
name: jaffle_shop
tables:
- name: orders
description: Order overview data mart, offering key details for each order including
if it's a customer's first order and a food vs. drink item breakdown. One row
per order.
base_table:
database: autosql_dataset_dbt_jaffle_shop
schema: data
table: orders
filters:
- name: large_order
expr: cogs > 100
- name: custom_filter
expr: my_udf(col1, col2)
- name: window_func
expr: COUNT(i) OVER (PARTITION BY p ORDER BY o) count_i_Range_Pre
verified_queries:
- name: daily cumulative expenses in 2023 dec
question: daily cumulative expenses in 2023 dec
sql: " SELECT date, SUM(daily_cogs) OVER ( ORDER BY date ROWS BETWEEN UNBOUNDED
PRECEDING AND CURRENT ROW ) AS cumulative_cogs FROM __daily_revenue WHERE date
BETWEEN '2023-12-01' AND '2023-12-31' ORDER BY date DESC; "
verified_at: '1714752498'
verified_by: renee
"""

_INVALID_YAML_DUPLICATE_VERIFIED_QUERIES = """
name: jaffle_shop
tables:
- name: orders
description: Order overview data mart, offering key details for each order including
if it's a customer's first order and a food vs. drink item breakdown. One row
per order.
base_table:
database: autosql_dataset_dbt_jaffle_shop
schema: data
table: orders
filters:
- name: large_order
expr: cogs > 100
- name: custom_filter
expr: my_udf(col1, col2)
- name: window_func
expr: COUNT(i) OVER (PARTITION BY p ORDER BY o) count_i_Range_Pre
verified_queries:
- name: daily cumulative expenses in 2023 dec
question: daily cumulative expenses in 2023 dec
sql: " SELECT date, SUM(daily_cogs) OVER ( ORDER BY date ROWS BETWEEN UNBOUNDED
PRECEDING AND CURRENT ROW ) AS cumulative_cogs FROM __daily_revenue WHERE date
BETWEEN '2023-12-01' AND '2023-12-31' ORDER BY date DESC; "
verified_at: '1714752498'
verified_by: renee
- name: daily cumulative expenses in 2023 dec
question: daily cumulative expenses in 2023 dec
sql: " SELECT date, SUM(daily_cogs) OVER ( ORDER BY date ROWS BETWEEN UNBOUNDED
PRECEDING AND CURRENT ROW ) AS cumulative_cogs FROM __daily_revenue WHERE date
BETWEEN '2023-12-01' AND '2023-12-31' ORDER BY date DESC; "
verified_at: '1714752498'
verified_by: renee
"""
32 changes: 32 additions & 0 deletions semantic_model_generator/tests/validate_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ def temp_invalid_yaml_too_long_context():
yield tmp.name


@pytest.fixture
def temp_valid_yaml_with_verified_query():
"""Create a temporary YAML file with the test data."""
with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp:
tmp.write(validate_yamls._VALID_YAML_WITH_SINGLE_VERIFIED_QUERY)
tmp.flush()
yield tmp.name


@pytest.fixture
def temp_invalid_yaml_duplicate_verified_queries():
"""Create a temporary YAML file with the test data."""
with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp:
tmp.write(validate_yamls._INVALID_YAML_DUPLICATE_VERIFIED_QUERIES)
tmp.flush()
yield tmp.name


@mock.patch("semantic_model_generator.validate_model.logger")
def test_valid_yaml_flow_style(
mock_logger, temp_valid_yaml_file_flow_style, mock_snowflake_connection
Expand Down Expand Up @@ -247,3 +265,17 @@ def test_valid_yaml_many_sample_values(mock_logger, mock_snowflake_connection):
tmp.write(yaml)
tmp.flush()
assert validate_from_local_path(tmp.name, account_name) is None


@mock.patch("semantic_model_generator.validate_model.logger")
def test_invalid_yaml_duplicate_verified_queries(
mock_logger, temp_invalid_yaml_duplicate_verified_queries, mock_snowflake_connection
):
account_name = "snowflake test"
with pytest.raises(
YAMLValidationError,
match=r"Duplicate verified query found\.\n in \"semantic model\", line \d+, column \d+:\n verified_queries:\n \^ \(line: \d+\)\ndaily cumulative expenses in 2023 dec\n in \"semantic model\", line \d+, column \d+:\n verified_by: renee\n \^ \(line: \d+\)",
):
validate_from_local_path(
temp_invalid_yaml_duplicate_verified_queries, account_name
)
38 changes: 36 additions & 2 deletions semantic_model_generator/validate/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@

import sqlglot
from google.protobuf.descriptor import Descriptor, EnumDescriptor, FieldDescriptor
from strictyaml import Bool, Decimal, Enum, Int, Map, Optional, Seq, Str, Validator
from strictyaml import (
Bool,
Decimal,
Enum,
Int,
Map,
Optional,
Seq,
Str,
Validator,
YAMLValidationError,
)

from semantic_model_generator.protos import semantic_model_pb2

Expand All @@ -32,6 +43,26 @@ def validate_scalar(self, chunk): # type: ignore
return chunk.contents


class VerifiedQueries(Seq): # type: ignore
"""
Validator for the verified_queries field.
We ensure that there are no duplicate verified queries, by checking for duplicate (question, sql) pairs.
"""

def validate(self, chunk): # type: ignore
super().validate(chunk)
seen_queries = set()
for query in chunk.contents:
qa_pair = (query["question"], query["sql"])
if qa_pair in seen_queries:
raise YAMLValidationError(
context="Duplicate verified query found.",
problem=query["name"],
chunk=chunk,
)
seen_queries.add(qa_pair)


def create_schema_for_message(
message: Descriptor, precomputed_types: Dict[str, Validator]
) -> Validator:
Expand Down Expand Up @@ -69,7 +100,10 @@ def create_schema_for_field(
raise Exception(f"unsupported type: {field_descriptor.type}")

if field_descriptor.label == FieldDescriptor.LABEL_REPEATED:
field_type = Seq(field_type)
if field_descriptor.name == "verified_queries":
field_type = VerifiedQueries(field_type)
else:
field_type = Seq(field_type)

return field_type

Expand Down

0 comments on commit 2e3c74e

Please sign in to comment.