diff --git a/assets/ar/sequence_tagging_and_information_extraction/dialect_identification/ADI_BLOOMZ_ZeroShot.py b/assets/ar/sequence_tagging_and_information_extraction/dialect_identification/ADI_BLOOMZ_ZeroShot.py index 88542dad..a56dbc8d 100644 --- a/assets/ar/sequence_tagging_and_information_extraction/dialect_identification/ADI_BLOOMZ_ZeroShot.py +++ b/assets/ar/sequence_tagging_and_information_extraction/dialect_identification/ADI_BLOOMZ_ZeroShot.py @@ -17,37 +17,12 @@ def config(): "dataset": ADIDataset, "task": DialectIDTask, "model": PetalsModel, - "model_args": { - "class_labels": [ - "EGY", - "IRA", - "JOR", - "KSA", - "KUW", - "LEB", - "LIB", - "MOR", - "MSA", - "PAL", - "QAT", - "SUD", - "SYR", - "UAE", - "YEM", - ], - "max_tries": 3, - }, } def prompt(input_sample): - arr = input_sample.split() - if len(arr) > 500: - input_sample = arr[:500] - prompt_string = ( - f'Classify the following "text" into one of the following categories: "EGY", "IRA", "JOR", "KSA", "KUW", "LEB", "LIB", "MOR", "MSA", "PAL", "QAT", "SUD", "SYR", "UAE", "YEM"\n' - f"Please provide only the label.\n\n" + f'Identify the dialect of the following Arabic "text" given the following possible dialects: "Egyptian", "Iraqi", "Jordanian", "Saudi", "Kuwaiti", "Lebanese", "Libyan", "Moroccan", "modern standard Arabic", "Palestinian", "Qatari", "Sudanese", "Syrian", "Emirati", "Yemeni"\n' f"text: {input_sample}\n" f"label: \n" ) @@ -58,45 +33,36 @@ def prompt(input_sample): def post_process(response): + count_label_map = { + "Egyptian": "EGY", + "Iraqi": "IRA", + "Jordanian": "JOR", + "Saudi": "KSA", + "Kuwaiti": "KUW", + "Lebanese": "LEB", + "Libyan": "LIB", + "Moroccan": "MOR", + "modern standard Arabic": "MSA", + "Modern standard Arabic": "MSA", + "Modern Standard Arabic": "MSA", + "Palestinian": "PAL", + "Qatari": "QAT", + "Sudanese": "SUD", + "Syrian": "SYR", + "Emirati": "UAE", + "Yemeni": "YEM", + "Yemen": "YEM", + } + label = response["outputs"].strip() label = label.replace("", "") label = label.replace("", "") - label = label.lower() - - # label_list = config()["model_args"]["class_labels"] - # label_list = [lab.lower() for lab in label_list] - # - # if "label: " in label: - # label_fixed = label.replace("label: ", "").lower() - # elif label.lower() in label_list: - # label_fixed = label.lower() - # else: - # label_fixed = None - label_list = config()["model_args"]["class_labels"] - label_list = [dialect.lower() for dialect in label_list] - - label = label.replace("label:", "").strip() + label = label.replace("Dialect: ", "").replace("dialect: ", "") + label = label.replace("label: ", "") + label = label.strip() - if label in label_list: - label_fixed = label - elif "\n msa" in label: - label_fixed = "msa" - elif "\n ksa" in label: - label_fixed = "ksa" - elif "\n pal" in label: - label_fixed = "pal" - elif "\n egy" in label: - label_fixed = "egy" - elif "\n yem" in label: - label_fixed = "yem" - elif "\n syr" in label: - label_fixed = "syr" - elif "\n jor" in label: - label_fixed = "jor" - elif "\n ira" in label: - label_fixed = "ira" - elif "\n kuw" in label: - label_fixed = "kuw" + if label in count_label_map: + label_fixed = count_label_map[label].lower() else: label_fixed = None