Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The MMLU dataset results of google_flan_t5_large are lower than your experimental results #91

Open
zl-comment opened this issue Dec 26, 2024 · 1 comment

Comments

@zl-comment
Copy link

Image

```# create dataset
dataset = pb.DatasetLoader.load_dataset("mmlu",local_path=localpathconfig.MMLU_PATH)   #还有mrpc的实验
logging.info(f"数据集: mmlu")
# 确保数据集包含足够的数据
if len(dataset) >= 1000:
    # 选择前1000条记录
    validation_dataset = dataset[:1000]
else:
    validation_dataset=dataset


prompts = [
      "In relation to the multiple-choice question on {}, please provide the accurate answer by choosing 'A', 'B', 'C', or 'D'",
      "For each multiple-choice question about {}, identify the correct answer by selecting 'A', 'B', 'C', or 'D'",
      "Answer the subsequent multiple-choice question about {} by picking the right option among 'A', 'B', 'C', or 'D'",
      "As an expert in {}, respond to the following multiple-choice question by selecting 'A', 'B', 'C', or 'D'",
      "Considering your familiarity with {}, attend to the following multiple-choice question by picking 'A', 'B', 'C', or 'D'",
      "As someone well-versed in {}, please address the multiple-choice question below by selecting 'A', 'B', 'C', or 'D'"
        ]

try:
     model_t5 = LLMModel(model='google/flan-t5-large', temperature=0.5)
     print("语句执行成功,模型已成功加载。")
except Exception as e:
     print("语句执行失败,以下是错误信息:")
     print(str(e))

# define the projection function required by the output process
def proj_func(pred):
    mapping = {
        "a": 0,
        "b": 1,
        "c": 2,
        "d": 3
    }
    pred_lower = pred.lower()  # 将输入转换为小写
    if pred_lower in mapping:
        return mapping[pred_lower]
    else:
        logging.info(f"ERROR OUT: {pred}")  # 记录到日志文件
        return -1


# define the evaluation function required by the attack
def eval_func(prompt, validation_dataset, model):
    logging.info(f"Prompt: {prompt}")  # 记录到日志文件
    preds = []
    labels = []
    for d in tqdm(validation_dataset, desc="process"):

        input_text = pb.InputProcess.basic_format(prompt.replace("{}", "{content}"), d)


        raw_output = model(input_text)  #是有回答的

        output = pb.OutputProcess.cls(raw_output, proj_func)   #将输出结果映射到1 0 -1


        preds.append(output)

        labels.append(d["label"])

    return pb.Eval.compute_cls_accuracy(preds, labels)


# define the unmodifiable words in the prompt
# for example, the labels "positive" and "negative" are unmodifiable, and "content" is modifiable because it is a placeholder
# if your labels are enclosed with '', you need to add \' to the unmodifiable words (due to one feature of textattack)
unmodifiable_words = ['A', 'B', 'C', 'D', 'A\'', 'B\'', 'C\'', 'D\'', 'a', 'b', 'c', 'd', 'a\'', 'b\'', 'c\'', 'd\'']

# print all supported attacks
print(Attack.attack_list())
@zl-comment
Copy link
Author

    def __init__(self, local_path=None):
        print(local_path)
        self.data = []
        self.tasks = ['high_school_european_history', 'business_ethics', 'clinical_knowledge', 'medical_genetics',
                    'high_school_us_history', 'high_school_physics', 'high_school_world_history', 'virology',
                    'high_school_microeconomics', 'econometrics', 'college_computer_science', 'high_school_biology',
                    'abstract_algebra', 'professional_accounting', 'philosophy', 'professional_medicine', 'nutrition',
                    'global_facts', 'machine_learning', 'security_studies', 'public_relations', 'professional_psychology',
                    'prehistory', 'anatomy', 'human_sexuality', 'college_medicine', 'high_school_government_and_politics',
                    'college_chemistry', 'logical_fallacies', 'high_school_geography', 'elementary_mathematics', 'human_aging',
                    'college_mathematics', 'high_school_psychology', 'formal_logic', 'high_school_statistics', 'international_law',
                    'high_school_mathematics', 'high_school_computer_science', 'conceptual_physics', 'miscellaneous', 'high_school_chemistry',
                    'marketing', 'professional_law', 'management', 'college_physics', 'jurisprudence', 'world_religions', 'sociology',
                    'us_foreign_policy', 'high_school_macroeconomics', 'computer_security', 'moral_scenarios', 'moral_disputes',
                    'electrical_engineering', 'astronomy', 'college_biology']

        if local_path :
            for task in self.tasks:
                data = pd.read_parquet(f"{local_path}/{task}/validation-00000-of-00001.parquet")
                #转换格式
                data = datasets.Dataset.from_pandas(data)
                for d in data:
                    d["task"] = task
                    self.data.append({"content":d,"label":d["answer"]})
        else:
            for task in self.tasks:
                data = load_dataset(local_path, task)["test"]
                for d in data:
                    d["task"] = task
                    self.data.append({"content":d,"label":d["answer"]})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant