Skip to content

Commit

Permalink
Add partition test
Browse files Browse the repository at this point in the history
  • Loading branch information
Sreesh Maheshwar committed Dec 23, 2024
1 parent 8523b3a commit b581214
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
37 changes: 37 additions & 0 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,43 @@ def test_query_filter_v1_v2_append_null(
assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}"


@pytest.mark.integration
@pytest.mark.parametrize(
"part_col", ["int", "bool", "string", "string_long", "long", "float", "double", "date", "timestamp", "timestamptz", "binary"]
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_object_storage_excludes_partition(
session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int
) -> None:
nested_field = TABLE_SCHEMA.find_field(part_col)
partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col)
)

tbl = _create_table(
session_catalog=session_catalog,
identifier=f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}",
properties={"format-version": str(format_version), "write.object-storage.enabled": True},
data=[arrow_table_with_null],
partition_spec=partition_spec,
)

original_paths = tbl.inspect.data_files().to_pydict()["file_path"]
assert len(original_paths) == 3

# Update props to exclude partitioned paths and append data
with tbl.transaction() as tx:
tx.set_properties({"write.object-storage.partitioned-paths": False})
tbl.append(arrow_table_with_null)

added_paths = set(tbl.inspect.data_files().to_pydict()["file_path"]) - set(original_paths)
assert len(added_paths) == 3

# All paths before the props update should contain the partition, while all paths after should not
assert all(f"{part_col}=" in path for path in original_paths)
assert all(f"{part_col}=" not in path for path in added_paths)


@pytest.mark.integration
@pytest.mark.parametrize(
"spec",
Expand Down
36 changes: 20 additions & 16 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
# Since we don't rewrite, this should produce a new manifest with an ADDED entry
tbl.append(arrow_table_with_null)


rows = spark.sql(
f"""
SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count
Expand All @@ -285,27 +284,32 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0]
assert [row.deleted_data_files_count for row in rows] == [0, 1, 0, 0, 0]


tests/integration/test_writes/test_writes.py
@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_object_storage_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None:
# TODO: What to do about "tbl.add_files()"?
identifier = "default.object_stored_table"

tbl = _create_table(session_catalog, identifier, {"format-version": format_version, "write.object-storage.enabled": True}, [])
def test_object_storage_data_files(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
tbl = _create_table(
session_catalog=session_catalog,
identifier="default.object_stored",
properties={"format-version": format_version, "write.object-storage.enabled": True},
data=[arrow_table_with_null],
)
tbl.append(arrow_table_with_null)

paths = tbl.inspect.entries().to_pydict()["data_file"]
assert len(paths) == 1
location = paths[0]["file_path"]
paths = tbl.inspect.data_files().to_pydict()["file_path"]
assert len(paths) == 2

parts = location.split("/")
assert len(parts) == 11
for location in paths:
assert location.startswith("s3://warehouse/default/object_stored/data/")
parts = location.split("/")
assert len(parts) == 11

assert location.startswith("s3://warehouse/default/object_stored_table/data/")
for i in range(6, 10):
assert len(parts[i]) == (8 if i == 9 else 4)
assert all(c in "01" for c in parts[i])
# Entropy binary directories should have been injected
for i in range(6, 10):
assert parts[i]
assert all(c in "01" for c in parts[i])


@pytest.mark.integration
Expand Down

0 comments on commit b581214

Please sign in to comment.