Skip to content

Commit

Permalink
Merge pull request #106 from FFTYYY/master
Browse files Browse the repository at this point in the history
update loss & a small change in requirements
  • Loading branch information
xuyige authored Nov 12, 2018
2 parents 5ec58e3 + 3cadd5a commit abf840c
Show file tree
Hide file tree
Showing 4 changed files with 509 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
## Requirements

- numpy>=1.14.2
- torch==0.4.0
- torch>=0.4.0
- torchvision>=0.1.8
- tensorboardX

Expand Down
246 changes: 192 additions & 54 deletions fastNLP/core/loss.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,196 @@
import torch

def squash(predict , truth , **kwargs):
'''To reshape tensors in order to fit Loss functions in pytorch
:param predict : Tensor, model output
:param truth : Tensor, truth from dataset
:param **kwargs : extra arguments
:return predict , truth: predict & truth after processing
'''
return predict.view(-1 , predict.size()[-1]) , truth.view(-1,)

def unpad(predict , truth , **kwargs):
'''To process padded sequence output to get true loss
Using pack_padded_sequence() method
This method contains squash()
:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
kwargs["lens"] : list or LongTensor, [batch_size]
the i-th element is true lengths of i-th sequence
:return predict , truth: predict & truth after processing
'''
if kwargs.get("lens") is None:
return predict , truth
lens = torch.LongTensor(kwargs["lens"])
lens , idx = torch.sort(lens , descending = True)
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx] , lens , batch_first = True).data
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx] , lens , batch_first = True).data
return predict , truth

def unpad_mask(predict , truth , **kwargs):
'''To process padded sequence output to get true loss
Using mask() method
This method contains squash()
:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
kwargs["lens"] : list or LongTensor, [batch_size]
the i-th element is true lengths of i-th sequence
:return predict , truth: predict & truth after processing
'''
if kwargs.get("lens") is None:
return predict , truth
mas = make_mask(kwargs["lens"] , truth.size()[1])
return mask(predict , truth , mask = mas)

def mask(predict , truth , **kwargs):
'''To select specific elements from Tensor
This method contains squash()
:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist
kwargs["mask"] : ByteTensor, [batch_size , max_len]
the mask Tensor , the position that is 1 will be selected
:return predict , truth: predict & truth after processing
'''
if kwargs.get("mask") is None:
return predict , truth
mask = kwargs["mask"]

predict , truth = squash(predict , truth)
mask = mask.view(-1,)

predict = torch.masked_select(predict.permute(1,0) , mask).view(predict.size()[-1] , -1).permute(1,0)
truth = torch.masked_select(truth , mask)

return predict , truth

def make_mask(lens , tar_len):
'''to generate a mask that select [:lens[i]] for i-th element
embezzle from fastNLP.models.sequence_modeling.seq_mask
:param lens : list or LongTensor, [batch_size]
:param tar_len : int
:return mask : ByteTensor
'''
lens = torch.LongTensor(lens)
mask = [torch.ge(lens, i + 1) for i in range(tar_len)]
mask = torch.stack(mask, 1)
return mask

#map string to function. Just for more elegant using
method_dict = {
"squash" : squash,
"unpad" : unpad,
"unpad_mask" : unpad_mask,
"mask" : mask,
}

loss_function_name = {
"L1Loss".lower() : torch.nn.L1Loss,
"BCELoss".lower() : torch.nn.BCELoss,
"MSELoss".lower() : torch.nn.MSELoss,
"NLLLoss".lower() : torch.nn.NLLLoss,
"KLDivLoss".lower() : torch.nn.KLDivLoss,
"NLLLoss2dLoss".lower() : torch.nn.NLLLoss2d, #every name should end with "loss"
"SmoothL1Loss".lower() : torch.nn.SmoothL1Loss,
"SoftMarginLoss".lower() : torch.nn.SoftMarginLoss,
"PoissonNLLLoss".lower() : torch.nn.PoissonNLLLoss,
"MultiMarginLoss".lower() : torch.nn.MultiMarginLoss,
"CrossEntropyLoss".lower() : torch.nn.CrossEntropyLoss,
"BCEWithLogitsLoss".lower() : torch.nn.BCEWithLogitsLoss,
"MarginRankingLoss".lower() : torch.nn.MarginRankingLoss,
"TripletMarginLoss".lower() : torch.nn.TripletMarginLoss,
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss,
"CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss,
"MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss,
"MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss,
}

