From 5995d4b53ca51bb073701f921554ebf126e2238f Mon Sep 17 00:00:00 2001 From: Fahim Imaduddin Dalvi Date: Sat, 28 Oct 2023 22:41:54 +0300 Subject: [PATCH] Fix GPT4 asset --- .../sentiment/SST2_GPT4_ZeroShot.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/assets/en/sentiment_emotion_others/sentiment/SST2_GPT4_ZeroShot.py b/assets/en/sentiment_emotion_others/sentiment/SST2_GPT4_ZeroShot.py index 2be46c17..544cb7ff 100644 --- a/assets/en/sentiment_emotion_others/sentiment/SST2_GPT4_ZeroShot.py +++ b/assets/en/sentiment_emotion_others/sentiment/SST2_GPT4_ZeroShot.py @@ -1,4 +1,4 @@ -from llmebench.datasets import SST2 +from llmebench.datasets import HuggingFaceDataset from llmebench.models import OpenAIModel from llmebench.tasks import SentimentTask @@ -13,13 +13,22 @@ def metadata(): def config(): return { - "dataset": SST2, + "dataset": HuggingFaceDataset, + "dataset_args": { + "huggingface_dataset_name": "sst2", + "column_mapping": { + "input": "sentence", + "label": "label", + "input_id": "idx", + }, + }, "task": SentimentTask, "model": OpenAIModel, "model_args": { "class_labels": ["positive", "negative"], "max_tries": 3, }, + "general_args": {"custom_test_split": "validation"}, } @@ -45,7 +54,13 @@ def post_process(response): label = response["choices"][0]["message"]["content"].lower() label_fixed = label.replace("label:", "").replace("sentiment: ", "").strip() + if label_fixed.startswith("Please provide the text"): label_fixed = None - return label_fixed + if label_fixed == "positive": + return 1 + elif label_fixed == "negative": + return 0 + + return None