We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
您好, 很抱歉这个issue可能会打扰到项目组成员,但对于此项目上的复现我一直不得要点,得不到与文章相同的结果,还望前辈拨冗解惑。 对于贵组放出的 chatgpt-detector-roberta-chinese 模型的描述,此模型是由mix-filter训练得到的。 我采取的测试方式如下所示
最后对raw-full进行测试的结果: 2024-03-05 19:44:46,902 - testing - INFO - test_doc: {'f1': 0.9976726144297905}
与原论文的表中数据显著不同,所以我想请教一下,是我的测试方式有误吗,如果有误,正确的测试方式应该是什么?
最后,无论如何都感谢贵组的工作。
import argparse import os import numpy as np import sys import evaluate import pandas as pd import torch import logging import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm from datasets import Dataset, concatenate_datasets from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, BertForSequenceClassification ) logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger('testing') file_handler = logging.FileHandler('test.log') file_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) logger.addHandler(file_handler) sys.path.append('./') _PARSER = argparse.ArgumentParser('ptm detector') _PARSER.add_argument('--model_name', type=str, default='/data1/xxxxxx/DeepfakeText-chinese/model/chinese-roberta-wwm-ext', help='ptm model name') _PARSER.add_argument('--roberta_model',type=str, default='/data1/xxxxxx/DeepfakeText-chinese/model/chatgpt-detector-roberta-chinese', help='roberta_model') _PARSER.add_argument('--test_doc', type=str, default='../../data/zh_doc_test.csv', help='input doc test file path') _PARSER.add_argument('--test_sent', type=str, default='../../data/shuffled_zh_sent_test.csv', help='input test sent file path') _PARSER.add_argument('--batch_size', type=int, default=16, help='batch size') _PARSER.add_argument('--epochs', type=int, default=2, help='epochs') _PARSER.add_argument('--num_labels', type=int, default=2, help='num_labels') _PARSER.add_argument('--cuda', type=str, default='0', help='gpu ids, like: 1,2,3') _PARSER.add_argument('--seed', type=int, default=42, help='random seed.') _PARSER.add_argument('--max_length', type=int, default=365, help='max_length') _PARSER.add_argument('--stacking', type=bool, default=True, help='stacking') _ARGS = _PARSER.parse_args() if len(_ARGS.cuda) > 1: os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL' os.environ["OMP_NUM_THREADS"] = '8' os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # if cuda >= 10.2 os.environ['CUDA_VISIBLE_DEVICES'] = _ARGS.cuda device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def create_dataloader(args: argparse.Namespace): """ dataloaders分别是train_doc, test_doc, test_sent """ datasets = [] files = [args.test_doc, args.test_sent] for file in files: df = pd.read_csv(file) dataset = Dataset.from_pandas(df) datasets.append(dataset) tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) def tokenize_fn(example): return tokenizer(example['answer'], max_length=args.max_length, padding='max_length', truncation=True) datasets = [datasets[0], datasets[1]] names = ['id', 'question', 'answer', 'source'] tokenized_datasets = [] for dataset in datasets: tokenized = dataset.map( tokenize_fn, batched=True, remove_columns=names) tokenized_datasets.append(tokenized) def collate_fn(examples): return tokenizer.pad(examples,return_tensors='pt') dataloaders = [] for dataset in tokenized_datasets: dataloader = DataLoader(dataset, shuffle=False, collate_fn=collate_fn, batch_size=args.batch_size) dataloaders.append(dataloader) return dataloaders def eval(args, dataloaders): if args.stacking: # roberta_cnn_model = torch.load(args.roberta_cnn_model).to(device) # roberta_cnn_model.eval() # print("roberta_cnn_model loaded") # roberta_model = torch.load(args.roberta_model).to(device) # roberta_model.eval() config = AutoConfig.from_pretrained( args.roberta_model, num_labels=2, ) roberta_model = BertForSequenceClassification.from_pretrained( args.roberta_model, config=config, ).to(device) # print(roberta_model.base_model) # exit() # for param in roberta_model.base_model.parameters(): # param.requires_grad = False print("roberta_rnn_model loaded") # roberta_rcnn_model = torch.load(args.roberta_rcnn_model).to(device) # roberta_rcnn_model.eval() # print("roberta_rcnn_model loaded") # roberta_rcnn_model = torch.load(args.roberta_rcnn_model).to(device) # roberta_rcnn_model.eval() # print("roberta_rcnn_model loaded") eval_name_list = ['test_doc', 'test_sent'] for item, eval_name in enumerate(eval_name_list, 0): metric = evaluate.load("/data1/xxxxxx/DeepfakeText-chinese/dataset/metrics/f1") for step, batch in enumerate(tqdm(dataloaders[item], desc='Evaling', colour="green")): batch.to(device) with torch.no_grad(): labels = batch.pop('label') outputs = roberta_model(**batch)['logits'] predictions = outputs.argmax(dim=-1) predictions, references = predictions, labels metric.add_batch( predictions=predictions, references=references, ) eval_metric = metric.compute() logger.info(f"{eval_name}: {eval_metric}") daataLoader = create_dataloader(_ARGS) eval(_ARGS,daataLoader)
The text was updated successfully, but these errors were encountered:
No branches or pull requests
您好,
很抱歉这个issue可能会打扰到项目组成员,但对于此项目上的复现我一直不得要点,得不到与文章相同的结果,还望前辈拨冗解惑。
对于贵组放出的 chatgpt-detector-roberta-chinese 模型的描述,此模型是由mix-filter训练得到的。
我采取的测试方式如下所示
最后对raw-full进行测试的结果:
2024-03-05 19:44:46,902 - testing - INFO - test_doc: {'f1': 0.9976726144297905}
与原论文的表中数据显著不同,所以我想请教一下,是我的测试方式有误吗,如果有误,正确的测试方式应该是什么?
最后,无论如何都感谢贵组的工作。
The text was updated successfully, but these errors were encountered: