Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate context length for generation and validation #33

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions semantic_model_generator/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_valid_schemas_tables_columns_df,
)
from semantic_model_generator.snowflake_utils.utils import create_fqn_table
from semantic_model_generator.validate.context_length import validate_context_length

_PLACEHOLDER_COMMENT = " "
_FILL_OUT_TOKEN = " # <FILL-OUT>"
Expand Down Expand Up @@ -299,6 +300,10 @@ def generate_base_semantic_model_from_snowflake(
# Once we have the yaml, update to include to # <FILL-OUT> tokens.
yaml_str = append_comment_to_placeholders(yaml_str)

# Validate the generated yaml is within context limits

validate_context_length(yaml_str)

with open(write_path, "w") as f:
f.write(yaml_str)
logger.info(f"Semantic model saved to {write_path}")
Expand Down
82 changes: 81 additions & 1 deletion semantic_model_generator/tests/generate_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ def mock_snowflake_connection():
comment=None,
)

_TABLE_WITH_THAT_EXCEEDS_CONTEXT = Table(
id_=0,
name="PRODUCTS",
columns=[
Column(
id_=0,
column_name="SKU",
column_type="NUMBER",
values=["1", "2", "3"] * 3000,
comment=None,
),
],
comment=None,
)


@pytest.fixture
def mock_snowflake_connection_env(monkeypatch):
Expand Down Expand Up @@ -262,6 +277,40 @@ def mock_dependencies_object_dtype(mock_snowflake_connection):
yield


@pytest.fixture
def mock_dependencies_exceed_context(mock_snowflake_connection):
valid_schemas_tables_columns_df_alias = pd.DataFrame(
{
"TABLE_NAME": ["PRODUCTS"],
"COLUMN_NAME": ["SKU"],
"DATA_TYPE": ["OBJECT"],
}
)
valid_schemas_tables_columns_df_zip_code = pd.DataFrame(
{
"TABLE_NAME": ["PRODUCTS"],
"COLUMN_NAME": ["SKU"],
"DATA_TYPE": ["NUMBER"],
}
)
valid_schemas_tables_representations = [
valid_schemas_tables_columns_df_alias,
valid_schemas_tables_columns_df_zip_code,
]
table_representations = [
_TABLE_WITH_THAT_EXCEEDS_CONTEXT, # Value returned on the first call.
]

with patch(
"semantic_model_generator.generate_model.get_valid_schemas_tables_columns_df",
side_effect=valid_schemas_tables_representations,
), patch(
"semantic_model_generator.generate_model.get_table_representation",
side_effect=table_representations,
):
yield


