Skip to content

Commit

Permalink
Bug/case sensitivity (#101)
Browse files Browse the repository at this point in the history
* Add test to reproduce error

* Use lower case for column names

Ingesting and validating column names is all done in lowercase which includes STAC metadata. Updated test to query 'hex_id' and not 'Hex_Id'.
  • Loading branch information
zacdezgeo authored Dec 11, 2024
1 parent 104ad65 commit be2ea27
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 8 deletions.
15 changes: 7 additions & 8 deletions space2stats_api/src/space2stats_ingest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def read_parquet_file(file_path: str) -> pa.Table:
table = pq.read_table(tmp_file.name)
else:
table = pq.read_table(file_path)
return table
return table.rename_columns([col.lower() for col in table.column_names])


def get_stac_fields_from_item(stac_item_path: str) -> Set[str]:
item = Item.from_file(stac_item_path)
columns = [c["name"] for c in item.properties.get("table:columns")]
columns = [c["name"].lower() for c in item.properties.get("table:columns")]
return set(columns)


Expand Down Expand Up @@ -55,8 +55,6 @@ def verify_columns(
raise ValueError("The 'hex_id' column is missing from the Parquet file.")

# Verify Parquet columns match the STAC fields
# TODO: Standarize the hex_id in the parquet files/STAC items.
# if parquet_columns - {"hex_id"} != stac_fields:
if parquet_columns != stac_fields:
extra_in_parquet = parquet_columns - stac_fields
extra_in_stac = stac_fields - parquet_columns
Expand All @@ -72,7 +70,7 @@ def verify_columns(
FROM information_schema.columns
WHERE table_name = '{TABLE_NAME}'
""")
existing_columns = set(row[0] for row in cur.fetchall())
existing_columns = set(row[0].lower() for row in cur.fetchall())

# Check for overlap in columns (excluding 'hex_id')
overlapping_columns = parquet_columns.intersection(existing_columns) - {"hex_id"}
Expand Down Expand Up @@ -154,7 +152,7 @@ def load_parquet_to_db(
FROM information_schema.columns
WHERE table_name = '{temp_table}'
AND column_name NOT IN (
SELECT column_name FROM information_schema.columns WHERE table_name = '{TABLE_NAME}'
SELECT LOWER(column_name) FROM information_schema.columns WHERE table_name = '{TABLE_NAME}'
)
""")
new_columns = cur.fetchall()
Expand All @@ -166,14 +164,15 @@ def load_parquet_to_db(
# Add new columns to the main table
for column, column_type in new_columns:
cur.execute(
f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column} {column_type}"
f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column.lower()} {column_type}"
)

print(f"Adding new columns: {[c[0] for c in new_columns]}...")

# Construct the SET clause for the update query
update_columns = [
f"{column} = temp.{column}" for column, _ in new_columns
f"{column.lower()} = temp.{column.lower()}"
for column, _ in new_columns
]
set_clause = ", ".join(update_columns)

Expand Down
49 changes: 49 additions & 0 deletions space2stats_api/src/tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,52 @@ def test_hex_id_column_mandatory(clean_database, tmpdir):
load_parquet_to_db(str(parquet_file), connection_string, str(item_file))
except ValueError as e:
assert "The 'hex_id' column is missing from the Parquet file." in str(e)


def test_case_sensitivity_in_columns(clean_database, tmpdir):
connection_string = f"postgresql://{clean_database.user}:{clean_database.password}@{clean_database.host}:{clean_database.port}/{clean_database.dbname}"

# Create Parquet file with a column name that includes capitalization
parquet_file = tmpdir.join("case_sensitivity.parquet")
data = {
"Hex_ID": ["hex_1", "hex_2"], # Capitalized column name
"Sum_Pop": [100, 200],
}
table = pa.table(data)
pq.write_table(table, parquet_file)

# Create corresponding STAC item with matching capitalization
stac_item = {
"type": "Feature",
"stac_version": "1.0.0",
"id": "space2stats_case_sensitivity",
"properties": {
"table:columns": [
{"name": "Hex_ID", "type": "string"},
{"name": "Sum_Pop", "type": "int64"},
],
"datetime": "2024-10-07T11:21:25.944150Z",
},
"geometry": None,
"bbox": [-180, -90, 180, 90],
"links": [],
"assets": {},
}

item_file = tmpdir.join("case_sensitivity.json")
with open(item_file, "w") as f:
json.dump(stac_item, f)

# Attempt to load into the database
load_parquet_to_db(str(parquet_file), connection_string, str(item_file))

# Validate the data was inserted correctly
with psycopg.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_1'")
result = cur.fetchone()
assert result == ("hex_1", 100)

cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_2'")
result = cur.fetchone()
assert result == ("hex_2", 200)

0 comments on commit be2ea27

Please sign in to comment.