Skip to content

Commit

Permalink
Merge pull request #1073 from NASA-IMPACT/add_per_indicator_thresholding
Browse files Browse the repository at this point in the history
add per indicator thrsholding and new dump
  • Loading branch information
CarsonDavis authored Nov 20, 2024
2 parents 52ab3a5 + d537302 commit 0f0d407
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion scripts/ej/cmr_to_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def categorize_processing_level(level):
# remove existing data
EnvironmentalJusticeRow.objects.filter(destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV).delete()

ej_dump = json.load(open("backups/ej_dump_20240815_112916.json"))
ej_dump = json.load(open("backups/ej_dump_20241017_133151.json.json"))
for dataset in ej_dump:
ej_row = EnvironmentalJusticeRow(
destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV,
Expand Down
37 changes: 26 additions & 11 deletions scripts/ej/create_ej_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
inferences are supplied by the classification model. the contact point is Bishwas
cmr is supplied by running
github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py
move to the serve like this: scp ej_dump_20240814_143036.json sde:/home/ec2-user/sde_indexing_helper/backups/
move to the server like this: scp ej_dump_20241017_133151.json sde:/home/ec2-user/sde_indexing_helper/backups/
"""

import json
Expand All @@ -19,20 +19,22 @@ def save_to_json(data: dict | list, file_path: str) -> None:
json.dump(data, file, indent=2)


def process_classifications(predictions: list[dict[str, float]], threshold: float = 0.5) -> list[str]:
def process_classifications(predictions: list[dict[str, float]], thresholds: dict[str, float]) -> list[str]:
"""
Process the predictions and classify as follows:
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification
2. Filter classifications based on the threshold, excluding 'Not EJ'
3. Default to 'Not EJ' if no classifications meet the threshold
Process the predictions and classify based on the individual thresholds per indicator:
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification.
2. Filter classifications based on their individual thresholds, excluding 'Not EJ'.
3. Default to 'Not EJ' if no classifications meet the threshold.
"""
highest_prediction = max(predictions, key=lambda x: x["score"])

if highest_prediction["label"] == "Not EJ":
return ["Not EJ"]

classifications = [
pred["label"] for pred in predictions if pred["score"] >= threshold and pred["label"] != "Not EJ"
pred["label"]
for pred in predictions
if pred["score"] >= thresholds[pred["label"]] and pred["label"] != "Not EJ"
]

return classifications if classifications else ["Not EJ"]
Expand Down Expand Up @@ -63,14 +65,14 @@ def remove_unauthorized_classifications(classifications: list[str]) -> list[str]
def update_cmr_with_classifications(
inferences: list[dict[str, dict]],
cmr_dict: dict[str, dict[str, dict]],
threshold: float = 0.5,
thresholds: dict[str, float],
) -> list[dict[str, dict]]:
"""Update CMR data with valid classifications based on inferences."""

predicted_cmr = []

for inference in inferences:
classifications = process_classifications(predictions=inference["predictions"], threshold=threshold)
classifications = process_classifications(predictions=inference["predictions"], thresholds=thresholds)
classifications = remove_unauthorized_classifications(classifications)

if classifications:
Expand All @@ -84,17 +86,30 @@ def update_cmr_with_classifications(


def main():
inferences = load_json_file("cmr-inference.json")
thresholds = {
"Not EJ": 0.80,
"Climate Change": 0.95,
"Disasters": 0.80,
"Extreme Heat": 0.50,
"Food Availability": 0.80,
"Health & Air Quality": 0.90,
"Human Dimensions": 0.80,
"Urban Flooding": 0.50,
"Water Availability": 0.80,
}

inferences = load_json_file("alpha-1.3-wise-vortex-42-predictions.json")
cmr = load_json_file("cmr_collections_umm_20240807_142146.json")

cmr_dict = create_cmr_dict(cmr)

predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, threshold=0.8)
predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, thresholds=thresholds)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"ej_dump_{timestamp}.json"

save_to_json(predicted_cmr, file_name)
print(f"Saved to {file_name}")


if __name__ == "__main__":
Expand Down

0 comments on commit 0f0d407

Please sign in to comment.