From db0a789d619c0e47564c89c910ba1db9e26a49c1 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Fri, 7 Dec 2018 19:09:50 +0800 Subject: [PATCH] * final clean up * remove conflicts * all tests passed --- fastNLP/core/dataset.py | 4 ++-- fastNLP/core/trainer.py | 2 +- fastNLP/io/base_loader.py | 16 ---------------- fastNLP/io/dataset_loader.py | 10 +++++++--- test/core/test_dataset.py | 14 ++++++++++++++ test/core/test_optimizer.py | 18 ++++++++++++++---- 6 files changed, 38 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index a08961fc..52dac2fc 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -98,10 +98,10 @@ def __getitem__(self, idx): raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) def __getattr__(self, item): + # Not tested. Don't use !! if item == "field_arrays": raise AttributeError - # TODO dataset.x - if item in self.field_arrays: + if isinstance(item, str) and item in self.field_arrays: return self.field_arrays[item] try: reader = DataLoaderRegister.get_reader(item) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6cb6b560..5997ebbc 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -85,7 +85,7 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch if metric_key is not None: self.increase_better = False if metric_key[0] == "-" else True self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key - elif metrics is not None: + elif len(metrics) > 0: self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') # prepare loss diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index a3ce410b..b01c233a 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -31,22 +31,6 @@ def load_with_cache(cls, data_path, cache_path): return obj -class ToyLoader0(BaseLoader): - """ - For CharLM - """ - - def __init__(self, data_path): - super(ToyLoader0, self).__init__(data_path) - - def load(self): - with open(self.data_path, 'r') as f: - corpus = f.read().lower() - import re - corpus = re.sub(r"", "unk", corpus) - return corpus.split() - - class DataLoaderRegister: """"register for data sets""" _readers = {} diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index a1cfe33f..641a631e 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -75,7 +75,6 @@ def convert(self, data): raise NotImplementedError -@DataSet.set_reader("read_naive") class NativeDataSetLoader(DataSetLoader): def __init__(self): super(NativeDataSetLoader, self).__init__() @@ -87,7 +86,9 @@ def load(self, path): return ds -@DataSet.set_reader('read_raw') +DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') + + class RawDataSetLoader(DataSetLoader): def __init__(self): super(RawDataSetLoader, self).__init__() @@ -101,6 +102,8 @@ def load(self, data_path, split=None): def convert(self, data): return convert_seq_dataset(data) + + DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') @@ -171,6 +174,8 @@ def convert(self, data): """Convert lists of strings into Instances with Fields. """ return convert_seq2seq_dataset(data) + + DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') @@ -348,7 +353,6 @@ def convert(self, data): pass -@DataSet.set_reader('read_people_daily') class PeopleDailyCorpusLoader(DataSetLoader): """ People Daily Corpus: Chinese word segmentation, POS tag, NER diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 74ad5958..01963af6 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -178,6 +178,20 @@ def test_get_field(self): self.assertTrue(isinstance(ans, FieldArray)) self.assertEqual(ans.content, [[5, 6]] * 10) + def test_reader(self): + # 跑通即可 + ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv") + self.assertTrue(isinstance(ds, DataSet)) + self.assertTrue(len(ds) > 0) + + ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt") + self.assertTrue(isinstance(ds, DataSet)) + self.assertTrue(len(ds) > 0) + + ds = DataSet().read_pos("test/data_for_tests/people.txt") + self.assertTrue(isinstance(ds, DataSet)) + self.assertTrue(len(ds) > 0) + class TestDataSetIter(unittest.TestCase): def test__repr__(self): diff --git a/test/core/test_optimizer.py b/test/core/test_optimizer.py index 8ffa1a72..83ed6000 100644 --- a/test/core/test_optimizer.py +++ b/test/core/test_optimizer.py @@ -7,7 +7,7 @@ class TestOptim(unittest.TestCase): def test_SGD(self): - optim = SGD(torch.nn.Linear(10, 3).parameters()) + optim = SGD(model_params=torch.nn.Linear(10, 3).parameters()) self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("momentum" in optim.__dict__["settings"]) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) @@ -22,13 +22,18 @@ def test_SGD(self): self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) - with self.assertRaises(RuntimeError): + optim = SGD(0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) + res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) + self.assertTrue(isinstance(res, torch.optim.SGD)) + + with self.assertRaises(TypeError): _ = SGD("???") - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): _ = SGD(0.001, lr=0.002) def test_Adam(self): - optim = Adam(torch.nn.Linear(10, 3).parameters()) + optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("weight_decay" in optim.__dict__["settings"]) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) @@ -42,3 +47,8 @@ def test_Adam(self): optim = Adam(lr=0.002, weight_decay=0.989) self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) + + optim = Adam(0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) + res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) + self.assertTrue(isinstance(res, torch.optim.Adam))