-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathrun.py
164 lines (108 loc) · 6.04 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from configs.base import ParamManager
from dataloaders.base import DataManager
from backbones.base import ModelManager
from methods import method_map
from utils.functions import save_results, set_seed
import logging
import argparse
import sys
import os
import datetime
import itertools
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, default='open_intent_discovery', help="Type for methods")
parser.add_argument('--logger_name', type=str, default='Discovery', help="Logger name for open intent discovery.")
parser.add_argument('--log_dir', type=str, default='logs', help="Logger directory.")
parser.add_argument("--dataset", default='banking', type=str, help="The name of the dataset to train selected")
parser.add_argument("--known_cls_ratio", default=0.75, type=float, help="The number of known classes")
parser.add_argument("--num_workers", default=8, type=int, help="The number of known classes")
parser.add_argument("--labeled_ratio", default=0.1, type=float, help="The ratio of labeled samples in the training set")
parser.add_argument("--cluster_num_factor", default=1.0, type=float, help="The factor (magnification) of the number of clusters K.")
parser.add_argument("--method", type=str, default='DeepAligned', help="which method to use")
parser.add_argument("--train", action="store_true", help="Whether to train the model")
parser.add_argument("--tune", action="store_true", help="Whether to tune the model")
parser.add_argument("--save_model", action="store_true", help="save trained-model for open intent detection")
parser.add_argument("--backbone", type=str, default='bert', help="which backbone to use")
parser.add_argument('--setting', type=str, default='semi_supervised', help="Type for clustering methods.")
parser.add_argument("--config_file_name", type=str, default='DeepAligned.py', help = "The name of the config file.")
parser.add_argument('--seed', type=int, default=0, help="random seed for initialization")
parser.add_argument("--gpu_id", type=str, default='0', help="Select the GPU id")
parser.add_argument("--pipe_results_path", type=str, default='pipe_results', help="the path to save results of pipeline methods")
parser.add_argument("--data_dir", default = sys.path[0]+'/../data', type=str,
help="The input data dir. Should contain the .csv files (or other data files) for the task.")
parser.add_argument("--output_dir", default= '/home/sharing/disk1/zhl/TEXTOIR/outputs', type=str,
help="The output directory where all train data will be written.")
parser.add_argument("--model_dir", default='models', type=str,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--result_dir", type=str, default = 'results', help="The path to save results")
parser.add_argument("--results_file_name", type=str, default = 'results.csv', help="The file name of all the results.")
parser.add_argument("--save_results", action="store_true", help="save final results for open intent detection")
args = parser.parse_args()
return args
def set_logger(args):
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
file_name = f"{args.method}_{args.dataset}_{args.backbone}_{args.known_cls_ratio}_{args.labeled_ratio}_{time}.log"
args.logger_file_name = f"{args.method}_{args.dataset}_{args.backbone}_{time}"
print('logger_file_name', args.logger_file_name)
logger = logging.getLogger(args.logger_name)
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler(os.path.join(args.log_dir, file_name))
fh_formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s')
fh.setFormatter(fh_formatter)
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch_formatter = logging.Formatter('%(name)s - %(message)s')
ch.setFormatter(ch_formatter)
logger.addHandler(ch)
return logger
def run(args, logger, debug_args = None):
set_seed(args.seed)
logger.info('Data and Model Preparation...')
data = DataManager(args)
model = ModelManager(args, data)
method_manager = method_map[args.method]
method = method_manager(args, data, model, logger_name = args.logger_name)
if args.train:
logger.info('Training Begin...')
method.train(args, data)
logger.info('Training Finished...')
logger.info('Testing begin...')
outputs = method.test(args, data)
logger.info('Testing finished...')
if args.save_results:
logger.info('Results saved in %s', str(os.path.join(args.result_dir, args.results_file_name)))
save_results(args, outputs, debug_args=debug_args)
if __name__ == '__main__':
sys.path.append('.')
args = parse_arguments()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = set_logger(args)
logger.info('Open Intent Discovery Begin...')
logger.info('Parameters Initialization...')
param = ParamManager(args)
args = param.args
logger.debug("="*30+" Params "+"="*30)
for k in args.keys():
logger.debug(f"{k}:\t{args[k]}")
logger.debug("="*30+" End Params "+"="*30)
if args.tune:
logger.info('Tuning begins...')
debug_args = {}
for k,v in args.items():
if isinstance(v, list):
debug_args[k] = v
logger.info("***** Tuning parameters: *****")
for key in debug_args.keys():
logger.info(" %s = %s", key, str(debug_args[key]))
for result in itertools.product(*debug_args.values()):
for i, key in enumerate(debug_args.keys()):
args[key] = result[i]
run(args, logger, debug_args=debug_args)
else:
run(args, logger)