From 131bda77d7d624fd0d2ad7b96182efc7cdceb6ff Mon Sep 17 00:00:00 2001 From: Arid Hasan Date: Tue, 2 Jan 2024 00:43:11 -0400 Subject: [PATCH] Add demographic attributes assets for JAIS --- .../gender/ArabGend_JAIS13b_ZeroShot.py | 48 ++++++ .../gender/ArapTweet_JAIS13b_ZeroShot.py | 58 +++++++ .../location/Location_JAIS13b_ZeroShot.py | 81 +++++++++ .../name_info/NameInfo_JAIS13b_ZeroShot.py | 161 ++++++++++++++++++ 4 files changed, 348 insertions(+) create mode 100644 assets/ar/demographic_attributes/gender/ArabGend_JAIS13b_ZeroShot.py create mode 100644 assets/ar/demographic_attributes/gender/ArapTweet_JAIS13b_ZeroShot.py create mode 100644 assets/ar/demographic_attributes/location/Location_JAIS13b_ZeroShot.py create mode 100644 assets/ar/demographic_attributes/name_info/NameInfo_JAIS13b_ZeroShot.py diff --git a/assets/ar/demographic_attributes/gender/ArabGend_JAIS13b_ZeroShot.py b/assets/ar/demographic_attributes/gender/ArabGend_JAIS13b_ZeroShot.py new file mode 100644 index 00000000..b6a902f7 --- /dev/null +++ b/assets/ar/demographic_attributes/gender/ArabGend_JAIS13b_ZeroShot.py @@ -0,0 +1,48 @@ +from llmebench.datasets import ArabGendDataset +from llmebench.models import FastChatModel +from llmebench.tasks import ClassificationTask + + +def metadata(): + return { + "author": "Arabic Language Technologies, QCRI, HBKU", + "model": "JAIS-13b", + "description": "Locally hosted JAIS-13b-chat model using FastChat.", + "scores": {"Macro-F1": ""}, + } + + +def config(): + return { + "dataset": ArabGendDataset, + "task": ClassificationTask, + "model": FastChatModel, + "model_args": { + "class_labels": ["m", "f"], + "max_tries": 3, + }, + } + + +def prompt(input_sample): + base_prompt = ( + f"Identify the gender from the following name as 'female' or 'male'.\n\n" + f"name: {input_sample}" + f"gender: \n" + ) + return [ + { + "role": "user", + "content": base_prompt, + }, + ] + + +def post_process(response): + label = response["choices"][0]["message"]["content"] + if label.lower() == 'male': + return 'm' + elif "female" in label.lower(): + return "f" + else: + return None diff --git a/assets/ar/demographic_attributes/gender/ArapTweet_JAIS13b_ZeroShot.py b/assets/ar/demographic_attributes/gender/ArapTweet_JAIS13b_ZeroShot.py new file mode 100644 index 00000000..3c56fc81 --- /dev/null +++ b/assets/ar/demographic_attributes/gender/ArapTweet_JAIS13b_ZeroShot.py @@ -0,0 +1,58 @@ +from llmebench.datasets import ArapTweetDataset +from llmebench.models import FastChatModel +from llmebench.tasks import ClassificationTask + + +def metadata(): + return { + "author": "Arabic Language Technologies, QCRI, HBKU", + "model": "JAIS-13b", + "description": "Locally hosted JAIS-13b-chat model using FastChat.", + "scores": {"Macro-F1": ""}, + } + + +def config(): + return { + "dataset": ArapTweetDataset, + "task": ClassificationTask, + "model": FastChatModel, + "model_args": { + "class_labels": ["Female", "Male"], + "max_tries": 30, + }, + } + + +def prompt(input_sample): + base_prompt = ( + f"Identify the gender from the following name as 'Female' or 'Male'.\n\n" + f"name: {input_sample}" + f"gender: \n" + ) + + return [ + { + "role": "user", + "content": base_prompt, + }, + ] + + +def post_process(response): + label = response["choices"][0]["message"]["content"] + # label = label.replace("gender:", "").strip() + if "gender: Female" in label or "\nFemale" in label or label == "Female": + label = "Female" + elif ( + "gender: Male" in label + or "\nMale" in label + or "likely to be 'Male'" in label + or label == "Male" + or "typically a 'Male' name" in label + ): + label = "Male" + else: + label = None + + return label diff --git a/assets/ar/demographic_attributes/location/Location_JAIS13b_ZeroShot.py b/assets/ar/demographic_attributes/location/Location_JAIS13b_ZeroShot.py new file mode 100644 index 00000000..e558f5aa --- /dev/null +++ b/assets/ar/demographic_attributes/location/Location_JAIS13b_ZeroShot.py @@ -0,0 +1,81 @@ +from llmebench.datasets import LocationDataset +from llmebench.models import FastChatModel +from llmebench.tasks import DemographyLocationTask + + +def metadata(): + return { + "author": "Arabic Language Technologies, QCRI, HBKU", + "model": "JAIS-13b", + "description": "Locally hosted JAIS-13b-chat model using FastChat.", + "scores": {"Macro-F1": ""}, + } + + +def config(): + return { + "dataset": LocationDataset, + "task": DemographyLocationTask, + "model": FastChatModel, + "model_args": { + "class_labels": [ + "ae", + "OTHERS", + "bh", + "dz", + "eg", + "iq", + "jo", + "kw", + "lb", + "ly", + "ma", + "om", + "ps", + "qa", + "sa", + "sd", + "so", + "sy", + "tn", + "UNK", + "ye", + "mr", + ], + "max_tries": 30, + }, + } + + +def prompt(input_sample): + base_prompt = ( + f"Given the following 'user location', identify and map it to its corresponding country code in accordance with ISO 3166-1 alpha-2. " + f"Please write the country code only, with no additional explanations. " + f"If the country is not an Arab country, please write 'OTHERS'. If the location doesn't map to a recognized country, write 'UNK'.\n\n" + f"user location: {input_sample}\n" + f"country code: \n" + ) + + return [ + { + "role": "user", + "content": base_prompt, + }, + ] + + +def post_process(response): + label = response["choices"][0]["message"]["content"].lower() + + label_list = config()["model_args"]["class_labels"] + + if "country code: " in label: + label_fixed = label.replace("country code: ", "") + elif label.lower() == 'uae': + label_fixed = 'ae' + elif label in label_list: + label_fixed = label + else: + label_fixed = None + + return label_fixed diff --git a/assets/ar/demographic_attributes/name_info/NameInfo_JAIS13b_ZeroShot.py b/assets/ar/demographic_attributes/name_info/NameInfo_JAIS13b_ZeroShot.py new file mode 100644 index 00000000..1640829e --- /dev/null +++ b/assets/ar/demographic_attributes/name_info/NameInfo_JAIS13b_ZeroShot.py @@ -0,0 +1,161 @@ +from llmebench.datasets import NameInfoDataset +from llmebench.models import FastChatModel +from llmebench.tasks import DemographyNameInfoTask + + +def metadata(): + return { + "author": "Arabic Language Technologies, QCRI, HBKU", + "model": "JAIS-13b", + "description": "Locally hosted JAIS-13b-chat model using FastChat.", + "scores": {"Macro-F1": ""}, + } + + +def config(): + return { + "dataset": NameInfoDataset, + "task": DemographyNameInfoTask, + "model": FastChatModel, + "model_args": { + "class_labels": [ + "gb", + "us", + "cl", + "fr", + "ru", + "pl", + "in", + "it", + "kr", + "gh", + "ca", + "sa", + "at", + "de", + "cn", + "br", + "dk", + "se", + "bd", + "cu", + "jp", + "be", + "es", + "co", + "id", + "iq", + "pk", + "tr", + "il", + "ch", + "ar", + "ro", + "nl", + "ps", + "ug", + "ir", + "cg", + "do", + "ee", + "tn", + "gr", + "np", + "ie", + "sy", + "hu", + "eg", + "ma", + "ve", + "ph", + "no", + "bg", + "si", + "ke", + "au", + "et", + "py", + "af", + "pt", + "th", + "bo", + "mx", + "lb", + "za", + "fi", + "hr", + "vn", + "ly", + "nz", + "qa", + "kh", + "ci", + "ng", + "sg", + "cm", + "dz", + "tz", + "ae", + "pe", + "az", + "lu", + "ec", + "cz", + "ua", + "uy", + "sd", + "ao", + "my", + "lv", + "kw", + "tw", + "bh", + "lk", + "ye", + "cr", + "jo", + "pa", + "om", + "uz", + "by", + "kz", + ], + "max_tries": 30, + }, + } + + +def prompt(input_sample): + base_prompt = ( + f"Label the country of the following person 'name'. Write ONLY the country code in ISO 3166-1 alpha-2 format.\n\n" + f"name: {input_sample}\n" + f"country: \n" + ) + return [ + { + "role": "user", + "content": base_prompt, + }, + ] + + +def post_process(response): + label = response["choices"][0]["message"]["content"] + + label_list = config()["model_args"]["class_labels"] + + if "name: " in label: + label_fixed = label.replace("name: ", "").lower() + elif label.lower() == 'uae': + label_fixed = 'ae' + elif label.lower() in label_list: + label_fixed = label.lower() + elif ( + "I'm sorry, but I cannot predict the country" in label + or "I cannot predict the country" in label + ): + label_fixed = None + else: + label_fixed = None + + return label_fixed