Skip to content

Commit

Permalink
Implement generic ClassificationTask (#242)
Browse files Browse the repository at this point in the history
This commit adds support for a generic `ClassificationTask` with many popular metrics.

* Added SST2 dataset, GPT4-ZeroShot and Llama-2-13b-chat-hf assets.

* updated the classification task with gender datasets

* Fix errorneous imports

---------

Co-authored-by: Fahim Imaduddin Dalvi <[email protected]>
  • Loading branch information
firojalam and fdalvi authored Oct 28, 2023
1 parent 9b9de5a commit 5325b63
Show file tree
Hide file tree
Showing 14 changed files with 67 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArabGendDataset
from llmebench.models import PetalsModel
from llmebench.tasks import DemographyGenderTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -14,7 +14,7 @@ def metadata():
def config():
return {
"dataset": ArabGendDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": PetalsModel,
"model_args": {
"class_labels": ["m", "f"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArabGendDataset
from llmebench.models import LegacyOpenAIModel
from llmebench.tasks import DemographyGenderTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -14,7 +14,7 @@ def metadata():
def config():
return {
"dataset": ArabGendDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": LegacyOpenAIModel,
"model_args": {
"class_labels": ["m", "f"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArabGendDataset
from llmebench.models import OpenAIModel
from llmebench.tasks import DemographyGenderTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -14,7 +14,7 @@ def metadata():
def config():
return {
"dataset": ArabGendDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": OpenAIModel,
"model_args": {
"class_labels": ["m", "f"],
Expand Down
4 changes: 2 additions & 2 deletions assets/ar/demographic_attributes/gender/ArabGend_Random.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArabGendDataset
from llmebench.models import RandomModel
from llmebench.tasks import DemographyGenderTask, TaskType
from llmebench.tasks import ClassificationTask, TaskType


def metadata():
Expand All @@ -15,7 +15,7 @@ def metadata():
def config():
return {
"dataset": ArabGendDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": RandomModel,
"model_args": {
"task_type": TaskType.Classification,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArapTweetDataset
from llmebench.models import PetalsModel
from llmebench.tasks import DemographyGenderTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -15,7 +15,7 @@ def metadata():
def config():
return {
"dataset": ArapTweetDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": PetalsModel,
"model_args": {
"class_labels": ["Female", "Male"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArapTweetDataset
from llmebench.models import LegacyOpenAIModel
from llmebench.tasks import DemographyGenderTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -15,7 +15,7 @@ def metadata():
def config():
return {
"dataset": ArapTweetDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": LegacyOpenAIModel,
"model_args": {
"class_labels": ["Female", "Male"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArapTweetDataset
from llmebench.models import OpenAIModel
from llmebench.tasks import DemographyGenderTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -15,7 +15,7 @@ def metadata():
def config():
return {
"dataset": ArapTweetDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": OpenAIModel,
"model_args": {
"class_labels": ["Female", "Male"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArapTweetDataset
from llmebench.models import OpenAIModel
from llmebench.tasks import DemographyGenderTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -15,7 +15,7 @@ def metadata():
def config():
return {
"dataset": ArapTweetDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": OpenAIModel,
"model_args": {
"class_labels": ["Female", "Male"],
Expand Down
4 changes: 2 additions & 2 deletions assets/ar/demographic_attributes/gender/ArapTweet_Random.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import ArapTweetDataset
from llmebench.models import RandomModel
from llmebench.tasks import DemographyGenderTask, TaskType
from llmebench.tasks import ClassificationTask, TaskType


def metadata():
Expand All @@ -15,7 +15,7 @@ def metadata():
def config():
return {
"dataset": ArapTweetDataset,
"task": DemographyGenderTask,
"task": ClassificationTask,
"model": RandomModel,
"model_args": {
"task_type": TaskType.Classification,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import HuggingFaceDataset
from llmebench.models import OpenAIModel
from llmebench.tasks import SentimentTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -22,7 +22,7 @@ def config():
"input_id": "idx",
},
},
"task": SentimentTask,
"task": ClassificationTask,
"model": OpenAIModel,
"model_args": {
"class_labels": ["positive", "negative"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llmebench.datasets import HuggingFaceDataset
from llmebench.models import FastChatModel
from llmebench.tasks import SentimentTask
from llmebench.tasks import ClassificationTask


def metadata():
Expand All @@ -23,7 +23,7 @@ def config():
"input_id": "idx",
},
},
"task": SentimentTask,
"task": ClassificationTask,
"model": FastChatModel,
"general_args": {"custom_test_split": "validation"},
}
Expand Down
44 changes: 44 additions & 0 deletions llmebench/tasks/Classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from llmebench.tasks.task_base import TaskBase


class ClassificationTask(TaskBase):
def __init__(self, **kwargs):
super(ClassificationTask, self).__init__(**kwargs)

def evaluate(self, true_labels, predicted_labels):
predicted_labels = [
p if p is not None else self.get_random_prediction(set(true_labels))
for p in predicted_labels
]

acc_score = accuracy_score(true_labels, predicted_labels)
macro_precision = precision_score(
true_labels, predicted_labels, average="macro"
)
macro_recall = recall_score(true_labels, predicted_labels, average="macro")
macro_f1 = f1_score(true_labels, predicted_labels, average="macro")

micro_precision = precision_score(
true_labels, predicted_labels, average="micro"
)
micro_recall = recall_score(true_labels, predicted_labels, average="micro")
micro_f1 = f1_score(true_labels, predicted_labels, average="micro")

w_precision = precision_score(true_labels, predicted_labels, average="weighted")
w_recall = recall_score(true_labels, predicted_labels, average="weighted")
w_f1 = f1_score(true_labels, predicted_labels, average="weighted")

return {
"Accuracy": acc_score,
"Macro precision": macro_precision,
"Macro recall": macro_recall,
"Macro F1": macro_f1,
"Micro precision": micro_precision,
"Micro recall": micro_recall,
"Micro F1": micro_f1,
"Weighted Precision": w_precision,
"Weighted Recall": w_recall,
"Weighted F1": w_f1,
}
1 change: 1 addition & 0 deletions llmebench/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .Attentionworthy import AttentionworthyTask
from .Checkworthiness import CheckworthinessTask
from .ClaimDetection import ClaimDetectionTask
from .Classification import ClassificationTask
from .DemographyGender import DemographyGenderTask
from .DemographyLocation import DemographyLocationTask
from .DemographyNameInfo import DemographyNameInfoTask
Expand Down

0 comments on commit 5325b63

Please sign in to comment.