-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_lda_model.py
180 lines (138 loc) · 6.6 KB
/
train_lda_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
This script runs the LDA model training.
'''
import logging
from pathlib import Path
import os
import itertools
import json
import click
import gensim
import contexttimer
from gensim.corpora import Dictionary, MmCorpus
from gensim.models.ldamulticore import LdaMulticore
# from joblib import Parallel, delayed
# import joblib
import wb_nlp
from wb_nlp import dir_manager
from wb_nlp.utils.scripts import (
configure_logger,
load_config, generate_model_hash,
create_get_directory
)
from wb_nlp.processing.corpus import MultiDirGenerator
_logger = logging.getLogger(__file__)
def checkpoint_log(timer, logger, message=''):
logger.info('Time elapsed now in minutes: %s %s',
timer.elapsed / 60, message)
@click.command()
@click.option('-c', '--config', 'cfg_path', required=True,
type=click.Path(exists=True), help='path to config file')
@click.option('--quiet', 'log_level', flag_value=logging.WARNING, default=True)
@click.option('-v', '--verbose', 'log_level', flag_value=logging.INFO)
@click.option('-vv', '--very-verbose', 'log_level', flag_value=logging.DEBUG)
@click.option('--train-dictionary', 'load_dictionary', flag_value=False, default=True)
@click.option('--load-dictionary', 'load_dictionary', flag_value=True)
@click.option('--from-files', 'load_dump', flag_value=False, default=True)
@click.option('--from-dump', 'load_dump', flag_value=True)
@click.version_option(wb_nlp.__version__)
def main(cfg_path: Path, log_level: int, load_dictionary: bool, load_dump: bool):
'''
Entry point for LDA model training script.
'''
with contexttimer.Timer() as timer:
configure_logger(log_level)
if load_dump:
assert load_dictionary, "Can't load a corpus dump without using the --load-dictionary flag."
# YOUR CODE GOES HERE! Keep the main functionality in src/wb_nlp
# est = wb_nlp.models.Estimator()
config = load_config(cfg_path, 'model_config', _logger)
assert gensim.__version__ == config['meta']['library_version']
paths_conf = config['paths']
model_dir = Path(dir_manager.get_path_from_root(
paths_conf['model_dir']))
if not model_dir.exists():
model_dir.mkdir(parents=True)
corpus_path = paths_conf['corpus_path']
file_generator = MultiDirGenerator(
base_dir=paths_conf['base_dir'],
source_dir_name=paths_conf['source_dir_name'],
split=True,
min_tokens=config['params']['min_tokens'],
logger=_logger
)
_logger.info('Training dictionary...')
dictionary_params = config['params']['dictionary']
# dictionary_hash = generate_model_hash(dictionary_params)
dictionary_file = Path(paths_conf['dictionary_path'])
checkpoint_log(
timer, _logger, message='Loading or generating dictionary...')
if load_dictionary and dictionary_file.exists():
g_dict = Dictionary.load(str(dictionary_file))
else:
assert not load_dump, "Can't generate dictionary if trying to use a corpus dump. Use --from-files flag instead."
g_dict = Dictionary(file_generator)
g_dict.filter_extremes(
no_below=dictionary_params['no_below'],
no_above=dictionary_params['no_above'],
keep_n=dictionary_params['keep_n'],
keep_tokens=dictionary_params['keep_tokens'])
g_dict.id2token = {id: token for token,
id in g_dict.token2id.items()}
g_dict.save(str(dictionary_file))
checkpoint_log(
timer, _logger, message='Loading or generating corpus...')
if load_dump:
_logger.info('Loading saved corpus...')
corpus = MmCorpus(corpus_path)
else:
_logger.info('Generating corpus...')
corpus = [g_dict.doc2bow(d) for d in file_generator]
_logger.info('Saving corpus to %s...', corpus_path)
MmCorpus.serialize(corpus_path, corpus)
_logger.info('Generating model configurations...')
# Find parameters that are lists
lda_params = config['params']['lda']
list_params = sorted(
filter(lambda x: isinstance(lda_params[x], list), lda_params))
_logger.info(list_params)
lda_params['workers'] = max(1, os.cpu_count() + lda_params['workers'])
lda_params_set = []
for vals in itertools.product(*[lda_params[lp] for lp in list_params]):
_lda_params = dict(lda_params)
for k, val in zip(list_params, vals):
_lda_params[k] = val
lda_params_set.append(_lda_params)
_logger.info('Training models...')
checkpoint_log(timer, _logger, message='Starting now...')
models_count = len(lda_params_set)
for model_num, model_params in enumerate(lda_params_set, 1):
record_config = dict(config)
record_config['params']['lda'] = dict(model_params)
record_config['meta']['model_id'] = ''
model_hash = generate_model_hash(record_config)
sub_model_dir = create_get_directory(model_dir, model_hash)
with open(sub_model_dir / f'model_config_{model_hash}.json', 'w') as open_file:
json.dump(record_config, open_file)
_logger.info("Training model_id: %s", model_hash)
_logger.info(model_params)
model_params['id2word'] = dict(g_dict.id2token)
lda = LdaMulticore(corpus, **model_params)
# TODO: Find a better strategy to name models.
# It can be a hash of the config values for easier tracking?
lda.save(str(sub_model_dir / f'model_{model_hash}.lda.bz2'))
_logger.info(lda.print_topics())
checkpoint_log(
timer, _logger, message=f'Finished running model {model_num}/{models_count}...')
# lda.update(corpus)
# break
if __name__ == '__main__':
# Use in local machine
# python -u scripts/models/train_lda_model.py -c configs/models/lda/test.yml -vv |& tee ./logs/train_lda_model.py.log
# python -u scripts/models/train_lda_model.py -c configs/models/lda/test.yml -vv --load-dictionary --from-dump |& tee ./logs/train_lda_model.py.log
# Use in w1lxbdatad07
# python -u scripts/models/train_lda_model.py -c configs/models/lda/default.yml -vv |& tee ./logs/train_lda_model.py.log
# python -u scripts/models/train_lda_model.py -c configs/models/lda/default.yml -vv --load-dictionary --from-dump |& tee ./logs/train_lda_model.py.log
main()