-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weighted soft_rank #76
Comments
A soft, and possibly weighted, kendall-tau B would also be a great thing to have. |
Thanks for the suggestions @davips. Do you have any references for the weighted rank? Its not exactly clear to me the intended output of the function (even in an exact/"hard" setting). I agree that a soft kendall-tau B would be great to have, although it may be out of scope of this library. I believe this could likely be done, as you said, without any sorting or ranking. I'm thinking you'd be able to do this with just plain PyTorch by constructing an n x n pairwise difference matrix for each variable, multiplying by some regularization value (so differences |
@teddykoker , your suggestion is pretty similar to what I am trying now. If we find a common ground, I can contribute to the library. I often use python packages as a mean to organize work done and make it available to future work, so when I have some free time it is certainly possible. The weighting scheme is simple. Higher ranks (i.e. lower values) have higher weights defined by a custom function which could be anything, e.g., cauchy or harmonic progression (not as good as a distribution, but it was recommended by Vigna in weighted tau, more details later). This is my from torch import sigmoid, sum, triu_indices
def pdiffs(x):
dis = x.unsqueeze(1) - x
indices = triu_indices(*dis.shape, offset=1)
return dis[indices[0], indices[1]]
def surrogate_tau(a, b, reg=10):
da, db = pdiffs(a), pdiffs(b)
return sum(sigmoid(reg * da * db))
soft_kendalltau = surrogate_tau(pred, target) / (len(pred) * (len(pred) - 1) / 2) This is my from torch import sigmoid, sum, triu_indices
def pdiffs(x):
dis = x.unsqueeze(1) - x
indices = triu_indices(*dis.shape, offset=1)
return dis[indices[0], indices[1]]
def psums(x):
dis = x.unsqueeze(1) + x
indices = torch.triu_indices(*dis.shape, offset=1)
return dis[indices[0], indices[1]]
def surrogate_wtau(a, b, w, reg=10):
da, db, sw = pdiffs(a), pdiffs(b), psums(w)
return sum(sigmoid(reg * da * db) * sw) / sum(sw)
soft_weigthed_kendalltau = surrogate_wtau(pred, target, weights)
# Missing part: the weights are defined according to the rank of each value in `target`.
# Such a ranking depends on `soft_rank`.
# I will try with `cau = scipy.stats.cauchy(0).pdf`:
# w = tensor([cau(r) for r in soft_rank(target)], requires_grad=True)
# Perhaps this will work fine. I am not experienced with regularization. I assume that lower reg values make it softer and easier for the gradient to descend, at the expense of having a less reliable approximation to the real (hard) concept. The paper on weighted tau doesn't seem trivial to follow (at a first glance), and the implementation is so optimized that it looks more complicated than what the problem really is: scipy weighted tau In a nut shell: you may help me taking a look the soft_kendalltau and soft_weigthed_kendalltau code above to assess if it makes sense, and to decide if it is in a possible scope of the library. |
A Weighted soft_rank implementation would be a more general addition to the library, that could be used to replace parts of the code I wrote above. |
A weighted soft_rank would be a great addition to have!
It could be weighted by element index, by a weighting function, or by a vector of weights.
My initial attempt was this (which is part of a larger code to calculate a weighted spearman rho based on
wcorr
package):However it seems to be too optimistic (near 1.0) when compared to, e.g., weightedtau (around 0.6 in a random test I did here). The original README's soft-spearman works fine, being just a little more optimistic (~5% in some tests) than its hard counterpart, which makes sense to me.
The text was updated successfully, but these errors were encountered: