Skip to content

Commit

Permalink
Merge branch 'dev' into 1057-implement-ml-tags-backend-in-cosmos
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirandawadi authored Nov 21, 2024
2 parents e8b0cf0 + 0f0d407 commit 911010b
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 54 deletions.
12 changes: 10 additions & 2 deletions .envs/.local/.django
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ SINEQUA_CONFIGS_REPO_WEBAPP_PR_BRANCH='dummy_branch'
# Slack Webhook
# ------------------------------------------------------------------------------
SLACK_WEBHOOK_URL=''
LRM_USER=''
LRM_PASSWORD=''

#Server Credentials
#--------------------------------------------------------------------------------
LRM_DEV_USER=''
LRM_DEV_PASSWORD=''
XLI_USER=''
XLI_PASSWORD=''
LRM_QA_USER=''
LRM_QA_PASSWORD=''

#Server Tokens
#--------------------------------------------------------------------------------
LRM_DEV_TOKEN=''
XLI_TOKEN=''
6 changes: 4 additions & 2 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@
SLACK_WEBHOOK_URL = env("SLACK_WEBHOOK_URL")
XLI_USER = env("XLI_USER")
XLI_PASSWORD = env("XLI_PASSWORD")
LRM_USER = env("LRM_USER")
LRM_PASSWORD = env("LRM_PASSWORD")
LRM_DEV_USER = env("LRM_DEV_USER")
LRM_DEV_PASSWORD = env("LRM_DEV_PASSWORD")
LRM_QA_USER = env("LRM_QA_USER")
LRM_QA_PASSWORD = env("LRM_QA_PASSWORD")
LRM_DEV_TOKEN = env("LRM_DEV_TOKEN")
XLI_TOKEN = env("XLI_TOKEN")
2 changes: 1 addition & 1 deletion scripts/ej/cmr_to_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def categorize_processing_level(level):
# remove existing data
EnvironmentalJusticeRow.objects.filter(destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV).delete()

