Skip to content

Commit

Permalink
* final clean up
Browse files Browse the repository at this point in the history
* remove conflicts
* all tests passed
  • Loading branch information
FengZiYjun committed Dec 7, 2018
1 parent 267baec commit db0a789
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 26 deletions.
4 changes: 2 additions & 2 deletions fastNLP/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def __getitem__(self, idx):
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __getattr__(self, item):
# Not tested. Don't use !!
if item == "field_arrays":
raise AttributeError
# TODO dataset.x
if item in self.field_arrays:
if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item]
try:
reader = DataLoaderRegister.get_reader(item)
Expand Down
2 changes: 1 addition & 1 deletion fastNLP/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch
if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
elif metrics is not None:
elif len(metrics) > 0:
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')

# prepare loss
Expand Down
16 changes: 0 additions & 16 deletions fastNLP/io/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,6 @@ def load_with_cache(cls, data_path, cache_path):
return obj


class ToyLoader0(BaseLoader):
"""
For CharLM
"""

def __init__(self, data_path):
super(ToyLoader0, self).__init__(data_path)

def load(self):
with open(self.data_path, 'r') as f:
corpus = f.read().lower()
import re
corpus = re.sub(r"<unk>", "unk", corpus)
return corpus.split()


class DataLoaderRegister:
""""register for data sets"""
_readers = {}
Expand Down
10 changes: 7 additions & 3 deletions fastNLP/io/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def convert(self, data):
raise NotImplementedError


@DataSet.set_reader("read_naive")
class NativeDataSetLoader(DataSetLoader):
def __init__(self):
super(NativeDataSetLoader, self).__init__()
Expand All @@ -87,7 +86,9 @@ def load(self, path):
return ds


@DataSet.set_reader('read_raw')
DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive')


class RawDataSetLoader(DataSetLoader):
def __init__(self):
super(RawDataSetLoader, self).__init__()
Expand All @@ -101,6 +102,8 @@ def load(self, data_path, split=None):

def convert(self, data):
return convert_seq_dataset(data)


DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')


Expand Down Expand Up @@ -171,6 +174,8 @@ def convert(self, data):
"""Convert lists of strings into Instances with Fields.
"""
return convert_seq2seq_dataset(data)


DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')


Expand Down Expand Up @@ -348,7 +353,6 @@ def convert(self, data):
pass


@DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader):
"""
People Daily Corpus: Chinese word segmentation, POS tag, NER
Expand Down
14 changes: 14 additions & 0 deletions test/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,20 @@ def test_get_field(self):
self.assertTrue(isinstance(ans, FieldArray))
self.assertEqual(ans.content, [[5, 6]] * 10)

def test_reader(self):
# 跑通即可
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)

ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)

ds = DataSet().read_pos("test/data_for_tests/people.txt")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)


class TestDataSetIter(unittest.TestCase):
def test__repr__(self):
Expand Down
18 changes: 14 additions & 4 deletions test/core/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class TestOptim(unittest.TestCase):
def test_SGD(self):
optim = SGD(torch.nn.Linear(10, 3).parameters())
optim = SGD(model_params=torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
Expand All @@ -22,13 +22,18 @@ def test_SGD(self):
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)

with self.assertRaises(RuntimeError):
optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD))

with self.assertRaises(TypeError):
_ = SGD("???")
with self.assertRaises(RuntimeError):
with self.assertRaises(TypeError):
_ = SGD(0.001, lr=0.002)

def test_Adam(self):
optim = Adam(torch.nn.Linear(10, 3).parameters())
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
Expand All @@ -42,3 +47,8 @@ def test_Adam(self):
optim = Adam(lr=0.002, weight_decay=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)

optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam))

0 comments on commit db0a789

Please sign in to comment.