Skip to content

Commit

Permalink
Update ADI BLOOMZ asset (#181)
Browse files Browse the repository at this point in the history
Updated prompt and post-processing to improve performance.
  • Loading branch information
MaramHasanain authored Oct 2, 2023
1 parent b54033c commit c4d6f6a
Showing 1 changed file with 27 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,12 @@ def config():
"dataset": ADIDataset,
"task": DialectIDTask,
"model": PetalsModel,
"model_args": {
"class_labels": [
"EGY",
"IRA",
"JOR",
"KSA",
"KUW",
"LEB",
"LIB",
"MOR",
"MSA",
"PAL",
"QAT",
"SUD",
"SYR",
"UAE",
"YEM",
],
"max_tries": 3,
},
}


def prompt(input_sample):
arr = input_sample.split()
if len(arr) > 500:
input_sample = arr[:500]

prompt_string = (
f'Classify the following "text" into one of the following categories: "EGY", "IRA", "JOR", "KSA", "KUW", "LEB", "LIB", "MOR", "MSA", "PAL", "QAT", "SUD", "SYR", "UAE", "YEM"\n'
f"Please provide only the label.\n\n"
f'Identify the dialect of the following Arabic "text" given the following possible dialects: "Egyptian", "Iraqi", "Jordanian", "Saudi", "Kuwaiti", "Lebanese", "Libyan", "Moroccan", "modern standard Arabic", "Palestinian", "Qatari", "Sudanese", "Syrian", "Emirati", "Yemeni"\n'
f"text: {input_sample}\n"
f"label: \n"
)
Expand All @@ -58,45 +33,36 @@ def prompt(input_sample):


def post_process(response):
count_label_map = {
"Egyptian": "EGY",
"Iraqi": "IRA",
"Jordanian": "JOR",
"Saudi": "KSA",
"Kuwaiti": "KUW",
"Lebanese": "LEB",
"Libyan": "LIB",
"Moroccan": "MOR",
"modern standard Arabic": "MSA",
"Modern standard Arabic": "MSA",
"Modern Standard Arabic": "MSA",
"Palestinian": "PAL",
"Qatari": "QAT",
"Sudanese": "SUD",
"Syrian": "SYR",
"Emirati": "UAE",
"Yemeni": "YEM",
"Yemen": "YEM",
}

label = response["outputs"].strip()
label = label.replace("<s>", "")
label = label.replace("</s>", "")
label = label.lower()

# label_list = config()["model_args"]["class_labels"]
# label_list = [lab.lower() for lab in label_list]
#
# if "label: " in label:
# label_fixed = label.replace("label: ", "").lower()
# elif label.lower() in label_list:
# label_fixed = label.lower()
# else:
# label_fixed = None
label_list = config()["model_args"]["class_labels"]
label_list = [dialect.lower() for dialect in label_list]

label = label.replace("label:", "").strip()
label = label.replace("Dialect: ", "").replace("dialect: ", "")
label = label.replace("label: ", "")
label = label.strip()

if label in label_list:
label_fixed = label
elif "\n msa" in label:
label_fixed = "msa"
elif "\n ksa" in label:
label_fixed = "ksa"
elif "\n pal" in label:
label_fixed = "pal"
elif "\n egy" in label:
label_fixed = "egy"
elif "\n yem" in label:
label_fixed = "yem"
elif "\n syr" in label:
label_fixed = "syr"
elif "\n jor" in label:
label_fixed = "jor"
elif "\n ira" in label:
label_fixed = "ira"
elif "\n kuw" in label:
label_fixed = "kuw"
if label in count_label_map:
label_fixed = count_label_map[label].lower()
else:
label_fixed = None

Expand Down

0 comments on commit c4d6f6a

Please sign in to comment.