diff --git a/chi_annotator/algo_factory/common.py b/chi_annotator/algo_factory/common.py index bf19ca7..a01be51 100644 --- a/chi_annotator/algo_factory/common.py +++ b/chi_annotator/algo_factory/common.py @@ -92,7 +92,7 @@ def persist(self, model_dir): }) with io.open(os.path.join(model_dir, 'metadata.json'), 'w') as f: - f.write(str(json.dumps(metadata, indent=4))) + f.write(json.dumps(metadata, ensure_ascii=False, indent=4)) class Message(object): diff --git a/chi_annotator/task_center/config.py b/chi_annotator/task_center/config.py index 35cd529..1ccbdfd 100644 --- a/chi_annotator/task_center/config.py +++ b/chi_annotator/task_center/config.py @@ -162,7 +162,7 @@ def make_unicode(self, config): # to unify that and ease further usage of the config, we convert everything to unicode for k, v in config.items(): if type(v) is bytes: - config[k] = str(v, "utf-8") + config[k] = str(v).encode("utf-8") return config def override(self, config): diff --git a/chi_annotator/task_center/local_offline_train.py b/chi_annotator/task_center/local_offline_train.py new file mode 100644 index 0000000..dcf74af --- /dev/null +++ b/chi_annotator/task_center/local_offline_train.py @@ -0,0 +1,107 @@ +""" +This is a script for run pipeline in command line, just for test data flow, it will be + modified in future. + +The features of ths script as follow: +1、load config file from args of command line. +2、train pipeline and save model according config. +3、predict is to be added. + +You can run shell command below at root dir of project: + python -m chi_annotator.task_center.local_offline_train -c ./tests/data/test_config.json + +You can modify config file dir follow -c argument. you can refer test_config.json for your +own target. + +Note: + It only support load local json format train data for now. you can generate tmp train use + command as follow in root dir of project: + python -m tests.utils.txt_to_json + + More data load way will be supported in future. +""" +import argparse +import logging +import os + +from chi_annotator.algo_factory.components import ComponentBuilder +from chi_annotator.task_center.data_loader import load_local_data +from chi_annotator.task_center.model import Interpreter +from chi_annotator.task_center.model import Trainer + +from chi_annotator.task_center.config import AnnotatorConfig + +logger = logging.getLogger(__name__) + + +def create_argparser(): + parser = argparse.ArgumentParser( + description='train a custom language parser') + parser.add_argument('-c', '--config', required=True, + help="configuration file") + return parser + + +class TrainingException(Exception): + """Exception wrapping lower level exceptions that may happen while training + + Attributes: + failed_target_project -- name of the failed project + message -- explanation of why the request is invalid + """ + + def __init__(self, failed_target_project=None, exception=None): + self.failed_target_project = failed_target_project + if exception: + self.message = exception.args[0] + + def __str__(self): + return self.message + + +def init(): # pragma: no cover + # type: () -> AnnotatorConfig + """Combines passed arguments to create Annotator config.""" + + parser = create_argparser() + args = parser.parse_args() + config = AnnotatorConfig(args.config, os.environ, vars(args)) + return config + + +def do_train_in_worker(config): + # type: (AnnotatorConfig) -> Text + """Loads the trainer and the data and runs the training in a worker.""" + + try: + _, _, persisted_path = do_train(config) + return persisted_path + except Exception as e: + raise TrainingException(config.get("project"), e) + + +def do_train(config, # type: AnnotatorConfig + component_builder=None # type: Optional[ComponentBuilder] + ): + # type: (...) -> Tuple[Trainer, Interpreter, Text] + """Loads the trainer and the data and runs the training of the model.""" + + # Ensure we are training a model that we can save in the end + # WARN: there is still a race condition if a model with the same name is + # trained in another subprocess + trainer = Trainer(config, component_builder) + training_data = load_local_data(config['org_data']) + interpreter = trainer.train(training_data) + persisted_path = trainer.persist(config['path'], + config['project'], + config['fixed_model_name']) + return trainer, interpreter, persisted_path + + +if __name__ == '__main__': + config = init() + log_filename = config["log_file"] if config["log_file"] is not None else "task_center.log" + logging.basicConfig(level=config['log_level'], filename=log_filename) + + do_train(config) + logger.info("Finished training") diff --git a/chi_annotator/task_center/model.py b/chi_annotator/task_center/model.py index 2b85ba7..d219c09 100644 --- a/chi_annotator/task_center/model.py +++ b/chi_annotator/task_center/model.py @@ -79,7 +79,7 @@ def train(self, data): return Interpreter(self.pipeline, context) - def persist(self, path, persistor=None, project_name=None, + def persist(self, path, project_name=None, fixed_model_name=None): # type: (Text, Optional[Persistor], Text) -> Text """Persist all components of the pipeline to the passed path. @@ -93,8 +93,6 @@ def persist(self, path, persistor=None, project_name=None, for component in self.pipeline], } - logger.debug("metadata is :" + metadata) - if project_name is None: project_name = "default" @@ -106,22 +104,17 @@ def persist(self, path, persistor=None, project_name=None, # create model dir utils.create_dir(dir_name) - - # copy and save train data to dir_name - if self.training_data: - # self.training_data.persist return nothing here - metadata.update(self.training_data.persist(dir_name)) + # TODO we have no need to copy and save train data to model. + # if self.training_data: + # # self.training_data.persist return nothing here + # metadata.update(self.training_data.persist(dir_name)) for component in self.pipeline: update = component.persist(dir_name) if update: metadata.update(update) - # save metadata to dir_name Metadata(metadata, dir_name).persist(dir_name) - - if persistor is not None: - persistor.persist(dir_name, model_name, project_name) logger.info("Successfully saved model into " "'{}'".format(os.path.abspath(dir_name))) return dir_name @@ -133,7 +126,7 @@ class Interpreter(object): # Defines all attributes (& default values) that will be returned by `parse` @staticmethod def default_output_attributes(): - return {"intent": {"name": "", "confidence": 0.0}, "entities": []} + return {"label": {"name": "", "confidence": 0.0}} @staticmethod def load(model_dir, config=AnnotatorConfig(), component_builder=None, diff --git a/chi_annotator/user_instance/examples/classify/spam_email_classify_config.json b/chi_annotator/user_instance/examples/classify/spam_email_classify_config.json index b69a251..1adff48 100644 --- a/chi_annotator/user_instance/examples/classify/spam_email_classify_config.json +++ b/chi_annotator/user_instance/examples/classify/spam_email_classify_config.json @@ -15,5 +15,7 @@ "batch_num" : "10", "inference_num" : "20", "low_conf_num" : "10", - "confidence_threshold" : "0.95" + "confidence_threshold" : "0.95", + "log_level": "INFO", + "log_file": null } diff --git a/tests/data/test_config.json b/tests/data/test_config.json index 7ab2ccf..9754a39 100644 --- a/tests/data/test_config.json +++ b/tests/data/test_config.json @@ -3,5 +3,15 @@ "model_type": "classification", "pipeline": ["char_tokenizer"], "language": "zh", - "path": "./tests/models" + "path": "./tests/models", + "wordvec_file": "./tests/data/vec.txt", + "org_data" : "./tests/data/test_data.json", + "database_name" : "spam_emails_chi", + "labels": ["spam","notspam"], + "batch_num" : "10", + "inference_num" : "20", + "low_conf_num" : "10", + "confidence_threshold" : "0.95", + "log_level": "INFO", + "log_file": null } diff --git a/tests/taskcenter/test_trainer.py b/tests/taskcenter/test_trainer.py index 7d43ae5..8990010 100644 --- a/tests/taskcenter/test_trainer.py +++ b/tests/taskcenter/test_trainer.py @@ -1,3 +1,9 @@ +# -*- coding: utf-8 -*- +import json +import os +import io +import shutil + from chi_annotator.task_center.config import AnnotatorConfig from chi_annotator.task_center.data_loader import load_local_data from chi_annotator.task_center.model import Trainer @@ -75,24 +81,95 @@ def test_pipeline_flow(self): def test_trainer_persist(self): """ - test pipeline persist + test pipeline persist, metadata will be saved :return: """ # TODO because only char_tokenizer now. nothing to be persist - pass + test_config = "tests/data/test_config.json" + config = AnnotatorConfig(test_config) + + trainer = Trainer(config) + assert len(trainer.pipeline) == 1 + # char_tokenizer component should been created + assert trainer.pipeline[0] is not None + # create tmp train set + tmp_path = create_tmp_test_file("tmp.json") + train_data = load_local_data(tmp_path) + # rm tmp train set + rm_tmp_file("tmp.json") + + trainer.train(train_data) + persisted_path = trainer.persist(config['path'], + config['project'], + config['fixed_model_name']) + # load persisted metadata + metadata_path = os.path.join(persisted_path, 'metadata.json') + with io.open(metadata_path) as f: + metadata = json.load(f) + assert 'trained_at' in metadata + # rm tmp files and dirs + shutil.rmtree(config['path'], ignore_errors=True) + + def test_predict_flow(self): + """ + test Interpreter flow, only predict now + :return: + """ + test_config = "tests/data/test_config.json" + config = AnnotatorConfig(test_config) + + trainer = Trainer(config) + assert len(trainer.pipeline) == 1 + # char_tokenizer component should been created + assert trainer.pipeline[0] is not None + # create tmp train set + tmp_path = create_tmp_test_file("tmp.json") + train_data = load_local_data(tmp_path) + # rm tmp train set + rm_tmp_file("tmp.json") + + interpreter = trainer.train(train_data) + text = "我是测试" + output = interpreter.parse(text) + assert "label" in output def test_train_model_empty_pipeline(self): + """ + train model with no component + :return: + """ + # TODO pass def test_train_named_model(self): + """ + test train model with certain name + :return: + """ + # TODO pass def test_handles_pipeline_with_non_existing_component(self): + """ + handle no exist component in pipeline + :return: + """ + # TODO pass def test_load_and_persist_without_train(self): + """ + test save and load model without train + :return: + """ + # TODO pass def test_train_with_empty_data(self): + """ + test train with empty train data + :return: + """ + # TODO pass diff --git a/tests/utils/txt_to_json.py b/tests/utils/txt_to_json.py index e2fe0dd..7b68a77 100644 --- a/tests/utils/txt_to_json.py +++ b/tests/utils/txt_to_json.py @@ -28,3 +28,6 @@ def rm_tmp_file(filename): """ os.remove('tests/data/'+filename) +if __name__ == '__main__': + dirname = create_tmp_test_file("test_data.json") + print("test_data.json has been created in " + dirname)