diff --git a/chi_annotator/task_center/model.py b/chi_annotator/task_center/model.py index 9c62b50..c2e73f6 100644 --- a/chi_annotator/task_center/model.py +++ b/chi_annotator/task_center/model.py @@ -126,7 +126,7 @@ class Interpreter(object): # Defines all attributes (& default values) that will be returned by `parse` @staticmethod def default_output_attributes(): - return {"label": {"name": "", "confidence": 0.0}} + return {'classifylabel': {'name': '', 'confidence': 0.0}} @staticmethod def load(model_dir, config=AnnotatorConfig(), component_builder=None, diff --git a/tests/data/test_config.json b/tests/data/test_config.json index 33d8706..4bb8fab 100644 --- a/tests/data/test_config.json +++ b/tests/data/test_config.json @@ -1,5 +1,8 @@ { "name": "email_spam_classification", + "path": "./tests/models", + "project": "chi_annotator", + "fixed_model_name": "chi_annotator_models", "model_type": "classification", "pipeline": [ "char_tokenizer", @@ -7,16 +10,9 @@ "classifier_sklearn" ], "language": "zh", - "path": "./tests/models", "embedding_path": "./tests/data/vec.txt", "embedding_type": "text", "org_data" : "./tests/data/test_data.json", - "database_name" : "spam_emails_chi", - "labels": ["spam","notspam"], - "batch_num" : "10", - "inference_num" : "20", - "low_conf_num" : "10", - "confidence_threshold" : "0.95", "log_level": "INFO", "log_file": null } diff --git a/tests/taskcenter/test_trainer.py b/tests/taskcenter/test_trainer.py index 1944662..6defcf0 100644 --- a/tests/taskcenter/test_trainer.py +++ b/tests/taskcenter/test_trainer.py @@ -4,13 +4,32 @@ import io import shutil +import pytest + +from chi_annotator.algo_factory.common import TrainingData from chi_annotator.task_center.config import AnnotatorConfig from chi_annotator.task_center.data_loader import load_local_data +from chi_annotator.task_center.model import Interpreter from chi_annotator.task_center.model import Trainer from tests.utils.txt_to_json import create_tmp_test_jsonfile, rm_tmp_file class TestTrainer(object): + + @classmethod + def setup_class(cls): + """ setup any state specific to the execution of the given class (which + usually contains tests). + """ + pass + + @classmethod + def teardown_class(cls): + """ teardown any state that was previously setup with a call to + setup_class. + """ + pass + """ test Trainer and Interpreter """ @@ -73,14 +92,25 @@ def test_pipeline_flow(self): interpreter = trainer.train(train_data) assert interpreter is not None - # TODO because only char_tokenizer now. + out1 = interpreter.parse(("点连接拿红包啦")) + + # test persist and load + persisted_path = trainer.persist(config['path'], + config['project'], + config['fixed_model_name']) + + interpreter_loaded = Interpreter.load(persisted_path, config) + out2 = interpreter_loaded.parse("点连接拿红包啦") + assert out1.get("classifylabel").get("name") == out2.get("classifylabel").get("name") + + # remove tmp models + shutil.rmtree(config['path'], ignore_errors=True) def test_trainer_persist(self): """ test pipeline persist, metadata will be saved :return: """ - # TODO because only char_tokenizer now. nothing to be persist test_config = "tests/data/test_config.json" config = AnnotatorConfig(test_config) @@ -106,64 +136,96 @@ def test_trainer_persist(self): # rm tmp files and dirs shutil.rmtree(config['path'], ignore_errors=True) - def test_predict_flow(self): + def test_train_model_empty_pipeline(self): """ - test Interpreter flow, only predict now + train model with no component :return: """ test_config = "tests/data/test_config.json" config = AnnotatorConfig(test_config) + config['pipeline'] = [] - trainer = Trainer(config) - assert len(trainer.pipeline) > 0 - # create tmp train set tmp_path = create_tmp_test_jsonfile("tmp.json") train_data = load_local_data(tmp_path) - # rm tmp train set rm_tmp_file("tmp.json") - interpreter = trainer.train(train_data) - text = "我是测试" - output = interpreter.parse(text) - assert "label" in output - - def test_train_model_empty_pipeline(self): - """ - train model with no component - :return: - """ - # TODO - pass - - def test_train_named_model(self): - """ - test train model with certain name - :return: - """ - # TODO - pass + with pytest.raises(ValueError): + trainer = Trainer(config) + trainer.train(train_data) def test_handles_pipeline_with_non_existing_component(self): """ handle no exist component in pipeline :return: """ - # TODO - pass + test_config = "tests/data/test_config.json" + config = AnnotatorConfig(test_config) + config['pipeline'].append("unknown_component") + + tmp_path = create_tmp_test_jsonfile("tmp.json") + train_data = load_local_data(tmp_path) + rm_tmp_file("tmp.json") + + with pytest.raises(Exception) as execinfo: + trainer = Trainer(config) + trainer.train(train_data) + assert "Failed to find component" in str(execinfo.value) def test_load_and_persist_without_train(self): """ test save and load model without train :return: """ - # TODO - pass + test_config = "tests/data/test_config.json" + config = AnnotatorConfig(test_config) + + trainer = Trainer(config) + assert len(trainer.pipeline) > 0 + # create tmp train set + tmp_path = create_tmp_test_jsonfile("tmp.json") + train_data = load_local_data(tmp_path) + # rm tmp train set + rm_tmp_file("tmp.json") + + # interpreter = trainer.train(train_data) + # test persist and load + persisted_path = trainer.persist(config['path'], + config['project'], + config['fixed_model_name']) + + interpreter_loaded = Interpreter.load(persisted_path, config) + assert interpreter_loaded.pipeline + assert interpreter_loaded.parse("hello") is not None + assert interpreter_loaded.parse("Hello today is Monday, again!") is not None + # remove tmp models + shutil.rmtree(config['path'], ignore_errors=True) def test_train_with_empty_data(self): """ test train with empty train data :return: """ - # TODO - pass + test_config = "tests/data/test_config.json" + config = AnnotatorConfig(test_config) + + trainer = Trainer(config) + assert len(trainer.pipeline) > 0 + # create tmp train set + + train_data = TrainingData([]) + # rm tmp train set + + trainer.train(train_data) + # test persist and load + persisted_path = trainer.persist(config['path'], + config['project'], + config['fixed_model_name']) + + interpreter_loaded = Interpreter.load(persisted_path, config) + # remove tmp models + assert interpreter_loaded.pipeline + assert interpreter_loaded.parse("hello") is not None + assert interpreter_loaded.parse("Hello today is Monday, again!") is not None + +