You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def masked_log_softmax(vector: torch.Tensor, mask: torch.BoolTensor, dim: int = -1) -> torch.Tensor:
if mask is not None:
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
vector = vector + (mask + tiny_value_of_dtype(vector.dtype)).log()
return torch.nn.functional.log_softmax(vector, dim=dim)
So sorry, I can not understand: why just mask the (vector=type_linear_output @ span_linear_output) before inputting the vector to the log_softmax function, how to make sure the numerator (exp(sim(s_i,j, ek))) in function (5) is the positive?
The text was updated successfully, but these errors were encountered:
def masked_log_softmax(vector: torch.Tensor, mask: torch.BoolTensor, dim: int = -1) -> torch.Tensor:
if mask is not None:
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
vector = vector + (mask + tiny_value_of_dtype(vector.dtype)).log()
return torch.nn.functional.log_softmax(vector, dim=dim)
So sorry, I can not understand: why just mask the (vector=type_linear_output @ span_linear_output) before inputting the vector to the log_softmax function, how to make sure the numerator (exp(sim(s_i,j, ek))) in function (5) is the positive?
The text was updated successfully, but these errors were encountered: