Skip to content

Commit

Permalink
Update test_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
FFTYYY committed Nov 10, 2018
1 parent 1f15b52 commit 07fb61e
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions test/core/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance

from fastNLP.core.optimizer import Optimizer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.models.sequence_modeling import SeqLabeling
Expand Down Expand Up @@ -51,6 +50,8 @@ def test_case_1(self):
print ("loss = %f" % (los))
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_2(self):
#验证squash()的正确性
print ("----------------------------------")
Expand Down Expand Up @@ -82,12 +83,14 @@ def test_case_2(self):

y = tc.log(y)
los = loss_func(y , gy)
print ("loss = %f" % (los))

r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
r /= 6
print ("loss = %f" % (los))
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_3(self):
#验证pack_padded_sequence()的正确性
print ("----------------------------------")
Expand Down Expand Up @@ -130,6 +133,8 @@ def test_case_3(self):
r /= 6
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_4(self):
#验证unpad()的正确性
print ("----------------------------------")
Expand Down Expand Up @@ -169,6 +174,9 @@ def test_case_4(self):
r /= 7
print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_5(self):
#验证mask()和make_mask()的正确性
print ("----------------------------------")
Expand Down Expand Up @@ -217,6 +225,10 @@ def test_case_5(self):
r /= 8
print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))
self.assertEqual(int(los2 * 1000), int(r * 1000))

def test_case_6(self):
#验证unpad_mask()的正确性
print ("----------------------------------")
Expand Down Expand Up @@ -256,6 +268,8 @@ def test_case_6(self):
r /= 7
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_7(self):
#验证一些其他东西
print ("----------------------------------")
Expand Down Expand Up @@ -295,6 +309,7 @@ def test_case_7(self):
r = - log(.3) - log(.5) - log(.3)
r /= 3
print ("r = %f" % (r))
self.assertEqual(int(los * 1000), int(r * 1000))

if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit 07fb61e

Please sign in to comment.