-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add JAIS assests for arabic news categorization
- Loading branch information
Showing
4 changed files
with
295 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from llmebench.datasets import ASNDDataset | ||
from llmebench.models import FastChatModel | ||
from llmebench.tasks import NewsCategorizationTask | ||
|
||
|
||
def metadata(): | ||
return { | ||
"author": "Arabic Language Technologies, QCRI, HBKU", | ||
"model": "JAIS-13b", | ||
"description": "Locally hosted JAIS-13b-chat model using FastChat.", | ||
"scores": {"Macro-F1": ""}, | ||
} | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": ASNDDataset, | ||
"task": NewsCategorizationTask, | ||
"model": FastChatModel, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
base_prompt = ( | ||
f"صنف التغريدة التالية إلى واحدة من الفئات التالية: " | ||
f"جريمة-حرب-صراع ، روحي-ديني ، صحة ، سياسة ، حقوق-الإنسان-حرية-الصحافة ، " | ||
f"تعليم ، أعمال-اقتصاد ، فن-ترفيه ، أخرى ، " | ||
f"علوم-تكنولوجيا ، رياضة ، بيئة\n" | ||
f"\nالتغريدة: {input_sample}" | ||
f"\nالفئة: \n" | ||
) | ||
|
||
return [ | ||
{ | ||
"role": "user", | ||
"content": base_prompt, | ||
}, | ||
] | ||
|
||
|
||
def post_process(response): | ||
label = response["choices"][0]["message"]["content"] | ||
|
||
if "جريمة-حرب-صراع" in label or "صراع-حرب" in label: | ||
label_fixed = "crime-war-conflict" | ||
elif "روحي" in label or "ديني" in label: | ||
label_fixed = "spiritual" | ||
elif "صحة" in label: | ||
label_fixed = "health" | ||
elif "سياسة" in label: | ||
label_fixed = "politics" | ||
elif "حقوق-الإنسان-حرية-الصحافة" in label: | ||
label_fixed = "human-rights-press-freedom" | ||
elif "تعليم" in label: | ||
label_fixed = "education" | ||
elif "أعمال-و-اقتصاد" in label or "أعمال" in label or "اقتصاد" in label: | ||
label_fixed = "business-and-economy" | ||
elif "فن-و-ترفيه" in label or "ترفيه" in label: | ||
label_fixed = "art-and-entertainment" | ||
elif "أخرى" in label: | ||
label_fixed = "others" | ||
elif "علم-و-تكنولوجيا" in label or "علوم" in label or "تكنولوجيا" in label: | ||
label_fixed = "science-and-technology" | ||
elif "رياضة" in label: | ||
label_fixed = "sports" | ||
elif "بيئة" in label: | ||
label_fixed = "environment" | ||
else: | ||
label_fixed = "others" | ||
|
||
return label_fixed |
73 changes: 73 additions & 0 deletions
73
assets/ar/news_categorization/SANADAkhbarona_JAIS13b_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import random | ||
|
||
from llmebench.datasets import SANADAkhbaronaDataset | ||
from llmebench.models import FastChatModel | ||
from llmebench.tasks import NewsCategorizationTask | ||
|
||
|
||
random.seed(1333) | ||
|
||
|
||
def metadata(): | ||
return { | ||
"author": "Arabic Language Technologies, QCRI, HBKU", | ||
"model": "JAIS-13b", | ||
"description": "Locally hosted JAIS-13b-chat model using FastChat.", | ||
"scores": {"Macro-F1": ""}, | ||
} | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": SANADAkhbaronaDataset, | ||
"task": NewsCategorizationTask, | ||
"model": FastChatModel, | ||
"model_args": { | ||
"class_labels": [ | ||
"politics", | ||
"religion", | ||
"medical", | ||
"sports", | ||
"tech", | ||
"finance", | ||
"culture", | ||
], | ||
"max_tries": 30, | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
base_prompt = ( | ||
f'Categorize the news "article" into one of the following categories: politics, religion, medical, sports, tech, finance, culture\n\n' | ||
f"article: {input_sample}\n" | ||
f"category: \n" | ||
) | ||
return [ | ||
{ | ||
"role": "user", | ||
"content": base_prompt, | ||
}, | ||
] | ||
|
||
|
||
def post_process(response): | ||
label = response["choices"][0]["message"]["content"] | ||
|
||
label_fixed = label.lower() | ||
label_fixed = label_fixed.replace("category: ", "") | ||
label_fixed = label_fixed.replace("science/physics", "tech") | ||
label_fixed = label_fixed.replace("health/nutrition", "medical") | ||
if "سياسة" in label or "السياسة" in label: | ||
label_fixed = "politics" | ||
if len(label_fixed.split("\s+")) > 1: | ||
label_fixed = label_fixed.split("\s+")[0] | ||
label_fixed = random.choice(label_fixed.split("/")).strip() | ||
if "science/physics" in label_fixed: | ||
label_fixed = label_fixed.replace("science/physics", "tech") | ||
if label_fixed.startswith("culture"): | ||
label_fixed = label_fixed.split("(")[0] | ||
|
||
label_fixed = label_fixed.replace("culture.", "culture") | ||
|
||
return label_fixed |
73 changes: 73 additions & 0 deletions
73
assets/ar/news_categorization/SANADAlArabiya_JAIS13b_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import random | ||
|
||
from llmebench.datasets import SANADAlArabiyaDataset | ||
from llmebench.models import FastChatModel | ||
from llmebench.tasks import NewsCategorizationTask | ||
|
||
|
||
random.seed(1333) | ||
|
||
|
||
def metadata(): | ||
return { | ||
"author": "Arabic Language Technologies, QCRI, HBKU", | ||
"model": "JAIS-13b", | ||
"description": "Locally hosted JAIS-13b-chat model using FastChat.", | ||
"scores": {"Macro-F1": ""}, | ||
} | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": SANADAlArabiyaDataset, | ||
"task": NewsCategorizationTask, | ||
"model": FastChatModel, | ||
"model_args": { | ||
"class_labels": [ | ||
"politics", | ||
"religion", | ||
"medical", | ||
"sports", | ||
"tech", | ||
"finance", | ||
"culture", | ||
], | ||
"max_tries": 30, | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
base_prompt = ( | ||
f'Categorize the news "article" into one of the following categories: politics, religion, medical, sports, tech, finance, culture\n\n' | ||
f"article: {input_sample}\n" | ||
f"category: \n" | ||
) | ||
return [ | ||
{ | ||
"role": "user", | ||
"content": base_prompt, | ||
}, | ||
] | ||
|
||
|
||
def post_process(response): | ||
label = response["choices"][0]["message"]["content"] | ||
|
||
label_fixed = label.lower() | ||
label_fixed = label_fixed.replace("category: ", "") | ||
label_fixed = label_fixed.replace("science/physics", "tech") | ||
label_fixed = label_fixed.replace("health/nutrition", "medical") | ||
if "سياسة" in label or "السياسة" in label: | ||
label_fixed = "politics" | ||
if len(label_fixed.split("\s+")) > 1: | ||
label_fixed = label_fixed.split("\s+")[0] | ||
label_fixed = random.choice(label_fixed.split("/")).strip() | ||
if "science/physics" in label_fixed: | ||
label_fixed = label_fixed.replace("science/physics", "tech") | ||
if label_fixed.startswith("culture"): | ||
label_fixed = label_fixed.split("(")[0] | ||
|
||
label_fixed = label_fixed.replace("culture.", "culture") | ||
|
||
return label_fixed |
78 changes: 78 additions & 0 deletions
78
assets/ar/news_categorization/SANADAlKhaleej_JAIS13b_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import random | ||
|
||
from llmebench.datasets import SANADAlKhaleejDataset | ||
from llmebench.models import FastChatModel | ||
from llmebench.tasks import NewsCategorizationTask | ||
|
||
|
||
random.seed(1333) | ||
|
||
|
||
def metadata(): | ||
return { | ||
"author": "Arabic Language Technologies, QCRI, HBKU", | ||
"model": "JAIS-13b", | ||
"description": "Locally hosted JAIS-13b-chat model using FastChat.", | ||
"scores": {"Macro-F1": ""}, | ||
} | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": SANADAlKhaleejDataset, | ||
"task": NewsCategorizationTask, | ||
"model": FastChatModel, | ||
"model_args": { | ||
"class_labels": [ | ||
"culture", | ||
"finance", | ||
"medical", | ||
"politics", | ||
"religion", | ||
"sports", | ||
"tech", | ||
], | ||
"max_tries": 30, | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
base_prompt = ( | ||
f'Categorize the news "article" into one of the following categories: culture, finance, medical, politics, religion, sports, tech\n\n' | ||
f"article: {input_sample}\n" | ||
f"category: \n" | ||
) | ||
return [ | ||
{ | ||
"role": "user", | ||
"content": base_prompt, | ||
}, | ||
] | ||
|
||
|
||
def post_process(response): | ||
label = response["choices"][0]["message"]["content"] | ||
label_list = config()["model_args"]["class_labels"] | ||
label_fixed = label.lower() | ||
label_fixed = label_fixed.replace("category: ", "") | ||
label_fixed = label_fixed.replace("science/physics", "tech") | ||
label_fixed = label_fixed.replace("health/nutrition", "medical") | ||
|
||
if "سياسة" in label or "السياسة" in label: | ||
label_fixed = "politics" | ||
|
||
if label_fixed.strip() in label_list: | ||
label_fixed = label_fixed.strip() | ||
|
||
elif "science/physics" in label_fixed: | ||
label_fixed = label_fixed.replace("science/physics", "tech") | ||
elif label_fixed.startswith("culture"): | ||
label_fixed = label_fixed.split("(")[0] | ||
label_fixed = label_fixed.replace("culture.", "culture") | ||
elif "/" in label: | ||
label_fixed = random.choice(label_fixed.split("/")).strip() | ||
else: | ||
label_fixed = None | ||
|
||
return label_fixed |