-
Notifications
You must be signed in to change notification settings - Fork 450
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from FFTYYY/master
update loss & a small change in requirements
- Loading branch information
Showing
4 changed files
with
509 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,196 @@ | ||
import torch | ||
|
||
def squash(predict , truth , **kwargs): | ||
'''To reshape tensors in order to fit Loss functions in pytorch | ||
:param predict : Tensor, model output | ||
:param truth : Tensor, truth from dataset | ||
:param **kwargs : extra arguments | ||
:return predict , truth: predict & truth after processing | ||
''' | ||
return predict.view(-1 , predict.size()[-1]) , truth.view(-1,) | ||
|
||
def unpad(predict , truth , **kwargs): | ||
'''To process padded sequence output to get true loss | ||
Using pack_padded_sequence() method | ||
This method contains squash() | ||
:param predict : Tensor, [batch_size , max_len , tag_size] | ||
:param truth : Tensor, [batch_size , max_len] | ||
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | ||
kwargs["lens"] : list or LongTensor, [batch_size] | ||
the i-th element is true lengths of i-th sequence | ||
:return predict , truth: predict & truth after processing | ||
''' | ||
if kwargs.get("lens") is None: | ||
return predict , truth | ||
lens = torch.LongTensor(kwargs["lens"]) | ||
lens , idx = torch.sort(lens , descending = True) | ||
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx] , lens , batch_first = True).data | ||
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx] , lens , batch_first = True).data | ||
return predict , truth | ||
|
||
def unpad_mask(predict , truth , **kwargs): | ||
'''To process padded sequence output to get true loss | ||
Using mask() method | ||
This method contains squash() | ||
:param predict : Tensor, [batch_size , max_len , tag_size] | ||
:param truth : Tensor, [batch_size , max_len] | ||
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | ||
kwargs["lens"] : list or LongTensor, [batch_size] | ||
the i-th element is true lengths of i-th sequence | ||
:return predict , truth: predict & truth after processing | ||
''' | ||
if kwargs.get("lens") is None: | ||
return predict , truth | ||
mas = make_mask(kwargs["lens"] , truth.size()[1]) | ||
return mask(predict , truth , mask = mas) | ||
|
||
def mask(predict , truth , **kwargs): | ||
'''To select specific elements from Tensor | ||
This method contains squash() | ||
:param predict : Tensor, [batch_size , max_len , tag_size] | ||
:param truth : Tensor, [batch_size , max_len] | ||
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist | ||
kwargs["mask"] : ByteTensor, [batch_size , max_len] | ||
the mask Tensor , the position that is 1 will be selected | ||
:return predict , truth: predict & truth after processing | ||
''' | ||
if kwargs.get("mask") is None: | ||
return predict , truth | ||
mask = kwargs["mask"] | ||
|
||
predict , truth = squash(predict , truth) | ||
mask = mask.view(-1,) | ||
|
||
predict = torch.masked_select(predict.permute(1,0) , mask).view(predict.size()[-1] , -1).permute(1,0) | ||
truth = torch.masked_select(truth , mask) | ||
|
||
return predict , truth | ||
|
||
def make_mask(lens , tar_len): | ||
'''to generate a mask that select [:lens[i]] for i-th element | ||
embezzle from fastNLP.models.sequence_modeling.seq_mask | ||
:param lens : list or LongTensor, [batch_size] | ||
:param tar_len : int | ||
:return mask : ByteTensor | ||
''' | ||
lens = torch.LongTensor(lens) | ||
mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | ||
mask = torch.stack(mask, 1) | ||
return mask | ||
|
||
#map string to function. Just for more elegant using | ||
method_dict = { | ||
"squash" : squash, | ||
"unpad" : unpad, | ||
"unpad_mask" : unpad_mask, | ||
"mask" : mask, | ||
} | ||
|
||
loss_function_name = { | ||
"L1Loss".lower() : torch.nn.L1Loss, | ||
"BCELoss".lower() : torch.nn.BCELoss, | ||
"MSELoss".lower() : torch.nn.MSELoss, | ||
"NLLLoss".lower() : torch.nn.NLLLoss, | ||
"KLDivLoss".lower() : torch.nn.KLDivLoss, | ||
"NLLLoss2dLoss".lower() : torch.nn.NLLLoss2d, #every name should end with "loss" | ||
"SmoothL1Loss".lower() : torch.nn.SmoothL1Loss, | ||
"SoftMarginLoss".lower() : torch.nn.SoftMarginLoss, | ||
"PoissonNLLLoss".lower() : torch.nn.PoissonNLLLoss, | ||
"MultiMarginLoss".lower() : torch.nn.MultiMarginLoss, | ||
"CrossEntropyLoss".lower() : torch.nn.CrossEntropyLoss, | ||
"BCEWithLogitsLoss".lower() : torch.nn.BCEWithLogitsLoss, | ||
"MarginRankingLoss".lower() : torch.nn.MarginRankingLoss, | ||
"TripletMarginLoss".lower() : torch.nn.TripletMarginLoss, | ||
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss, | ||
"CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss, | ||
"MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss, | ||
"MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss, | ||
} | ||
|
||
class Loss(object): | ||
"""Loss function of the algorithm, | ||
either the wrapper of a loss function from framework, or a user-defined loss (need pytorch auto_grad support) | ||
""" | ||
|
||
def __init__(self, args): | ||
""" | ||
:param args: None or str, the name of a loss function. | ||
""" | ||
if args is None: | ||
# this is useful when Trainer.__init__ performs type check | ||
self._loss = None | ||
elif isinstance(args, str): | ||
self._loss = self._borrow_from_pytorch(args) | ||
else: | ||
raise NotImplementedError | ||
|
||
def get(self): | ||
""" | ||
:return self._loss: the loss function | ||
""" | ||
return self._loss | ||
|
||
@staticmethod | ||
def _borrow_from_pytorch(loss_name): | ||
"""Given a name of a loss function, return it from PyTorch. | ||
:param loss_name: str, the name of a loss function | ||
- cross_entropy: combines log softmax and nll loss in a single function. | ||
- nll: negative log likelihood | ||
:return loss: a PyTorch loss | ||
""" | ||
|
||
class InnerCrossEntropy: | ||
"""A simple wrapper to guarantee input shapes.""" | ||
|
||
def __init__(self): | ||
self.f = torch.nn.CrossEntropyLoss() | ||
|
||
def __call__(self, predict, truth): | ||
truth = truth.view(-1, ) | ||
return self.f(predict, truth) | ||
|
||
if loss_name == "cross_entropy": | ||
return InnerCrossEntropy() | ||
elif loss_name == 'nll': | ||
return torch.nn.NLLLoss() | ||
else: | ||
raise NotImplementedError | ||
'''a Loss object is a callable object represents loss functions | ||
''' | ||
|
||
def __init__(self , loss_name , pre_pro = [squash], **kwargs): | ||
''' | ||
:param loss_name: str or None , the name of loss function | ||
:param pre_pro : list of function or str, methods to reform parameters before calculating loss | ||
the strings will be auto translated to pre-defined functions | ||
:param **kwargs: kwargs for torch loss function | ||
pre_pro funcsions should have three arguments: predict, truth, **arg | ||
predict and truth is the necessary parameters in loss function | ||
kwargs is the extra parameters passed-in when calling loss function | ||
pre_pro functions should return two objects, respectively predict and truth that after processed | ||
''' | ||
|
||
if loss_name is None: | ||
# this is useful when Trainer.__init__ performs type check | ||
self._loss = None | ||
else: | ||
if not isinstance(loss_name, str): | ||
raise NotImplementedError | ||
else: | ||
self._loss = self._get_loss(loss_name , **kwargs) | ||
|
||
self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] | ||
|
||
def add_pre_pro(self , func): | ||
'''add a pre_pro function | ||
:param func: a function or str, methods to reform parameters before calculating loss | ||
the strings will be auto translated to pre-defined functions | ||
''' | ||
if not callable(func): | ||
func = method_dict.get(func) | ||
if func is None: | ||
return | ||
self.pre_pro.append(func) | ||
|
||
@staticmethod | ||
def _get_loss(loss_name , **kwargs): | ||
'''Get loss function from torch | ||
:param loss_name: str, the name of loss function | ||
:param **kwargs: kwargs for torch loss function | ||
:return: A callable loss function object | ||
''' | ||
loss_name = loss_name.strip().lower() | ||
loss_name = "".join(loss_name.split("_")) | ||
|
||
if len(loss_name) < 4 or loss_name[-4 : ] != "loss": | ||
loss_name += "loss" | ||
return loss_function_name[loss_name](**kwargs) | ||
|
||
def get(self): | ||
'''This method exists just for make some existing codes run error-freely | ||
''' | ||
return self | ||
|
||
def __call__(self , predict , truth , **kwargs): | ||
'''call a loss function | ||
predict and truth will be processed by pre_pro methods in order of addition | ||
:param predict : Tensor, model output | ||
:param truth : Tensor, truth from dataset | ||
:param **kwargs : extra arguments, pass to pre_pro functions | ||
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens | ||
''' | ||
for f in self.pre_pro: | ||
if f is None: | ||
continue | ||
predict , truth = f(predict , truth , **kwargs) | ||
|
||
return self._loss(predict , truth) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
numpy>=1.14.2 | ||
torch==0.4.0 | ||
torch>=0.4.0 | ||
torchvision>=0.1.8 | ||
tensorboardX |
Oops, something went wrong.