-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathrun_experiment.py
60 lines (47 loc) · 2.01 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import click
import os
import logging
from configurations.configuration import Configuration
from finer import FINER
logging.getLogger('tensorflow').setLevel(logging.ERROR)
logging.getLogger('transformers').setLevel(logging.ERROR)
LOGGER = logging.getLogger(__name__)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
cli = click.Group()
@cli.command()
@click.option('--method', default='transformer')
@click.option('--mode', default='train')
def run_experiment(method, mode):
"""
Main function that instantiates and runs a new experiment
:param method: Method to run ("bilstm", "transformer", "transformer_bilstm")
:param mode: Mode to run ("train", "evaluate")
"""
# Instantiate the Configuration class
Configuration.configure(method=method, mode=mode)
experiment = FINER()
def log_parameters(parameters):
LOGGER.info(f'\n---------------- {parameters.split("_")[0].capitalize()} Parameters ----------------')
for param_name, value in Configuration[parameters].items():
if isinstance(value, dict):
LOGGER.info(f'{param_name}:')
for p_name, p_value in value.items():
LOGGER.info(f'\t{p_name}: {p_value}')
else:
LOGGER.info(f'{param_name}: {value}')
if mode == 'train':
LOGGER.info('\n---------------- Train ----------------')
LOGGER.info(f"Log Name: {Configuration['task']['log_name']}")
for params in ['train_parameters', 'general_parameters', 'hyper_parameters', 'evaluation']:
log_parameters(parameters=params)
LOGGER.info('\n')
experiment.train()
elif mode == 'evaluate':
LOGGER.info('\n---------------- Evaluate Pretrained Model ----------------')
for params in ['train_parameters', 'general_parameters', 'evaluation']:
log_parameters(parameters=params)
LOGGER.info('\n')
experiment.evaluate_pretrained_model()
if __name__ == '__main__':
run_experiment()