Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/case sensitivity #101

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions space2stats_api/src/space2stats_ingest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def read_parquet_file(file_path: str) -> pa.Table:
table = pq.read_table(tmp_file.name)
else:
table = pq.read_table(file_path)
return table
return table.rename_columns([col.lower() for col in table.column_names])


def get_stac_fields_from_item(stac_item_path: str) -> Set[str]:
item = Item.from_file(stac_item_path)
columns = [c["name"] for c in item.properties.get("table:columns")]
columns = [c["name"].lower() for c in item.properties.get("table:columns")]
zacdezgeo marked this conversation as resolved.
Show resolved Hide resolved
return set(columns)


Expand Down Expand Up @@ -55,8 +55,6 @@ def verify_columns(
raise ValueError("The 'hex_id' column is missing from the Parquet file.")

# Verify Parquet columns match the STAC fields
# TODO: Standarize the hex_id in the parquet files/STAC items.
# if parquet_columns - {"hex_id"} != stac_fields:
if parquet_columns != stac_fields:
extra_in_parquet = parquet_columns - stac_fields
extra_in_stac = stac_fields - parquet_columns
Expand All @@ -72,7 +70,7 @@ def verify_columns(
FROM information_schema.columns
WHERE table_name = '{TABLE_NAME}'
""")
existing_columns = set(row[0] for row in cur.fetchall())
existing_columns = set(row[0].lower() for row in cur.fetchall())

# Check for overlap in columns (excluding 'hex_id')
overlapping_columns = parquet_columns.intersection(existing_columns) - {"hex_id"}
Expand Down Expand Up @@ -154,7 +152,7 @@ def load_parquet_to_db(
FROM information_schema.columns
WHERE table_name = '{temp_table}'
AND column_name NOT IN (
SELECT column_name FROM information_schema.columns WHERE table_name = '{TABLE_NAME}'
SELECT LOWER(column_name) FROM information_schema.columns WHERE table_name = '{TABLE_NAME}'
)
""")
new_columns = cur.fetchall()
Expand All @@ -166,14 +164,15 @@ def load_parquet_to_db(
# Add new columns to the main table
for column, column_type in new_columns:
cur.execute(
f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column} {column_type}"
f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column.lower()} {column_type}"
)

print(f"Adding new columns: {[c[0] for c in new_columns]}...")

# Construct the SET clause for the update query
update_columns = [
f"{column} = temp.{column}" for column, _ in new_columns
f"{column.lower()} = temp.{column.lower()}"
for column, _ in new_columns
]
set_clause = ", ".join(update_columns)

Expand Down
49 changes: 49 additions & 0 deletions space2stats_api/src/tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,52 @@ def test_hex_id_column_mandatory(clean_database, tmpdir):
load_parquet_to_db(str(parquet_file), connection_string, str(item_file))
except ValueError as e:
assert "The 'hex_id' column is missing from the Parquet file." in str(e)


def test_case_sensitivity_in_columns(clean_database, tmpdir):
connection_string = f"postgresql://{clean_database.user}:{clean_database.password}@{clean_database.host}:{clean_database.port}/{clean_database.dbname}"

# Create Parquet file with a column name that includes capitalization
parquet_file = tmpdir.join("case_sensitivity.parquet")
data = {
"Hex_ID": ["hex_1", "hex_2"], # Capitalized column name
"Sum_Pop": [100, 200],
}
table = pa.table(data)
pq.write_table(table, parquet_file)

# Create corresponding STAC item with matching capitalization
stac_item = {
"type": "Feature",
"stac_version": "1.0.0",
"id": "space2stats_case_sensitivity",
"properties": {
"table:columns": [
{"name": "Hex_ID", "type": "string"},
{"name": "Sum_Pop", "type": "int64"},
],
"datetime": "2024-10-07T11:21:25.944150Z",
},
"geometry": None,
"bbox": [-180, -90, 180, 90],
"links": [],
"assets": {},
}

item_file = tmpdir.join("case_sensitivity.json")
with open(item_file, "w") as f:
json.dump(stac_item, f)

# Attempt to load into the database
load_parquet_to_db(str(parquet_file), connection_string, str(item_file))

# Validate the data was inserted correctly
with psycopg.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_1'")
result = cur.fetchone()
assert result == ("hex_1", 100)

cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_2'")
result = cur.fetchone()
assert result == ("hex_2", 200)
Loading