From be2ea270805f646e9e8495f09cd3fecd79eb0f6c Mon Sep 17 00:00:00 2001 From: Zac Deziel <zachary.deziel@gmail.com> Date: Wed, 11 Dec 2024 13:05:44 -0800 Subject: [PATCH] Bug/case sensitivity (#101) * Add test to reproduce error * Use lower case for column names Ingesting and validating column names is all done in lowercase which includes STAC metadata. Updated test to query 'hex_id' and not 'Hex_Id'. --- .../src/space2stats_ingest/main.py | 15 +++--- space2stats_api/src/tests/test_ingest.py | 49 +++++++++++++++++++ 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/space2stats_api/src/space2stats_ingest/main.py b/space2stats_api/src/space2stats_ingest/main.py index fbf843c..c25eea5 100644 --- a/space2stats_api/src/space2stats_ingest/main.py +++ b/space2stats_api/src/space2stats_ingest/main.py @@ -21,12 +21,12 @@ def read_parquet_file(file_path: str) -> pa.Table: table = pq.read_table(tmp_file.name) else: table = pq.read_table(file_path) - return table + return table.rename_columns([col.lower() for col in table.column_names]) def get_stac_fields_from_item(stac_item_path: str) -> Set[str]: item = Item.from_file(stac_item_path) - columns = [c["name"] for c in item.properties.get("table:columns")] + columns = [c["name"].lower() for c in item.properties.get("table:columns")] return set(columns) @@ -55,8 +55,6 @@ def verify_columns( raise ValueError("The 'hex_id' column is missing from the Parquet file.") # Verify Parquet columns match the STAC fields - # TODO: Standarize the hex_id in the parquet files/STAC items. - # if parquet_columns - {"hex_id"} != stac_fields: if parquet_columns != stac_fields: extra_in_parquet = parquet_columns - stac_fields extra_in_stac = stac_fields - parquet_columns @@ -72,7 +70,7 @@ def verify_columns( FROM information_schema.columns WHERE table_name = '{TABLE_NAME}' """) - existing_columns = set(row[0] for row in cur.fetchall()) + existing_columns = set(row[0].lower() for row in cur.fetchall()) # Check for overlap in columns (excluding 'hex_id') overlapping_columns = parquet_columns.intersection(existing_columns) - {"hex_id"} @@ -154,7 +152,7 @@ def load_parquet_to_db( FROM information_schema.columns WHERE table_name = '{temp_table}' AND column_name NOT IN ( - SELECT column_name FROM information_schema.columns WHERE table_name = '{TABLE_NAME}' + SELECT LOWER(column_name) FROM information_schema.columns WHERE table_name = '{TABLE_NAME}' ) """) new_columns = cur.fetchall() @@ -166,14 +164,15 @@ def load_parquet_to_db( # Add new columns to the main table for column, column_type in new_columns: cur.execute( - f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column} {column_type}" + f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column.lower()} {column_type}" ) print(f"Adding new columns: {[c[0] for c in new_columns]}...") # Construct the SET clause for the update query update_columns = [ - f"{column} = temp.{column}" for column, _ in new_columns + f"{column.lower()} = temp.{column.lower()}" + for column, _ in new_columns ] set_clause = ", ".join(update_columns) diff --git a/space2stats_api/src/tests/test_ingest.py b/space2stats_api/src/tests/test_ingest.py index 4b99f2d..3474084 100644 --- a/space2stats_api/src/tests/test_ingest.py +++ b/space2stats_api/src/tests/test_ingest.py @@ -294,3 +294,52 @@ def test_hex_id_column_mandatory(clean_database, tmpdir): load_parquet_to_db(str(parquet_file), connection_string, str(item_file)) except ValueError as e: assert "The 'hex_id' column is missing from the Parquet file." in str(e) + + +def test_case_sensitivity_in_columns(clean_database, tmpdir): + connection_string = f"postgresql://{clean_database.user}:{clean_database.password}@{clean_database.host}:{clean_database.port}/{clean_database.dbname}" + + # Create Parquet file with a column name that includes capitalization + parquet_file = tmpdir.join("case_sensitivity.parquet") + data = { + "Hex_ID": ["hex_1", "hex_2"], # Capitalized column name + "Sum_Pop": [100, 200], + } + table = pa.table(data) + pq.write_table(table, parquet_file) + + # Create corresponding STAC item with matching capitalization + stac_item = { + "type": "Feature", + "stac_version": "1.0.0", + "id": "space2stats_case_sensitivity", + "properties": { + "table:columns": [ + {"name": "Hex_ID", "type": "string"}, + {"name": "Sum_Pop", "type": "int64"}, + ], + "datetime": "2024-10-07T11:21:25.944150Z", + }, + "geometry": None, + "bbox": [-180, -90, 180, 90], + "links": [], + "assets": {}, + } + + item_file = tmpdir.join("case_sensitivity.json") + with open(item_file, "w") as f: + json.dump(stac_item, f) + + # Attempt to load into the database + load_parquet_to_db(str(parquet_file), connection_string, str(item_file)) + + # Validate the data was inserted correctly + with psycopg.connect(connection_string) as conn: + with conn.cursor() as cur: + cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_1'") + result = cur.fetchone() + assert result == ("hex_1", 100) + + cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_2'") + result = cur.fetchone() + assert result == ("hex_2", 200)