def test_raw_schema_to_semantic_context(
mock_dependencies, mock_snowflake_connection, mock_snowflake_connection_env
):
Expand Down Expand Up @@ -406,7 +455,7 @@ def test_generate_base_context_from_table_that_has_not_supported_dtype(
base_tables = ["test_db.schema_test.ALIAS"]
snowflake_account = "test_account"
output_path = "output_model_path.yaml"
semantic_model_name = "Another Incredible Semantic Model with new dtypes"
semantic_model_name = "Another Incredible Semantic Model with unsupported dtypes"

with pytest.raises(ValueError) as excinfo:
generate_base_semantic_model_from_snowflake(
Expand All @@ -432,6 +481,37 @@ def test_generate_base_context_from_table_that_has_not_supported_dtype(
mock_file().write.assert_not_called()


@patch("semantic_model_generator.generate_model.logger")
@patch("builtins.open", new_callable=mock_open)
def test_generate_base_context_from_table_that_has_too_long_context(
mock_file,
mock_logger,
mock_dependencies_exceed_context,
mock_snowflake_connection,
mock_snowflake_connection_env,
):

base_tables = ["test_db.schema_test.ALIAS"]
snowflake_account = "test_account"
output_path = "output_model_path.yaml"
semantic_model_name = "Another Incredible Semantic Model with long context"

with pytest.raises(ValueError) as excinfo:
generate_base_semantic_model_from_snowflake(
base_tables=base_tables,
snowflake_account=snowflake_account,
output_yaml_path=output_path,
semantic_model_name=semantic_model_name,
)

assert (
str(excinfo.value)
== "Your semantic model is too large. Passed size is 144558 characters. We need you to remove 116556 characters in your semantic model. Please check: \n (1) If you have long descriptions that can be truncated. \n (2) If you can remove some columns that are not used within your tables. \n (3) If you have extra tables you do not need. \n (4) If you can remove sample values."
)

mock_file().write.assert_not_called()


def test_semantic_model_to_yaml() -> None:
want_yaml = "name: transaction_ctx\ntables:\n - name: transactions\n description: A table containing data about financial transactions. Each row contains\n details of a financial transaction.\n base_table:\n database: my_database\n schema: my_schema\n table: transactions\n dimensions:\n - name: transaction_id\n description: A unique id for this transaction.\n expr: transaction_id\n data_type: BIGINT\n unique: true\n time_dimensions:\n - name: initiation_date\n description: Timestamp when the transaction was initiated. In UTC.\n expr: initiation_date\n data_type: DATETIME\n measures:\n - name: amount\n description: The amount of this transaction.\n expr: amount\n data_type: DECIMAL\n default_aggregation: sum\n"
got = semantic_model_pb2.SemanticModel(
Expand Down
73 changes: 71 additions & 2 deletions semantic_model_generator/tests/validate_model_test.py

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions semantic_model_generator/validate/context_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_MODEL_CONTEXT_LENGTH = 7000 # We use 7k so that we can reserve 1k for response tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we account for instruction tokens in the prompt?

Copy link
Collaborator Author

@sfc-gh-jhilgart sfc-gh-jhilgart May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty loose as is. Instruction tokens, even for llama3, are only ~20 tokens.

I can add an additional buffer here though!

sfc-gh-jhilgart marked this conversation as resolved.
Show resolved Hide resolved


def validate_context_length(yaml_str: str) -> None:
# Pass in the str version of the semantic context yaml.
# This isn't exactly how many tokens the model will be, but should roughly be correct.
CHARS_PER_TOKEN = 4 # as per https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
if len(yaml_str) // CHARS_PER_TOKEN > _MODEL_CONTEXT_LENGTH:
raise ValueError(
f"Your semantic model is too large. Passed size is {len(yaml_str)} characters. We need you to remove {((len(yaml_str) // CHARS_PER_TOKEN)-_MODEL_CONTEXT_LENGTH ) *CHARS_PER_TOKEN } characters in your semantic model. Please check: \n (1) If you have long descriptions that can be truncated. \n (2) If you can remove some columns that are not used within your tables. \n (3) If you have extra tables you do not need. \n (4) If you can remove sample values."
)
7 changes: 4 additions & 3 deletions semantic_model_generator/validate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SnowflakeConnector,
)
from semantic_model_generator.sqlgen.generate_sql import generate_select_with_all_cols
from semantic_model_generator.validate.context_length import validate_context_length


def validate(yaml_path: str, snowflake_account: str) -> None:
Expand All @@ -19,6 +20,8 @@ def validate(yaml_path: str, snowflake_account: str) -> None:
"""
with open(yaml_path) as f:
yaml_str = f.read()
# Validate the context length doesn't exceed max we can support.
validate_context_length(yaml_str)
model = yaml_to_semantic_model(yaml_str)

connector = SnowflakeConnector(
Expand All @@ -38,9 +41,7 @@ def validate(yaml_path: str, snowflake_account: str) -> None:
# Run the query
_ = conn.cursor().execute(select)
except Exception as e:
raise ValueError(
f"Unable to execute query with your logical table against base tables on Snowflake. Error = {e}"
)
raise ValueError(f"Unable to validate your semantic model. Error = {e}")
logger.info(f"Validated logical table: {table.name}")

logger.info(f"Successfully validated {yaml_path}")
Expand Down
Loading