Skip to content

Commit

Permalink
Merge pull request deepwel#20 from zqhZY/master
Browse files Browse the repository at this point in the history
  • Loading branch information
crownpku authored Nov 24, 2017
2 parents 047b28c + 507206c commit 5d20e5c
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 25 deletions.
49 changes: 27 additions & 22 deletions chi_annotator/task_center/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@
# -*- coding: utf8 -*-

import datetime
import io
import json
import logging
import os
import copy


from chi_annotator.task_center.components import ComponentBuilder
from chi_annotator.task_center.components import Component
from chi_annotator.task_center.common import Metadata
from chi_annotator.taks_center.common import Message
from chi_annotator.task_center.utils import *
from chi_annotator.algo_factory import components
from chi_annotator.algo_factory.common import Metadata
from chi_annotator.algo_factory.common import Message
from chi_annotator.algo_factory import utils
from chi_annotator.task_center.config import AnnotatorConfig

logger = logging.getLogger(__name__)


class Trainer(object):
"""Trainer will load the data and train all components.
Expand All @@ -27,7 +25,7 @@ class Trainer(object):
SUPPORTED_LANGUAGES = ["zh"]

def __init__(self, config, component_builder=None, skip_validation=False):
# type: (RasaNLUConfig, Optional[ComponentBuilder], bool) -> None
# type: (AnnotatorConfig, Optional[ComponentBuilder], bool) -> None

self.config = config
self.skip_validation = skip_validation
Expand All @@ -41,13 +39,13 @@ def __init__(self, config, component_builder=None, skip_validation=False):
# Before instantiating the component classes, lets check if all
# required packages are available
# TODO
#if not self.skip_validation:
# if not self.skip_validation:
# components.validate_requirements(config.pipeline)

# Transform the passed names of the pipeline components into classes
for component_name in config.pipeline:
component = component_builder.create_component(
component_name, config)
component_name, config)
self.pipeline.append(component)

def train(self, data):
Expand All @@ -72,7 +70,8 @@ def train(self, data):

for i, component in enumerate(self.pipeline):
logger.info("Starting to train component {}".format(component.name))
component.prepare_partial_processing(self.pipeline[:i], context)
# TODO , should we need component.prepare_partial_processing now?
# component.prepare_partial_processing(self.pipeline[:i], context)
updates = component.train(working_data, self.config, **context)
logger.info("Finished training component.")
if updates:
Expand All @@ -94,6 +93,8 @@ def persist(self, path, persistor=None, project_name=None,
for component in self.pipeline],
}

logger.debug("metadata is :" + metadata)

if project_name is None:
project_name = "default"

Expand All @@ -103,16 +104,20 @@ def persist(self, path, persistor=None, project_name=None,
model_name = "model_" + timestamp
dir_name = os.path.join(path, project_name, model_name)

create_dir(dir_name)
# create model dir
utils.create_dir(dir_name)

# copy and save train data to dir_name
if self.training_data:
# self.training_data.persist return nothing here
metadata.update(self.training_data.persist(dir_name))

for component in self.pipeline:
update = component.persist(dir_name)
if update:
metadata.update(update)

# save metadata to dir_name
Metadata(metadata, dir_name).persist(dir_name)

if persistor is not None:
Expand All @@ -131,26 +136,26 @@ def default_output_attributes():
return {"intent": {"name": "", "confidence": 0.0}, "entities": []}

@staticmethod
def load(model_dir, config=RasaNLUConfig(), component_builder=None,
def load(model_dir, config=AnnotatorConfig(), component_builder=None,
skip_valdation=False):
"""Creates an interpreter based on a persisted model."""

if isinstance(model_dir, Metadata):
# this is for backwards compatibilities (metadata passed as a dict)
model_metadata = model_dir
logger.warn("Deprecated use of `Interpreter.load` with a metadata "
"object. If you want to directly pass the metadata, "
"use `Interpreter.create(metadata, ...)`. If you want "
"to load the metadata from file, use "
"`Interpreter.load(model_dir, ...)")
logger.warning("Deprecated use of `Interpreter.load` with a metadata "
"object. If you want to directly pass the metadata, "
"use `Interpreter.create(metadata, ...)`. If you want "
"to load the metadata from file, use "
"`Interpreter.load(model_dir, ...)")
else:
model_metadata = Metadata.load(model_dir)
return Interpreter.create(model_metadata, config, component_builder,
skip_valdation)

@staticmethod
def create(model_metadata, # type: Metadata
config, # type: RasaNLUConfig
config, # type: AnnotatorConfig
component_builder=None, # type: Optional[ComponentBuilder]
skip_valdation=False # type: bool
):
Expand All @@ -173,8 +178,8 @@ def create(model_metadata, # type: Metadata

for component_name in model_metadata.pipeline:
component = component_builder.load_component(
component_name, model_metadata.model_dir,
model_metadata, config=config, **context)
component_name, model_metadata.model_dir,
model_metadata, config=config, **context)
try:
updates = component.provide_context()
if updates:
Expand Down
7 changes: 7 additions & 0 deletions tests/data/test_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"name": "email_spam_classification",
"model_type": "classification",
"pipeline": ["char_tokenizer"],
"language": "zh",
"path": "./tests/models"
}
62 changes: 60 additions & 2 deletions tests/taskcenter/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

from chi_annotator.task_center.config import AnnotatorConfig
from chi_annotator.task_center.data_loader import load_local_data
from chi_annotator.task_center.model import Trainer
from tests.utils.txt_to_json import create_tmp_test_file, rm_tmp_file


Expand Down Expand Up @@ -38,3 +37,62 @@ def test_load_default_config(self):
"""
config = AnnotatorConfig()
assert config["config"] == "config.json"

def test_trainer_init(self):
"""
test trainer
:return:
"""
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)

trainer = Trainer(config)
assert len(trainer.pipeline) == 1
# char_tokenizer component should been created
assert trainer.pipeline[0] is not None

def test_pipeline_flow(self):
"""
test trainer's train func for pipeline
:return:
"""
test_config = "tests/data/test_config.json"
config = AnnotatorConfig(test_config)

trainer = Trainer(config)
assert len(trainer.pipeline) == 1
# char_tokenizer component should been created
assert trainer.pipeline[0] is not None
# create tmp train set
tmp_path = create_tmp_test_file("tmp.json")
train_data = load_local_data(tmp_path)
# rm tmp train set
rm_tmp_file("tmp.json")

interpreter = trainer.train(train_data)
assert interpreter is not None
# TODO because only char_tokenizer now.

def test_trainer_persist(self):
"""
test pipeline persist
:return:
"""
# TODO because only char_tokenizer now. nothing to be persist
pass

def test_train_model_empty_pipeline(self):
pass

def test_train_named_model(self):
pass

def test_handles_pipeline_with_non_existing_component(self):
pass

def test_load_and_persist_without_train(self):
pass

def test_train_with_empty_data(self):
pass

3 changes: 2 additions & 1 deletion tests/utils/txt_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ def rm_tmp_file(filename):
rm tmp files
:return:
"""
os.remove('tests/data/'+filename)
os.remove('tests/data/'+filename)

0 comments on commit 5d20e5c

Please sign in to comment.