class Loss(object):
"""Loss function of the algorithm,
either the wrapper of a loss function from framework, or a user-defined loss (need pytorch auto_grad support)
"""

def __init__(self, args):
"""
:param args: None or str, the name of a loss function.
"""
if args is None:
# this is useful when Trainer.__init__ performs type check
self._loss = None
elif isinstance(args, str):
self._loss = self._borrow_from_pytorch(args)
else:
raise NotImplementedError

def get(self):
"""
:return self._loss: the loss function
"""
return self._loss

@staticmethod
def _borrow_from_pytorch(loss_name):
"""Given a name of a loss function, return it from PyTorch.
:param loss_name: str, the name of a loss function
- cross_entropy: combines log softmax and nll loss in a single function.
- nll: negative log likelihood
:return loss: a PyTorch loss
"""

class InnerCrossEntropy:
"""A simple wrapper to guarantee input shapes."""

def __init__(self):
self.f = torch.nn.CrossEntropyLoss()

def __call__(self, predict, truth):
truth = truth.view(-1, )
return self.f(predict, truth)

if loss_name == "cross_entropy":
return InnerCrossEntropy()
elif loss_name == 'nll':
return torch.nn.NLLLoss()
else:
raise NotImplementedError
'''a Loss object is a callable object represents loss functions
'''

def __init__(self , loss_name , pre_pro = [squash], **kwargs):
'''
:param loss_name: str or None , the name of loss function
:param pre_pro : list of function or str, methods to reform parameters before calculating loss
the strings will be auto translated to pre-defined functions
:param **kwargs: kwargs for torch loss function
pre_pro funcsions should have three arguments: predict, truth, **arg
predict and truth is the necessary parameters in loss function
kwargs is the extra parameters passed-in when calling loss function
pre_pro functions should return two objects, respectively predict and truth that after processed
'''

if loss_name is None:
# this is useful when Trainer.__init__ performs type check
self._loss = None
else:
if not isinstance(loss_name, str):
raise NotImplementedError
else:
self._loss = self._get_loss(loss_name , **kwargs)

self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro]

def add_pre_pro(self , func):
'''add a pre_pro function
:param func: a function or str, methods to reform parameters before calculating loss
the strings will be auto translated to pre-defined functions
'''
if not callable(func):
func = method_dict.get(func)
if func is None:
return
self.pre_pro.append(func)

@staticmethod
def _get_loss(loss_name , **kwargs):
'''Get loss function from torch
:param loss_name: str, the name of loss function
:param **kwargs: kwargs for torch loss function
:return: A callable loss function object
'''
loss_name = loss_name.strip().lower()
loss_name = "".join(loss_name.split("_"))

if len(loss_name) < 4 or loss_name[-4 : ] != "loss":
loss_name += "loss"
return loss_function_name[loss_name](**kwargs)

def get(self):
'''This method exists just for make some existing codes run error-freely
'''
return self

def __call__(self , predict , truth , **kwargs):
'''call a loss function
predict and truth will be processed by pre_pro methods in order of addition
:param predict : Tensor, model output
:param truth : Tensor, truth from dataset
:param **kwargs : extra arguments, pass to pre_pro functions
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens
'''
for f in self.pre_pro:
if f is None:
continue
predict , truth = f(predict , truth , **kwargs)

return self._loss(predict , truth)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy>=1.14.2
torch==0.4.0
torch>=0.4.0
torchvision>=0.1.8
tensorboardX
Loading

0 comments on commit abf840c

Please sign in to comment.