Skip to content

Commit

Permalink
fix save mode bugs and add predict part test (deepwel#22)
Browse files Browse the repository at this point in the history
* fix save mode bugs and add predict part test

* 1、add log file config. 2、 add local_offline_train script and fi config encode bugs in py2 env

* add comments
  • Loading branch information
zqhZY authored and crownpku committed Nov 25, 2017
1 parent 5d20e5c commit 40c0470
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 19 deletions.
2 changes: 1 addition & 1 deletion chi_annotator/algo_factory/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def persist(self, model_dir):
})

with io.open(os.path.join(model_dir, 'metadata.json'), 'w') as f:
f.write(str(json.dumps(metadata, indent=4)))
f.write(json.dumps(metadata, ensure_ascii=False, indent=4))


class Message(object):
Expand Down
2 changes: 1 addition & 1 deletion chi_annotator/task_center/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_unicode(self, config):
# to unify that and ease further usage of the config, we convert everything to unicode
for k, v in config.items():
if type(v) is bytes:
config[k] = str(v, "utf-8")
config[k] = str(v).encode("utf-8")
return config

def override(self, config):
Expand Down
107 changes: 107 additions & 0 deletions chi_annotator/task_center/local_offline_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
This is a script for run pipeline in command line, just for test data flow, it will be
modified in future.
The features of ths script as follow:
1、load config file from args of command line.
2、train pipeline and save model according config.
3、predict is to be added.
You can run shell command below at root dir of project:
python -m chi_annotator.task_center.local_offline_train -c ./tests/data/test_config.json
You can modify config file dir follow -c argument. you can refer test_config.json for your
own target.
Note:
It only support load local json format train data for now. you can generate tmp train use
command as follow in root dir of project:
python -m tests.utils.txt_to_json
More data load way will be supported in future.
"""
import argparse
import logging
import os

from chi_annotator.algo_factory.components import ComponentBuilder
from chi_annotator.task_center.data_loader import load_local_data
from chi_annotator.task_center.model import Interpreter
from chi_annotator.task_center.model import Trainer

from chi_annotator.task_center.config import AnnotatorConfig

logger = logging.getLogger(__name__)


def create_argparser():
parser = argparse.ArgumentParser(
description='train a custom language parser')
parser.add_argument('-c', '--config', required=True,
help="configuration file")
return parser


class TrainingException(Exception):
"""Exception wrapping lower level exceptions that may happen while training
Attributes:
failed_target_project -- name of the failed project
message -- explanation of why the request is invalid
"""

def __init__(self, failed_target_project=None, exception=None):
self.failed_target_project = failed_target_project
if exception:
self.message = exception.args[0]

def __str__(self):
return self.message


def init(): # pragma: no cover
# type: () -> AnnotatorConfig
"""Combines passed arguments to create Annotator config."""

parser = create_argparser()
args = parser.parse_args()
config = AnnotatorConfig(args.config, os.environ, vars(args))
return config


def do_train_in_worker(config):
# type: (AnnotatorConfig) -> Text
"""Loads the trainer and the data and runs the training in a worker."""

try:
_, _, persisted_path = do_train(config)
return persisted_path
except Exception as e:
raise TrainingException(config.get("project"), e)


def do_train(config, # type: AnnotatorConfig
component_builder=None # type: Optional[ComponentBuilder]
):
# type: (...) -> Tuple[Trainer, Interpreter, Text]
"""Loads the trainer and the data and runs the training of the model."""

# Ensure we are training a model that we can save in the end
# WARN: there is still a race condition if a model with the same name is
# trained in another subprocess
trainer = Trainer(config, component_builder)
training_data = load_local_data(config['org_data'])
interpreter = trainer.train(training_data)
persisted_path = trainer.persist(config['path'],
config['project'],
config['fixed_model_name'])
return trainer, interpreter, persisted_path


if __name__ == '__main__':
config = init()
log_filename = config["log_file"] if config["log_file"] is not None else "task_center.log"
logging.basicConfig(level=config['log_level'], filename=log_filename)

do_train(config)
logger.info("Finished training")
19 changes: 6 additions & 13 deletions chi_annotator/task_center/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def train(self, data):

return Interpreter(self.pipeline, context)

def persist(self, path, persistor=None, project_name=None,
def persist(self, path, project_name=None,
fixed_model_name=None):
# type: (Text, Optional[Persistor], Text) -> Text
"""Persist all components of the pipeline to the passed path.
Expand All @@ -93,8 +93,6 @@ def persist(self, path, persistor=None, project_name=None,
for component in self.pipeline],
}

logger.debug("metadata is :" + metadata)

if project_name is None:
project_name = "default"