ej_dump = json.load(open("backups/ej_dump_20240815_112916.json"))
ej_dump = json.load(open("backups/ej_dump_20241017_133151.json.json"))
for dataset in ej_dump:
ej_row = EnvironmentalJusticeRow(
destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV,
Expand Down
37 changes: 26 additions & 11 deletions scripts/ej/create_ej_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
inferences are supplied by the classification model. the contact point is Bishwas
cmr is supplied by running
github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py
move to the serve like this: scp ej_dump_20240814_143036.json sde:/home/ec2-user/sde_indexing_helper/backups/
move to the server like this: scp ej_dump_20241017_133151.json sde:/home/ec2-user/sde_indexing_helper/backups/
"""

import json
Expand All @@ -19,20 +19,22 @@ def save_to_json(data: dict | list, file_path: str) -> None:
json.dump(data, file, indent=2)


def process_classifications(predictions: list[dict[str, float]], threshold: float = 0.5) -> list[str]:
def process_classifications(predictions: list[dict[str, float]], thresholds: dict[str, float]) -> list[str]:
"""
Process the predictions and classify as follows:
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification
2. Filter classifications based on the threshold, excluding 'Not EJ'
3. Default to 'Not EJ' if no classifications meet the threshold
Process the predictions and classify based on the individual thresholds per indicator:
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification.
2. Filter classifications based on their individual thresholds, excluding 'Not EJ'.
3. Default to 'Not EJ' if no classifications meet the threshold.
"""
highest_prediction = max(predictions, key=lambda x: x["score"])

if highest_prediction["label"] == "Not EJ":
return ["Not EJ"]

classifications = [
pred["label"] for pred in predictions if pred["score"] >= threshold and pred["label"] != "Not EJ"
pred["label"]
for pred in predictions
if pred["score"] >= thresholds[pred["label"]] and pred["label"] != "Not EJ"
]

return classifications if classifications else ["Not EJ"]
Expand Down Expand Up @@ -63,14 +65,14 @@ def remove_unauthorized_classifications(classifications: list[str]) -> list[str]
def update_cmr_with_classifications(
inferences: list[dict[str, dict]],
cmr_dict: dict[str, dict[str, dict]],
threshold: float = 0.5,
thresholds: dict[str, float],
) -> list[dict[str, dict]]:
"""Update CMR data with valid classifications based on inferences."""

predicted_cmr = []

for inference in inferences:
classifications = process_classifications(predictions=inference["predictions"], threshold=threshold)
classifications = process_classifications(predictions=inference["predictions"], thresholds=thresholds)
classifications = remove_unauthorized_classifications(classifications)

if classifications:
Expand All @@ -84,17 +86,30 @@ def update_cmr_with_classifications(


def main():
inferences = load_json_file("cmr-inference.json")
thresholds = {
"Not EJ": 0.80,
"Climate Change": 0.95,
"Disasters": 0.80,
"Extreme Heat": 0.50,
"Food Availability": 0.80,
"Health & Air Quality": 0.90,
"Human Dimensions": 0.80,
"Urban Flooding": 0.50,
"Water Availability": 0.80,
}

inferences = load_json_file("alpha-1.3-wise-vortex-42-predictions.json")
cmr = load_json_file("cmr_collections_umm_20240807_142146.json")

cmr_dict = create_cmr_dict(cmr)

predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, threshold=0.8)
predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, thresholds=thresholds)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"ej_dump_{timestamp}.json"

save_to_json(predicted_cmr, file_name)
print(f"Saved to {file_name}")


if __name__ == "__main__":
Expand Down
34 changes: 26 additions & 8 deletions sde_collections/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,24 @@
from .models.candidate_url import CandidateURL, ResolvedTitle
from .models.collection import Collection, WorkflowHistory
from .models.pattern import DivisionPattern, IncludePattern, TitlePattern
from .tasks import import_candidate_urls_from_api
from .models.collection_choice_fields import TDAMMTags
from .tasks import fetch_and_update_full_text, import_candidate_urls_from_api


def fetch_and_update_text_for_server(modeladmin, request, queryset, server_name):
for collection in queryset:
fetch_and_update_full_text.delay(collection.id, server_name)
modeladmin.message_user(request, f"Started importing URLs from {server_name.upper()} Server")


@admin.action(description="Import candidate URLs from LRM Dev Server with Full Text")
def fetch_full_text_lrm_dev_action(modeladmin, request, queryset):
fetch_and_update_text_for_server(modeladmin, request, queryset, "lrm_dev")


@admin.action(description="Import candidate URLs from XLI Server with Full Text")
def fetch_full_text_lis_action(modeladmin, request, queryset):
fetch_and_update_text_for_server(modeladmin, request, queryset, "xli")


@admin.action(description="Generate deployment message")
Expand Down Expand Up @@ -111,7 +127,7 @@ def import_candidate_urls_from_api_caller(modeladmin, request, queryset, server_
messages.add_message(
request,
messages.INFO,
f"Started importing URLs from the API for: {collection_names} from {server_name.title()}",
f"Started importing URLs from the API for: {collection_names} from {server_name.upper()} Server",
)


Expand All @@ -135,19 +151,19 @@ def import_candidate_urls_secret_production(modeladmin, request, queryset):
import_candidate_urls_from_api_caller(modeladmin, request, queryset, "secret_production")


@admin.action(description="Import candidate URLs from Li's Server")
def import_candidate_urls_lis_server(modeladmin, request, queryset):
import_candidate_urls_from_api_caller(modeladmin, request, queryset, "lis_server")
@admin.action(description="Import candidate URLs from XLI Server")
def import_candidate_urls_xli_server(modeladmin, request, queryset):
import_candidate_urls_from_api_caller(modeladmin, request, queryset, "xli")


@admin.action(description="Import candidate URLs from LRM Dev Server")
def import_candidate_urls_lrm_dev_server(modeladmin, request, queryset):
import_candidate_urls_from_api_caller(modeladmin, request, queryset, "lrm_dev_server")
import_candidate_urls_from_api_caller(modeladmin, request, queryset, "lrm_dev")


@admin.action(description="Import candidate URLs from LRM QA Server")
def import_candidate_urls_lrm_qa_server(modeladmin, request, queryset):
import_candidate_urls_from_api_caller(modeladmin, request, queryset, "lrm_qa_server")
import_candidate_urls_from_api_caller(modeladmin, request, queryset, "lrm_qa")


class ExportCsvMixin:
Expand Down Expand Up @@ -239,9 +255,11 @@ class CollectionAdmin(admin.ModelAdmin, ExportCsvMixin, UpdateConfigMixin):
import_candidate_urls_production,
import_candidate_urls_secret_test,
import_candidate_urls_secret_production,
import_candidate_urls_lis_server,
import_candidate_urls_xli_server,
import_candidate_urls_lrm_dev_server,
import_candidate_urls_lrm_qa_server,
fetch_full_text_lrm_dev_action,
fetch_full_text_lis_action,
]
ordering = ("cleaning_order",)

Expand Down
18 changes: 18 additions & 0 deletions sde_collections/migrations/0059_candidateurl_scraped_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.9 on 2024-10-21 23:10

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("sde_collections", "0058_candidateurl_division_collection_is_multi_division_and_more"),
]

operations = [
migrations.AddField(
model_name="candidateurl",
name="scraped_text",
field=models.TextField(blank=True, null=True),
),
]
24 changes: 24 additions & 0 deletions sde_collections/migrations/0060_alter_candidateurl_scraped_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 4.2.9 on 2024-11-07 17:34

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("sde_collections", "0059_candidateurl_scraped_text"),
]

operations = [
migrations.AlterField(
model_name="candidateurl",
name="scraped_text",
field=models.TextField(
blank=True,
default="",
help_text="This is the text scraped by Sinequa",
null=True,
verbose_name="Scraped Text",
),
),
]
7 changes: 7 additions & 0 deletions sde_collections/models/candidate_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ class CandidateURL(models.Model):
blank=True,
help_text="This is the original title scraped by Sinequa",
)
scraped_text = models.TextField(
"Scraped Text",
default="",
null=True,
blank=True,
help_text="This is the text scraped by Sinequa",
)
generated_title = models.CharField(
"Generated Title",
default="",
Expand Down
Loading

0 comments on commit 911010b

Please sign in to comment.