From be2ea270805f646e9e8495f09cd3fecd79eb0f6c Mon Sep 17 00:00:00 2001
From: Zac Deziel <zachary.deziel@gmail.com>
Date: Wed, 11 Dec 2024 13:05:44 -0800
Subject: [PATCH] Bug/case sensitivity (#101)

* Add test to reproduce error

* Use lower case for column names

Ingesting and validating column names is all done in lowercase which includes STAC metadata. Updated test to query 'hex_id' and not 'Hex_Id'.
---
 .../src/space2stats_ingest/main.py            | 15 +++---
 space2stats_api/src/tests/test_ingest.py      | 49 +++++++++++++++++++
 2 files changed, 56 insertions(+), 8 deletions(-)

diff --git a/space2stats_api/src/space2stats_ingest/main.py b/space2stats_api/src/space2stats_ingest/main.py
index fbf843c..c25eea5 100644
--- a/space2stats_api/src/space2stats_ingest/main.py
+++ b/space2stats_api/src/space2stats_ingest/main.py
@@ -21,12 +21,12 @@ def read_parquet_file(file_path: str) -> pa.Table:
             table = pq.read_table(tmp_file.name)
     else:
         table = pq.read_table(file_path)
-    return table
+    return table.rename_columns([col.lower() for col in table.column_names])
 
 
 def get_stac_fields_from_item(stac_item_path: str) -> Set[str]:
     item = Item.from_file(stac_item_path)
-    columns = [c["name"] for c in item.properties.get("table:columns")]
+    columns = [c["name"].lower() for c in item.properties.get("table:columns")]
     return set(columns)
 
 
@@ -55,8 +55,6 @@ def verify_columns(
         raise ValueError("The 'hex_id' column is missing from the Parquet file.")
 
     # Verify Parquet columns match the STAC fields
-    # TODO: Standarize the hex_id in the parquet files/STAC items.
-    # if parquet_columns - {"hex_id"} != stac_fields:
     if parquet_columns != stac_fields:
         extra_in_parquet = parquet_columns - stac_fields
         extra_in_stac = stac_fields - parquet_columns
@@ -72,7 +70,7 @@ def verify_columns(
                 FROM information_schema.columns
                 WHERE table_name = '{TABLE_NAME}'
             """)
-            existing_columns = set(row[0] for row in cur.fetchall())
+            existing_columns = set(row[0].lower() for row in cur.fetchall())
 
     # Check for overlap in columns (excluding 'hex_id')
     overlapping_columns = parquet_columns.intersection(existing_columns) - {"hex_id"}
@@ -154,7 +152,7 @@ def load_parquet_to_db(
                 FROM information_schema.columns
                 WHERE table_name = '{temp_table}'
                 AND column_name NOT IN (
-                    SELECT column_name FROM information_schema.columns WHERE table_name = '{TABLE_NAME}'
+                    SELECT LOWER(column_name) FROM information_schema.columns WHERE table_name = '{TABLE_NAME}'
                 )
             """)
             new_columns = cur.fetchall()
@@ -166,14 +164,15 @@ def load_parquet_to_db(
                 # Add new columns to the main table
                 for column, column_type in new_columns:
                     cur.execute(
-                        f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column} {column_type}"
+                        f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column.lower()} {column_type}"
                     )
 
                 print(f"Adding new columns: {[c[0] for c in new_columns]}...")
 
                 # Construct the SET clause for the update query
                 update_columns = [
-                    f"{column} = temp.{column}" for column, _ in new_columns
+                    f"{column.lower()} = temp.{column.lower()}"
+                    for column, _ in new_columns
                 ]
                 set_clause = ", ".join(update_columns)
 
diff --git a/space2stats_api/src/tests/test_ingest.py b/space2stats_api/src/tests/test_ingest.py
index 4b99f2d..3474084 100644
--- a/space2stats_api/src/tests/test_ingest.py
+++ b/space2stats_api/src/tests/test_ingest.py
@@ -294,3 +294,52 @@ def test_hex_id_column_mandatory(clean_database, tmpdir):
         load_parquet_to_db(str(parquet_file), connection_string, str(item_file))
     except ValueError as e:
         assert "The 'hex_id' column is missing from the Parquet file." in str(e)
+
+
+def test_case_sensitivity_in_columns(clean_database, tmpdir):
+    connection_string = f"postgresql://{clean_database.user}:{clean_database.password}@{clean_database.host}:{clean_database.port}/{clean_database.dbname}"
+
+    # Create Parquet file with a column name that includes capitalization
+    parquet_file = tmpdir.join("case_sensitivity.parquet")
+    data = {
+        "Hex_ID": ["hex_1", "hex_2"],  # Capitalized column name
+        "Sum_Pop": [100, 200],
+    }
+    table = pa.table(data)
+    pq.write_table(table, parquet_file)
+
+    # Create corresponding STAC item with matching capitalization
+    stac_item = {
+        "type": "Feature",
+        "stac_version": "1.0.0",
+        "id": "space2stats_case_sensitivity",
+        "properties": {
+            "table:columns": [
+                {"name": "Hex_ID", "type": "string"},
+                {"name": "Sum_Pop", "type": "int64"},
+            ],
+            "datetime": "2024-10-07T11:21:25.944150Z",
+        },
+        "geometry": None,
+        "bbox": [-180, -90, 180, 90],
+        "links": [],
+        "assets": {},
+    }
+
+    item_file = tmpdir.join("case_sensitivity.json")
+    with open(item_file, "w") as f:
+        json.dump(stac_item, f)
+
+    # Attempt to load into the database
+    load_parquet_to_db(str(parquet_file), connection_string, str(item_file))
+
+    # Validate the data was inserted correctly
+    with psycopg.connect(connection_string) as conn:
+        with conn.cursor() as cur:
+            cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_1'")
+            result = cur.fetchone()
+            assert result == ("hex_1", 100)
+
+            cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_2'")
+            result = cur.fetchone()
+            assert result == ("hex_2", 200)