Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jhilgart/ensure sample values wrapped in str #35

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion semantic_model_generator/data_processing/proto_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import io
import json
from typing import Any, Dict, Type, TypeVar
from typing import Any, Dict, List, Type, TypeVar, Union

import ruamel.yaml
import yaml
Expand All @@ -13,6 +13,57 @@
ProtoMsg = TypeVar("ProtoMsg", bound=Message)


def ensure_sample_values_wrapped_in_quotes(yaml_str: str) -> str:
# Using ruamel.yaml package to preserve message order.
yaml = ruamel.yaml.YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.preserve_quotes = True

# Load YAML string
data = yaml.load(yaml_str)

# Function to recursively ensure sample values are wrapped in quotes
def ensure_quotes(obj: Union[Dict[str, Any], List[Any]]) -> None:
if isinstance(obj, dict):
for key, value in obj.items():
if isinstance(value, (str, bool, int, float)):
obj[key] = f'"{value}"'
elif isinstance(value, dict) or isinstance(value, list):
ensure_quotes(value)
elif isinstance(obj, list):
for i, item in enumerate(obj):
if isinstance(item, (str, bool, int, float)):
obj[i] = f'"{item}"'
elif isinstance(item, dict) or isinstance(item, list):
ensure_quotes(item)

# Ensure sample values are wrapped in quotes
for table in data["tables"]:
for column_type in ["dimensions", "measures", "time_dimensions"]:
if column_type in table:
for column in table[column_type]:
if "sample_values" in column:
column["sample_values"] = [
f"'{value}'"
for value in column["sample_values"]
if "'" not in str(value)
]

# Dump data back to YAML string
with io.StringIO() as stream:
yaml.dump(data, stream)
# Ensure multiline strings are preserved in block style
output = stream.getvalue()
output = output.replace(
"'''\n", "|-\n"
) # Replace triple single quotes with block style indicator
output = output.replace(
"'''", "'"
) # Replace any remaining triple single quotes with single quotes
output = output.replace("\"'", "'").replace("'\"", "'")
return output


def proto_to_yaml(message: ProtoMsg) -> str:
"""Serializes the input proto into a yaml message.

Expand Down
4 changes: 3 additions & 1 deletion semantic_model_generator/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,13 @@ def generate_base_semantic_model_from_snowflake(
)

yaml_str = proto_utils.proto_to_yaml(context)
# Update sample_values to surround them with single quotes
yaml_str = proto_utils.ensure_sample_values_wrapped_in_quotes(yaml_str)

# Once we have the yaml, update to include to # <FILL-OUT> tokens.
yaml_str = append_comment_to_placeholders(yaml_str)

# Validate the generated yaml is within context limits

validate_context_length(yaml_str)

with open(write_path, "w") as f:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ def _get_column_representation(
)
assert cursor_execute is not None, "cursor_execute should not be none "
res = cursor_execute.fetchall()
# Cast all values to string to ensure the list is json serializable.
# A better solution would be to identify the possible types that are not
# json serializable (e.g. datetime objects) and apply the appropriate casting
# in just those cases.
if len(res) > 0:
if isinstance(res[0], dict):
col_key = [k for k in res[0].keys()][0]
Expand Down
88 changes: 69 additions & 19 deletions semantic_model_generator/tests/validate_model_test.py

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions semantic_model_generator/validate/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,30 @@ def validate_contains_datatype_for_each_col(table: Table) -> None:
raise ValueError(
f"Your Semantic Model contains a col {time_dim_col.name} that does not have the `data_type` field. Please add."
)


def validate_sample_values_are_quoted(yaml_str: str) -> None:
"""
Validate that all sample_values in the provided YAML data are wrapped in quotes.

"""
inside_sample_values = False
for line in yaml_str.split("\n"):
line = line.strip()
if len(line) == 0:
continue

if "sample_values" in line:
inside_sample_values = True
continue
# Check if we are still in the list of sample values, or if we moved to another block element or a new table.
if inside_sample_values and (line[0] != "-" or "- name:" in line): # reset
inside_sample_values = False
continue
if inside_sample_values:
# ensure all quoted.
# count single and double quotes.
if line.count("'") != 2 and line.count('"') != 2:
raise ValueError(
f"You need to have all sample_values: surrounded by quotes. Please fix the value {line}."
)
3 changes: 3 additions & 0 deletions semantic_model_generator/validate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from semantic_model_generator.sqlgen.generate_sql import generate_select_with_all_cols
from semantic_model_generator.validate.context_length import validate_context_length
from semantic_model_generator.validate.fields import validate_sample_values_are_quoted


def validate(yaml_path: str, snowflake_account: str) -> None:
Expand All @@ -22,6 +23,8 @@ def validate(yaml_path: str, snowflake_account: str) -> None:
yaml_str = f.read()
# Validate the context length doesn't exceed max we can support.
validate_context_length(yaml_str)
validate_sample_values_are_quoted(yaml_str)

model = yaml_to_semantic_model(yaml_str)

connector = SnowflakeConnector(
Expand Down
Loading