From 5325b63c1d1a5a27b01720983324b83b2dce8aec Mon Sep 17 00:00:00 2001 From: "Firoj Alam, Scientist, QCRI" Date: Sat, 28 Oct 2023 23:19:27 +0300 Subject: [PATCH] Implement generic `ClassificationTask` (#242) This commit adds support for a generic `ClassificationTask` with many popular metrics. * Added SST2 dataset, GPT4-ZeroShot and Llama-2-13b-chat-hf assets. * updated the classification task with gender datasets * Fix errorneous imports --------- Co-authored-by: Fahim Imaduddin Dalvi --- .../gender/ArabGend_BLOOMZ_ZeroShot.py | 4 +- .../gender/ArabGend_GPT35_ZeroShot.py | 4 +- .../gender/ArabGend_GPT4_ZeroShot.py | 4 +- .../gender/ArabGend_Random.py | 4 +- .../gender/ArapTweet_BLOOMZ_ZeroShot.py | 4 +- .../gender/ArapTweet_GPT35_ZeroShot.py | 4 +- .../gender/ArapTweet_GPT4_FewShot.py | 4 +- .../gender/ArapTweet_GPT4_ZeroShot.py | 4 +- .../gender/ArapTweet_Random.py | 4 +- ...ot.py => ArSAS_Llama2_7b_chat_ZeroShot.py} | 0 .../sentiment/SST2_GPT4_ZeroShot.py | 4 +- ...hot.py => SST2_Llama2_7b_chat_ZeroShot.py} | 4 +- llmebench/tasks/Classification.py | 44 +++++++++++++++++++ llmebench/tasks/__init__.py | 1 + 14 files changed, 67 insertions(+), 22 deletions(-) rename assets/ar/sentiment_emotion_others/sentiment/{ArSAS_Llama_7b_chat_ZeroShot.py => ArSAS_Llama2_7b_chat_ZeroShot.py} (100%) rename assets/en/sentiment_emotion_others/sentiment/{SST2_Llama_7b_chat_ZeroShot.py => SST2_Llama2_7b_chat_ZeroShot.py} (96%) create mode 100644 llmebench/tasks/Classification.py diff --git a/assets/ar/demographic_attributes/gender/ArabGend_BLOOMZ_ZeroShot.py b/assets/ar/demographic_attributes/gender/ArabGend_BLOOMZ_ZeroShot.py index 6a2d7350..8403ec0c 100644 --- a/assets/ar/demographic_attributes/gender/ArabGend_BLOOMZ_ZeroShot.py +++ b/assets/ar/demographic_attributes/gender/ArabGend_BLOOMZ_ZeroShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArabGendDataset from llmebench.models import PetalsModel -from llmebench.tasks import DemographyGenderTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -14,7 +14,7 @@ def metadata(): def config(): return { "dataset": ArabGendDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": PetalsModel, "model_args": { "class_labels": ["m", "f"], diff --git a/assets/ar/demographic_attributes/gender/ArabGend_GPT35_ZeroShot.py b/assets/ar/demographic_attributes/gender/ArabGend_GPT35_ZeroShot.py index 154e3ce0..ee8e1f78 100644 --- a/assets/ar/demographic_attributes/gender/ArabGend_GPT35_ZeroShot.py +++ b/assets/ar/demographic_attributes/gender/ArabGend_GPT35_ZeroShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArabGendDataset from llmebench.models import LegacyOpenAIModel -from llmebench.tasks import DemographyGenderTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -14,7 +14,7 @@ def metadata(): def config(): return { "dataset": ArabGendDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": LegacyOpenAIModel, "model_args": { "class_labels": ["m", "f"], diff --git a/assets/ar/demographic_attributes/gender/ArabGend_GPT4_ZeroShot.py b/assets/ar/demographic_attributes/gender/ArabGend_GPT4_ZeroShot.py index 2eed8889..9f2daf89 100644 --- a/assets/ar/demographic_attributes/gender/ArabGend_GPT4_ZeroShot.py +++ b/assets/ar/demographic_attributes/gender/ArabGend_GPT4_ZeroShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArabGendDataset from llmebench.models import OpenAIModel -from llmebench.tasks import DemographyGenderTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -14,7 +14,7 @@ def metadata(): def config(): return { "dataset": ArabGendDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": OpenAIModel, "model_args": { "class_labels": ["m", "f"], diff --git a/assets/ar/demographic_attributes/gender/ArabGend_Random.py b/assets/ar/demographic_attributes/gender/ArabGend_Random.py index df47515f..2059cee0 100644 --- a/assets/ar/demographic_attributes/gender/ArabGend_Random.py +++ b/assets/ar/demographic_attributes/gender/ArabGend_Random.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArabGendDataset from llmebench.models import RandomModel -from llmebench.tasks import DemographyGenderTask, TaskType +from llmebench.tasks import ClassificationTask, TaskType def metadata(): @@ -15,7 +15,7 @@ def metadata(): def config(): return { "dataset": ArabGendDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": RandomModel, "model_args": { "task_type": TaskType.Classification, diff --git a/assets/ar/demographic_attributes/gender/ArapTweet_BLOOMZ_ZeroShot.py b/assets/ar/demographic_attributes/gender/ArapTweet_BLOOMZ_ZeroShot.py index 92668b05..a30141f3 100644 --- a/assets/ar/demographic_attributes/gender/ArapTweet_BLOOMZ_ZeroShot.py +++ b/assets/ar/demographic_attributes/gender/ArapTweet_BLOOMZ_ZeroShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArapTweetDataset from llmebench.models import PetalsModel -from llmebench.tasks import DemographyGenderTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -15,7 +15,7 @@ def metadata(): def config(): return { "dataset": ArapTweetDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": PetalsModel, "model_args": { "class_labels": ["Female", "Male"], diff --git a/assets/ar/demographic_attributes/gender/ArapTweet_GPT35_ZeroShot.py b/assets/ar/demographic_attributes/gender/ArapTweet_GPT35_ZeroShot.py index 656ea94c..5021366d 100644 --- a/assets/ar/demographic_attributes/gender/ArapTweet_GPT35_ZeroShot.py +++ b/assets/ar/demographic_attributes/gender/ArapTweet_GPT35_ZeroShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArapTweetDataset from llmebench.models import LegacyOpenAIModel -from llmebench.tasks import DemographyGenderTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -15,7 +15,7 @@ def metadata(): def config(): return { "dataset": ArapTweetDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": LegacyOpenAIModel, "model_args": { "class_labels": ["Female", "Male"], diff --git a/assets/ar/demographic_attributes/gender/ArapTweet_GPT4_FewShot.py b/assets/ar/demographic_attributes/gender/ArapTweet_GPT4_FewShot.py index 107a651e..5c442209 100644 --- a/assets/ar/demographic_attributes/gender/ArapTweet_GPT4_FewShot.py +++ b/assets/ar/demographic_attributes/gender/ArapTweet_GPT4_FewShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArapTweetDataset from llmebench.models import OpenAIModel -from llmebench.tasks import DemographyGenderTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -15,7 +15,7 @@ def metadata(): def config(): return { "dataset": ArapTweetDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": OpenAIModel, "model_args": { "class_labels": ["Female", "Male"], diff --git a/assets/ar/demographic_attributes/gender/ArapTweet_GPT4_ZeroShot.py b/assets/ar/demographic_attributes/gender/ArapTweet_GPT4_ZeroShot.py index 4a9d662c..7ec58a82 100644 --- a/assets/ar/demographic_attributes/gender/ArapTweet_GPT4_ZeroShot.py +++ b/assets/ar/demographic_attributes/gender/ArapTweet_GPT4_ZeroShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArapTweetDataset from llmebench.models import OpenAIModel -from llmebench.tasks import DemographyGenderTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -15,7 +15,7 @@ def metadata(): def config(): return { "dataset": ArapTweetDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": OpenAIModel, "model_args": { "class_labels": ["Female", "Male"], diff --git a/assets/ar/demographic_attributes/gender/ArapTweet_Random.py b/assets/ar/demographic_attributes/gender/ArapTweet_Random.py index 7409b18b..1dfb3b1d 100644 --- a/assets/ar/demographic_attributes/gender/ArapTweet_Random.py +++ b/assets/ar/demographic_attributes/gender/ArapTweet_Random.py @@ -1,6 +1,6 @@ from llmebench.datasets import ArapTweetDataset from llmebench.models import RandomModel -from llmebench.tasks import DemographyGenderTask, TaskType +from llmebench.tasks import ClassificationTask, TaskType def metadata(): @@ -15,7 +15,7 @@ def metadata(): def config(): return { "dataset": ArapTweetDataset, - "task": DemographyGenderTask, + "task": ClassificationTask, "model": RandomModel, "model_args": { "task_type": TaskType.Classification, diff --git a/assets/ar/sentiment_emotion_others/sentiment/ArSAS_Llama_7b_chat_ZeroShot.py b/assets/ar/sentiment_emotion_others/sentiment/ArSAS_Llama2_7b_chat_ZeroShot.py similarity index 100% rename from assets/ar/sentiment_emotion_others/sentiment/ArSAS_Llama_7b_chat_ZeroShot.py rename to assets/ar/sentiment_emotion_others/sentiment/ArSAS_Llama2_7b_chat_ZeroShot.py diff --git a/assets/en/sentiment_emotion_others/sentiment/SST2_GPT4_ZeroShot.py b/assets/en/sentiment_emotion_others/sentiment/SST2_GPT4_ZeroShot.py index 544cb7ff..8c021689 100644 --- a/assets/en/sentiment_emotion_others/sentiment/SST2_GPT4_ZeroShot.py +++ b/assets/en/sentiment_emotion_others/sentiment/SST2_GPT4_ZeroShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import HuggingFaceDataset from llmebench.models import OpenAIModel -from llmebench.tasks import SentimentTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -22,7 +22,7 @@ def config(): "input_id": "idx", }, }, - "task": SentimentTask, + "task": ClassificationTask, "model": OpenAIModel, "model_args": { "class_labels": ["positive", "negative"], diff --git a/assets/en/sentiment_emotion_others/sentiment/SST2_Llama_7b_chat_ZeroShot.py b/assets/en/sentiment_emotion_others/sentiment/SST2_Llama2_7b_chat_ZeroShot.py similarity index 96% rename from assets/en/sentiment_emotion_others/sentiment/SST2_Llama_7b_chat_ZeroShot.py rename to assets/en/sentiment_emotion_others/sentiment/SST2_Llama2_7b_chat_ZeroShot.py index ba5691a6..bfa72b2c 100644 --- a/assets/en/sentiment_emotion_others/sentiment/SST2_Llama_7b_chat_ZeroShot.py +++ b/assets/en/sentiment_emotion_others/sentiment/SST2_Llama2_7b_chat_ZeroShot.py @@ -1,6 +1,6 @@ from llmebench.datasets import HuggingFaceDataset from llmebench.models import FastChatModel -from llmebench.tasks import SentimentTask +from llmebench.tasks import ClassificationTask def metadata(): @@ -23,7 +23,7 @@ def config(): "input_id": "idx", }, }, - "task": SentimentTask, + "task": ClassificationTask, "model": FastChatModel, "general_args": {"custom_test_split": "validation"}, } diff --git a/llmebench/tasks/Classification.py b/llmebench/tasks/Classification.py new file mode 100644 index 00000000..8591659f --- /dev/null +++ b/llmebench/tasks/Classification.py @@ -0,0 +1,44 @@ +from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score + +from llmebench.tasks.task_base import TaskBase + + +class ClassificationTask(TaskBase): + def __init__(self, **kwargs): + super(ClassificationTask, self).__init__(**kwargs) + + def evaluate(self, true_labels, predicted_labels): + predicted_labels = [ + p if p is not None else self.get_random_prediction(set(true_labels)) + for p in predicted_labels + ] + + acc_score = accuracy_score(true_labels, predicted_labels) + macro_precision = precision_score( + true_labels, predicted_labels, average="macro" + ) + macro_recall = recall_score(true_labels, predicted_labels, average="macro") + macro_f1 = f1_score(true_labels, predicted_labels, average="macro") + + micro_precision = precision_score( + true_labels, predicted_labels, average="micro" + ) + micro_recall = recall_score(true_labels, predicted_labels, average="micro") + micro_f1 = f1_score(true_labels, predicted_labels, average="micro") + + w_precision = precision_score(true_labels, predicted_labels, average="weighted") + w_recall = recall_score(true_labels, predicted_labels, average="weighted") + w_f1 = f1_score(true_labels, predicted_labels, average="weighted") + + return { + "Accuracy": acc_score, + "Macro precision": macro_precision, + "Macro recall": macro_recall, + "Macro F1": macro_f1, + "Micro precision": micro_precision, + "Micro recall": micro_recall, + "Micro F1": micro_f1, + "Weighted Precision": w_precision, + "Weighted Recall": w_recall, + "Weighted F1": w_f1, + } diff --git a/llmebench/tasks/__init__.py b/llmebench/tasks/__init__.py index 0510eab6..41d0f748 100644 --- a/llmebench/tasks/__init__.py +++ b/llmebench/tasks/__init__.py @@ -8,6 +8,7 @@ from .Attentionworthy import AttentionworthyTask from .Checkworthiness import CheckworthinessTask from .ClaimDetection import ClaimDetectionTask +from .Classification import ClassificationTask from .DemographyGender import DemographyGenderTask from .DemographyLocation import DemographyLocationTask from .DemographyNameInfo import DemographyNameInfoTask