Skip to content

Commit

Permalink
add pipeline unit test and simplify test config (deepwel#26)
Browse files Browse the repository at this point in the history
* add pipeline unit test
* simplify test config
* fix relate bugs
  • Loading branch information
zqhZY authored and crownpku committed Nov 28, 2017
1 parent 14cb4f5 commit 651903a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 42 deletions.
2 changes: 1 addition & 1 deletion chi_annotator/task_center/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class Interpreter(object):
# Defines all attributes (& default values) that will be returned by `parse`
@staticmethod
def default_output_attributes():
return {"label": {"name": "", "confidence": 0.0}}
return {'classifylabel': {'name': '', 'confidence': 0.0}}

@staticmethod
def load(model_dir, config=AnnotatorConfig(), component_builder=None,
Expand Down
10 changes: 3 additions & 7 deletions tests/data/test_config.json
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
{
"name": "email_spam_classification",
"path": "./tests/models",
"project": "chi_annotator",
"fixed_model_name": "chi_annotator_models",
"model_type": "classification",
"pipeline": [
"char_tokenizer",
"sentence_embedding_extractor",
"classifier_sklearn"
],
"language": "zh",
"path": "./tests/models",
"embedding_path": "./tests/data/vec.txt",
"embedding_type": "text",
"org_data" : "./tests/data/test_data.json",
"database_name" : "spam_emails_chi",
"labels": ["spam","notspam"],
"batch_num" : "10",
"inference_num" : "20",
"low_conf_num" : "10",
"confidence_threshold" : "0.95",
"log_level": "INFO",
"log_file": null
}
130 changes: 96 additions & 34 deletions tests/taskcenter/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,32 @@
import io
import shutil

import pytest

from chi_annotator.algo_factory.common import TrainingData
from chi_annotator.task_center.config import AnnotatorConfig
from chi_annotator.task_center.data_loader import load_local_data
from chi_annotator.task_center.model import Interpreter
from chi_annotator.task_center.model import Trainer
from tests.utils.txt_to_json import create_tmp_test_jsonfile, rm_tmp_file


class TestTrainer(object):

@classmethod
def setup_class(cls):
""" setup any state specific to the execution of the given class (which
usually contains tests).
"""
pass

@classmethod
def teardown_class(cls):
""" teardown any state that was previously setup with a call to
setup_class.
"""
pass

"""
test Trainer and Interpreter
"""
Expand Down Expand Up @@ -73,14 +92,25 @@ def test_pipeline_flow(self):

interpreter = trainer.train(train_data)
assert interpreter is not None
# TODO because only char_tokenizer now.
out1 = interpreter.parse(("点连接拿红包啦"))

# test persist and load
persisted_path = trainer.persist(config['path'],
config['project'],
config['fixed_model_name'])

interpreter_loaded = Interpreter.load(persisted_path, config)
out2 = interpreter_loaded.parse("点连接拿红包啦")
assert out1.get("classifylabel").get("name") == out2.get("classifylabel").get("name")

# remove tmp models
shutil.rmtree(config['path'], ignore_errors=True)

def test_trainer_persist(self):
"""
test pipeline persist, metadata will be saved
:return:
"""
# TODO because only char_tokenizer now. nothing to be persist
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)

Expand All @@ -106,64 +136,96 @@ def test_trainer_persist(self):
# rm tmp files and dirs
shutil.rmtree(config['path'], ignore_errors=True)

def test_predict_flow(self):
def test_train_model_empty_pipeline(self):
"""
test Interpreter flow, only predict now
train model with no component
:return:
"""
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)
config['pipeline'] = []

trainer = Trainer(config)
assert len(trainer.pipeline) > 0
# create tmp train set
tmp_path = create_tmp_test_jsonfile("tmp.json")
train_data = load_local_data(tmp_path)
# rm tmp train set
rm_tmp_file("tmp.json")

interpreter = trainer.train(train_data)
text = "我是测试"
output = interpreter.parse(text)
assert "label" in output

def test_train_model_empty_pipeline(self):
"""
train model with no component
:return:
"""
# TODO
pass

def test_train_named_model(self):
"""
test train model with certain name
:return:
"""
# TODO
pass
with pytest.raises(ValueError):
trainer = Trainer(config)
trainer.train(train_data)

def test_handles_pipeline_with_non_existing_component(self):
"""
handle no exist component in pipeline
:return:
"""
# TODO
pass
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)
config['pipeline'].append("unknown_component")

tmp_path = create_tmp_test_jsonfile("tmp.json")
train_data = load_local_data(tmp_path)
rm_tmp_file("tmp.json")

with pytest.raises(Exception) as execinfo:
trainer = Trainer(config)
trainer.train(train_data)
assert "Failed to find component" in str(execinfo.value)

def test_load_and_persist_without_train(self):
"""
test save and load model without train
:return:
"""
# TODO
pass
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)

trainer = Trainer(config)
assert len(trainer.pipeline) > 0
# create tmp train set
tmp_path = create_tmp_test_jsonfile("tmp.json")
train_data = load_local_data(tmp_path)
# rm tmp train set
rm_tmp_file("tmp.json")

# interpreter = trainer.train(train_data)
# test persist and load
persisted_path = trainer.persist(config['path'],
config['project'],
config['fixed_model_name'])

interpreter_loaded = Interpreter.load(persisted_path, config)
assert interpreter_loaded.pipeline
assert interpreter_loaded.parse("hello") is not None
assert interpreter_loaded.parse("Hello today is Monday, again!") is not None
# remove tmp models
shutil.rmtree(config['path'], ignore_errors=True)

def test_train_with_empty_data(self):
"""
test train with empty train data
:return:
"""
# TODO
pass
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)

trainer = Trainer(config)
assert len(trainer.pipeline) > 0
# create tmp train set

train_data = TrainingData([])
# rm tmp train set

trainer.train(train_data)
# test persist and load
persisted_path = trainer.persist(config['path'],
config['project'],
config['fixed_model_name'])

interpreter_loaded = Interpreter.load(persisted_path, config)
# remove tmp models
assert interpreter_loaded.pipeline
assert interpreter_loaded.parse("hello") is not None
assert interpreter_loaded.parse("Hello today is Monday, again!") is not None



0 comments on commit 651903a

Please sign in to comment.