-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add demographic attributes assets for JAIS
- Loading branch information
Showing
4 changed files
with
348 additions
and
0 deletions.
There are no files selected for viewing
48 changes: 48 additions & 0 deletions
48
assets/ar/demographic_attributes/gender/ArabGend_JAIS13b_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from llmebench.datasets import ArabGendDataset | ||
from llmebench.models import FastChatModel | ||
from llmebench.tasks import ClassificationTask | ||
|
||
|
||
def metadata(): | ||
return { | ||
"author": "Arabic Language Technologies, QCRI, HBKU", | ||
"model": "JAIS-13b", | ||
"description": "Locally hosted JAIS-13b-chat model using FastChat.", | ||
"scores": {"Macro-F1": ""}, | ||
} | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": ArabGendDataset, | ||
"task": ClassificationTask, | ||
"model": FastChatModel, | ||
"model_args": { | ||
"class_labels": ["m", "f"], | ||
"max_tries": 3, | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
base_prompt = ( | ||
f"Identify the gender from the following name as 'female' or 'male'.\n\n" | ||
f"name: {input_sample}" | ||
f"gender: \n" | ||
) | ||
return [ | ||
{ | ||
"role": "user", | ||
"content": base_prompt, | ||
}, | ||
] | ||
|
||
|
||
def post_process(response): | ||
label = response["choices"][0]["message"]["content"] | ||
if label.lower() == 'male': | ||
return 'm' | ||
elif "female" in label.lower(): | ||
return "f" | ||
else: | ||
return None |
58 changes: 58 additions & 0 deletions
58
assets/ar/demographic_attributes/gender/ArapTweet_JAIS13b_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from llmebench.datasets import ArapTweetDataset | ||
from llmebench.models import FastChatModel | ||
from llmebench.tasks import ClassificationTask | ||
|
||
|
||
def metadata(): | ||
return { | ||
"author": "Arabic Language Technologies, QCRI, HBKU", | ||
"model": "JAIS-13b", | ||
"description": "Locally hosted JAIS-13b-chat model using FastChat.", | ||
"scores": {"Macro-F1": ""}, | ||
} | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": ArapTweetDataset, | ||
"task": ClassificationTask, | ||
"model": FastChatModel, | ||
"model_args": { | ||
"class_labels": ["Female", "Male"], | ||
"max_tries": 30, | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
base_prompt = ( | ||
f"Identify the gender from the following name as 'Female' or 'Male'.\n\n" | ||
f"name: {input_sample}" | ||
f"gender: \n" | ||
) | ||
|
||
return [ | ||
{ | ||
"role": "user", | ||
"content": base_prompt, | ||
}, | ||
] | ||
|
||
|
||
def post_process(response): | ||
label = response["choices"][0]["message"]["content"] | ||
# label = label.replace("gender:", "").strip() | ||
if "gender: Female" in label or "\nFemale" in label or label == "Female": | ||
label = "Female" | ||
elif ( | ||
"gender: Male" in label | ||
or "\nMale" in label | ||
or "likely to be 'Male'" in label | ||
or label == "Male" | ||
or "typically a 'Male' name" in label | ||
): | ||
label = "Male" | ||
else: | ||
label = None | ||
|
||
return label |
81 changes: 81 additions & 0 deletions
81
assets/ar/demographic_attributes/location/Location_JAIS13b_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from llmebench.datasets import LocationDataset | ||
from llmebench.models import FastChatModel | ||
from llmebench.tasks import DemographyLocationTask | ||
|
||
|
||
def metadata(): | ||
return { | ||
"author": "Arabic Language Technologies, QCRI, HBKU", | ||
"model": "JAIS-13b", | ||
"description": "Locally hosted JAIS-13b-chat model using FastChat.", | ||
"scores": {"Macro-F1": ""}, | ||
} | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": LocationDataset, | ||
"task": DemographyLocationTask, | ||
"model": FastChatModel, | ||
"model_args": { | ||
"class_labels": [ | ||
"ae", | ||
"OTHERS", | ||
"bh", | ||
"dz", | ||
"eg", | ||
"iq", | ||
"jo", | ||
"kw", | ||
"lb", | ||
"ly", | ||
"ma", | ||
"om", | ||
"ps", | ||
"qa", | ||
"sa", | ||
"sd", | ||
"so", | ||
"sy", | ||
"tn", | ||
"UNK", | ||
"ye", | ||
"mr", | ||
], | ||
"max_tries": 30, | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
base_prompt = ( | ||
f"Given the following 'user location', identify and map it to its corresponding country code in accordance with ISO 3166-1 alpha-2. " | ||
f"Please write the country code only, with no additional explanations. " | ||
f"If the country is not an Arab country, please write 'OTHERS'. If the location doesn't map to a recognized country, write 'UNK'.\n\n" | ||
f"user location: {input_sample}\n" | ||
f"country code: \n" | ||
) | ||
|
||
return [ | ||
{ | ||
"role": "user", | ||
"content": base_prompt, | ||
}, | ||
] | ||
|
||
|
||
def post_process(response): | ||
label = response["choices"][0]["message"]["content"].lower() | ||
|
||
label_list = config()["model_args"]["class_labels"] | ||
|
||
if "country code: " in label: | ||
label_fixed = label.replace("country code: ", "") | ||
elif label.lower() == 'uae': | ||
label_fixed = 'ae' | ||
elif label in label_list: | ||
label_fixed = label | ||
else: | ||
label_fixed = None | ||
|
||
return label_fixed |
161 changes: 161 additions & 0 deletions
161
assets/ar/demographic_attributes/name_info/NameInfo_JAIS13b_ZeroShot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from llmebench.datasets import NameInfoDataset | ||
from llmebench.models import FastChatModel | ||
from llmebench.tasks import DemographyNameInfoTask | ||
|
||
|
||
def metadata(): | ||
return { | ||
"author": "Arabic Language Technologies, QCRI, HBKU", | ||
"model": "JAIS-13b", | ||
"description": "Locally hosted JAIS-13b-chat model using FastChat.", | ||
"scores": {"Macro-F1": ""}, | ||
} | ||
|
||
|
||
def config(): | ||
return { | ||
"dataset": NameInfoDataset, | ||
"task": DemographyNameInfoTask, | ||
"model": FastChatModel, | ||
"model_args": { | ||
"class_labels": [ | ||
"gb", | ||
"us", | ||
"cl", | ||
"fr", | ||
"ru", | ||
"pl", | ||
"in", | ||
"it", | ||
"kr", | ||
"gh", | ||
"ca", | ||
"sa", | ||
"at", | ||
"de", | ||
"cn", | ||
"br", | ||
"dk", | ||
"se", | ||
"bd", | ||
"cu", | ||
"jp", | ||
"be", | ||
"es", | ||
"co", | ||
"id", | ||
"iq", | ||
"pk", | ||
"tr", | ||
"il", | ||
"ch", | ||
"ar", | ||
"ro", | ||
"nl", | ||
"ps", | ||
"ug", | ||
"ir", | ||
"cg", | ||
"do", | ||
"ee", | ||
"tn", | ||
"gr", | ||
"np", | ||
"ie", | ||
"sy", | ||
"hu", | ||
"eg", | ||
"ma", | ||
"ve", | ||
"ph", | ||
"no", | ||
"bg", | ||
"si", | ||
"ke", | ||
"au", | ||
"et", | ||
"py", | ||
"af", | ||
"pt", | ||
"th", | ||
"bo", | ||
"mx", | ||
"lb", | ||
"za", | ||
"fi", | ||
"hr", | ||
"vn", | ||
"ly", | ||
"nz", | ||
"qa", | ||
"kh", | ||
"ci", | ||
"ng", | ||
"sg", | ||
"cm", | ||
"dz", | ||
"tz", | ||
"ae", | ||
"pe", | ||
"az", | ||
"lu", | ||
"ec", | ||
"cz", | ||
"ua", | ||
"uy", | ||
"sd", | ||
"ao", | ||
"my", | ||
"lv", | ||
"kw", | ||
"tw", | ||
"bh", | ||
"lk", | ||
"ye", | ||
"cr", | ||
"jo", | ||
"pa", | ||
"om", | ||
"uz", | ||
"by", | ||
"kz", | ||
], | ||
"max_tries": 30, | ||
}, | ||
} | ||
|
||
|
||
def prompt(input_sample): | ||
base_prompt = ( | ||
f"Label the country of the following person 'name'. Write ONLY the country code in ISO 3166-1 alpha-2 format.\n\n" | ||
f"name: {input_sample}\n" | ||
f"country: \n" | ||
) | ||
return [ | ||
{ | ||
"role": "user", | ||
"content": base_prompt, | ||
}, | ||
] | ||
|
||
|
||
def post_process(response): | ||
label = response["choices"][0]["message"]["content"] | ||
|
||
label_list = config()["model_args"]["class_labels"] | ||
|
||
if "name: " in label: | ||
label_fixed = label.replace("name: ", "").lower() | ||
elif label.lower() == 'uae': | ||
label_fixed = 'ae' | ||
elif label.lower() in label_list: | ||
label_fixed = label.lower() | ||
elif ( | ||
"I'm sorry, but I cannot predict the country" in label | ||
or "I cannot predict the country" in label | ||
): | ||
label_fixed = None | ||
else: | ||
label_fixed = None | ||
|
||
return label_fixed |