-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
42 lines (34 loc) · 1.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import json
import logging
import os
from pprint import pformat
from importlib import import_module
from vocab import Vocab, UnkVocab
from dataset import Dataset, Ontology
from preprocess_data import dann
def load_dataset(splits=('train', 'dev', 'test')):
with open(os.path.join(dann, 'ontology.json')) as f:
ontology = Ontology.from_dict(json.load(f))
with open(os.path.join(dann, 'vocab.json')) as f:
vocab = Vocab.from_dict(json.load(f))
with open(os.path.join(dann, 'emb.json')) as f:
E = json.load(f)
# with open(os.path.join(dann, 'init_state.json')) as f:
# init_state = json.load(f)
init_state = None
dataset = {}
for split in splits:
with open(os.path.join(dann, '{}.json'.format(split))) as f:
logging.warn('loading split {}'.format(split))
dataset[split] = Dataset.from_dict(json.load(f))
logging.info('dataset sizes: {}'.format(pformat({k: len(v) for k, v in dataset.items()})))
return dataset, ontology, vocab, E, init_state
def get_models():
return [m.replace('.py', '') for m in os.listdir('models') if not m.startswith('_') and m != 'model']
def load_model(model, *args, **kwargs):
Model = import_module('models.{}'.format(model)).Model
model = Model(*args, **kwargs)
logging.info('loaded model {}'.format(Model))
return model
if __name__=='__main__':
load_dataset()