From d537302dcbbc288175ec81f62994b5fec84fbcbc Mon Sep 17 00:00:00 2001 From: Carson Davis Date: Thu, 17 Oct 2024 13:52:37 -0500 Subject: [PATCH] add per indicator thrsholding and new dump --- scripts/ej/cmr_to_models.py | 2 +- scripts/ej/create_ej_dump.py | 37 +++++++++++++++++++++++++----------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/scripts/ej/cmr_to_models.py b/scripts/ej/cmr_to_models.py index 130de722..f7ba46db 100644 --- a/scripts/ej/cmr_to_models.py +++ b/scripts/ej/cmr_to_models.py @@ -69,7 +69,7 @@ def categorize_processing_level(level): # remove existing data EnvironmentalJusticeRow.objects.filter(destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV).delete() -ej_dump = json.load(open("backups/ej_dump_20240815_112916.json")) +ej_dump = json.load(open("backups/ej_dump_20241017_133151.json.json")) for dataset in ej_dump: ej_row = EnvironmentalJusticeRow( destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV, diff --git a/scripts/ej/create_ej_dump.py b/scripts/ej/create_ej_dump.py index 36d7f722..c44aebc5 100644 --- a/scripts/ej/create_ej_dump.py +++ b/scripts/ej/create_ej_dump.py @@ -2,7 +2,7 @@ inferences are supplied by the classification model. the contact point is Bishwas cmr is supplied by running github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py -move to the serve like this: scp ej_dump_20240814_143036.json sde:/home/ec2-user/sde_indexing_helper/backups/ +move to the server like this: scp ej_dump_20241017_133151.json sde:/home/ec2-user/sde_indexing_helper/backups/ """ import json @@ -19,12 +19,12 @@ def save_to_json(data: dict | list, file_path: str) -> None: json.dump(data, file, indent=2) -def process_classifications(predictions: list[dict[str, float]], threshold: float = 0.5) -> list[str]: +def process_classifications(predictions: list[dict[str, float]], thresholds: dict[str, float]) -> list[str]: """ - Process the predictions and classify as follows: - 1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification - 2. Filter classifications based on the threshold, excluding 'Not EJ' - 3. Default to 'Not EJ' if no classifications meet the threshold + Process the predictions and classify based on the individual thresholds per indicator: + 1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification. + 2. Filter classifications based on their individual thresholds, excluding 'Not EJ'. + 3. Default to 'Not EJ' if no classifications meet the threshold. """ highest_prediction = max(predictions, key=lambda x: x["score"]) @@ -32,7 +32,9 @@ def process_classifications(predictions: list[dict[str, float]], threshold: floa return ["Not EJ"] classifications = [ - pred["label"] for pred in predictions if pred["score"] >= threshold and pred["label"] != "Not EJ" + pred["label"] + for pred in predictions + if pred["score"] >= thresholds[pred["label"]] and pred["label"] != "Not EJ" ] return classifications if classifications else ["Not EJ"] @@ -63,14 +65,14 @@ def remove_unauthorized_classifications(classifications: list[str]) -> list[str] def update_cmr_with_classifications( inferences: list[dict[str, dict]], cmr_dict: dict[str, dict[str, dict]], - threshold: float = 0.5, + thresholds: dict[str, float], ) -> list[dict[str, dict]]: """Update CMR data with valid classifications based on inferences.""" predicted_cmr = [] for inference in inferences: - classifications = process_classifications(predictions=inference["predictions"], threshold=threshold) + classifications = process_classifications(predictions=inference["predictions"], thresholds=thresholds) classifications = remove_unauthorized_classifications(classifications) if classifications: @@ -84,17 +86,30 @@ def update_cmr_with_classifications( def main(): - inferences = load_json_file("cmr-inference.json") + thresholds = { + "Not EJ": 0.80, + "Climate Change": 0.95, + "Disasters": 0.80, + "Extreme Heat": 0.50, + "Food Availability": 0.80, + "Health & Air Quality": 0.90, + "Human Dimensions": 0.80, + "Urban Flooding": 0.50, + "Water Availability": 0.80, + } + + inferences = load_json_file("alpha-1.3-wise-vortex-42-predictions.json") cmr = load_json_file("cmr_collections_umm_20240807_142146.json") cmr_dict = create_cmr_dict(cmr) - predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, threshold=0.8) + predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, thresholds=thresholds) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") file_name = f"ej_dump_{timestamp}.json" save_to_json(predicted_cmr, file_name) + print(f"Saved to {file_name}") if __name__ == "__main__":