Expand All @@ -106,22 +104,17 @@ def persist(self, path, persistor=None, project_name=None,

# create model dir
utils.create_dir(dir_name)

# copy and save train data to dir_name
if self.training_data:
# self.training_data.persist return nothing here
metadata.update(self.training_data.persist(dir_name))
# TODO we have no need to copy and save train data to model.
# if self.training_data:
# # self.training_data.persist return nothing here
# metadata.update(self.training_data.persist(dir_name))

for component in self.pipeline:
update = component.persist(dir_name)
if update:
metadata.update(update)

# save metadata to dir_name
Metadata(metadata, dir_name).persist(dir_name)

if persistor is not None:
persistor.persist(dir_name, model_name, project_name)
logger.info("Successfully saved model into "
"'{}'".format(os.path.abspath(dir_name)))
return dir_name
Expand All @@ -133,7 +126,7 @@ class Interpreter(object):
# Defines all attributes (& default values) that will be returned by `parse`
@staticmethod
def default_output_attributes():
return {"intent": {"name": "", "confidence": 0.0}, "entities": []}
return {"label": {"name": "", "confidence": 0.0}}

@staticmethod
def load(model_dir, config=AnnotatorConfig(), component_builder=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@
"batch_num" : "10",
"inference_num" : "20",
"low_conf_num" : "10",
"confidence_threshold" : "0.95"
"confidence_threshold" : "0.95",
"log_level": "INFO",
"log_file": null
}
12 changes: 11 additions & 1 deletion tests/data/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,15 @@
"model_type": "classification",
"pipeline": ["char_tokenizer"],
"language": "zh",
"path": "./tests/models"
"path": "./tests/models",
"wordvec_file": "./tests/data/vec.txt",
"org_data" : "./tests/data/test_data.json",
"database_name" : "spam_emails_chi",
"labels": ["spam","notspam"],
"batch_num" : "10",
"inference_num" : "20",
"low_conf_num" : "10",
"confidence_threshold" : "0.95",
"log_level": "INFO",
"log_file": null
}
81 changes: 79 additions & 2 deletions tests/taskcenter/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# -*- coding: utf-8 -*-
import json
import os
import io
import shutil

from chi_annotator.task_center.config import AnnotatorConfig
from chi_annotator.task_center.data_loader import load_local_data
from chi_annotator.task_center.model import Trainer
Expand Down Expand Up @@ -75,24 +81,95 @@ def test_pipeline_flow(self):

def test_trainer_persist(self):
"""
test pipeline persist
test pipeline persist, metadata will be saved
:return:
"""
# TODO because only char_tokenizer now. nothing to be persist
pass
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)

trainer = Trainer(config)
assert len(trainer.pipeline) == 1
# char_tokenizer component should been created
assert trainer.pipeline[0] is not None
# create tmp train set
tmp_path = create_tmp_test_file("tmp.json")
train_data = load_local_data(tmp_path)
# rm tmp train set
rm_tmp_file("tmp.json")

trainer.train(train_data)
persisted_path = trainer.persist(config['path'],
config['project'],
config['fixed_model_name'])
# load persisted metadata
metadata_path = os.path.join(persisted_path, 'metadata.json')
with io.open(metadata_path) as f:
metadata = json.load(f)
assert 'trained_at' in metadata
# rm tmp files and dirs
shutil.rmtree(config['path'], ignore_errors=True)

def test_predict_flow(self):
"""
test Interpreter flow, only predict now
:return:
"""
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)

trainer = Trainer(config)
assert len(trainer.pipeline) == 1
# char_tokenizer component should been created
assert trainer.pipeline[0] is not None
# create tmp train set
tmp_path = create_tmp_test_file("tmp.json")
train_data = load_local_data(tmp_path)
# rm tmp train set
rm_tmp_file("tmp.json")

interpreter = trainer.train(train_data)
text = "我是测试"
output = interpreter.parse(text)
assert "label" in output

def test_train_model_empty_pipeline(self):
"""
train model with no component
:return:
"""
# TODO
pass

def test_train_named_model(self):
"""
test train model with certain name
:return:
"""
# TODO
pass

def test_handles_pipeline_with_non_existing_component(self):
"""
handle no exist component in pipeline
:return:
"""
# TODO
pass

def test_load_and_persist_without_train(self):
"""
test save and load model without train
:return:
"""
# TODO
pass

def test_train_with_empty_data(self):
"""
test train with empty train data
:return:
"""
# TODO
pass

3 changes: 3 additions & 0 deletions tests/utils/txt_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ def rm_tmp_file(filename):
"""
os.remove('tests/data/'+filename)

if __name__ == '__main__':
dirname = create_tmp_test_file("test_data.json")
print("test_data.json has been created in " + dirname)

0 comments on commit 40c0470

Please sign in to comment.