-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprepare.py
executable file
·86 lines (74 loc) · 2.78 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# coding=utf-8
import argparse
import json
import shutil
import pickle
import os
import logging
import multiprocessing as mp
from os.path import dirname, exists, join
import torch
import tqdm
from inputters import inputters
from utils.building_utils import build_model
parser = argparse.ArgumentParser()
parser.add_argument('--config_name', type=str, required=True)
parser.add_argument('--inputter_name', type=str, required=True)
parser.add_argument('--train_input_file', type=str, required=True)
parser.add_argument('--max_input_length', type=int, default=150, help='discard data longer than this')
parser.add_argument('--max_decoder_input_length', type=int, default=None, help='discard data longer than this')
parser.add_argument('--max_knowledge_length', type=int, default=None, help='discard data longer than this')
parser.add_argument('--label_num', type=int, default=None)
parser.add_argument('--only_encode', action='store_true', help='only do encoding')
parser.add_argument('--single_processing', action='store_true', help='do not use multiprocessing')
args = parser.parse_args()
names = {
'inputter_name': args.inputter_name,
'config_name': args.config_name,
}
inputter = inputters[args.inputter_name]()
toker = build_model(only_toker=True, **names)
with open(args.train_input_file) as f:
reader = f.readlines()
if not os.path.exists(f'./DATA'):
os.mkdir(f'./DATA')
save_dir = f'./DATA/{args.inputter_name}.{args.config_name}'
if not exists(save_dir):
os.mkdir(save_dir)
kwargs = {
'max_input_length': args.max_input_length,
'max_decoder_input_length': args.max_decoder_input_length,
'max_knowledge_length': args.max_knowledge_length,
'label_num': args.label_num,
'only_encode': args.only_encode,
}
def process_data(line):
data = json.loads(line)
inputs = inputter.convert_data_to_inputs(
data=data,
toker=toker,
**kwargs
)
features = inputter.convert_inputs_to_features(
inputs=inputs,
toker=toker,
**kwargs,
)
return features
processed_data = []
if args.single_processing:
for features in map(process_data, tqdm.tqdm(reader, total=len(reader))):
processed_data.extend(features)
else:
with mp.Pool(processes=mp.cpu_count()) as pool:
for features in pool.imap(process_data, tqdm.tqdm(reader, total=len(reader))):
processed_data.extend(features)
# save data
data_path = f'{save_dir}/data2.pkl'
with open(data_path, 'wb') as file:
pickle.dump(processed_data, file)
kwargs.update({'n_examples': len(processed_data)})
# save relevant information to reproduce
with open(f'{save_dir}/meta.json', 'w') as writer:
json.dump(kwargs, writer, indent=4)
torch.save(toker, f'{save_dir}/tokenizer.pt')