Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weighted soft_rank #76

Open
davips opened this issue Aug 16, 2023 · 4 comments
Open

Weighted soft_rank #76

davips opened this issue Aug 16, 2023 · 4 comments

Comments

@davips
Copy link

davips commented Aug 16, 2023

A weighted soft_rank would be a great addition to have!
It could be weighted by element index, by a weighting function, or by a vector of weights.

My initial attempt was this (which is part of a larger code to calculate a weighted spearman rho based on wcorr package):

def wsrank(x, w, regularization="l2", regularization_strength=1.):
    """
    >>> import torch
    >>> soft_rank(torch.tensor([[1., 2, 2, 3]]))
    tensor([[1.5000, 2.5000, 2.5000, 3.5000]])
    >>> wsrank(torch.tensor([[1., 2, 2, 3]]), torch.tensor([1., 1, 1, 1]))
    tensor([1.5000, 2.5000, 2.5000, 3.5000])
    >>> wsrank(torch.tensor([[1., 2, 3, 4]]), torch.tensor([1., 1/2, 1/3, 1/4]))
    tensor([1.0000, 1.5000, 1.8333, 2.0833])
    >>> wsrank(torch.tensor([[1., 2, 3, 4]], requires_grad=True), torch.tensor([1., 1/2, 1/3, 1/4])).sum().backward()
    """
    r = soft_rank(x, regularization=regularization, regularization_strength=regularization_strength).view(x.shape[1])
    d = hstack([r[0], diff(r)])
    s = cumsum((d * w) / 1, dim=0)
    return s

However it seems to be too optimistic (near 1.0) when compared to, e.g., weightedtau (around 0.6 in a random test I did here). The original README's soft-spearman works fine, being just a little more optimistic (~5% in some tests) than its hard counterpart, which makes sense to me.

@davips
Copy link
Author

davips commented Aug 16, 2023

A soft, and possibly weighted, kendall-tau B would also be a great thing to have.
It probably doesn't need sorting or ranking, just a sigmoid-like function to soften the agreement/disagreement/tie counters.
We can take advantage of the GPU implementing a parallelization of the naive O(n²), instead of the Knight O(nlogn) algorithm adopted by S. Vigna in scipy/cython implementation.

@teddykoker
Copy link
Owner

Thanks for the suggestions @davips. Do you have any references for the weighted rank? Its not exactly clear to me the intended output of the function (even in an exact/"hard" setting).

I agree that a soft kendall-tau B would be great to have, although it may be out of scope of this library. I believe this could likely be done, as you said, without any sorting or ranking. I'm thinking you'd be able to do this with just plain PyTorch by constructing an n x n pairwise difference matrix for each variable, multiplying by some regularization value (so differences $\to {-\infty, 0, \infty}$ as regularization term $\to \infty$), then handling the concordant/discordant pairs with activation functions. Definitely feasible with differentiable operations, but I'm not sure how "smooth" it would be for low regularization values. Not sure if I have the time to implement this myself, but I'd be happy to iterate on any ideas you might have.

@davips
Copy link
Author

davips commented Sep 24, 2023

@teddykoker , your suggestion is pretty similar to what I am trying now. If we find a common ground, I can contribute to the library. I often use python packages as a mean to organize work done and make it available to future work, so when I have some free time it is certainly possible.

The weighting scheme is simple. Higher ranks (i.e. lower values) have higher weights defined by a custom function which could be anything, e.g., cauchy or harmonic progression (not as good as a distribution, but it was recommended by Vigna in weighted tau, more details later).

This is my soft_kendalltau so far:

from torch import sigmoid, sum, triu_indices

def pdiffs(x):
    dis = x.unsqueeze(1) - x
    indices = triu_indices(*dis.shape, offset=1)
    return dis[indices[0], indices[1]]

def surrogate_tau(a, b, reg=10):
    da, db = pdiffs(a), pdiffs(b)
    return sum(sigmoid(reg * da * db))

soft_kendalltau = surrogate_tau(pred, target) / (len(pred) * (len(pred) - 1) / 2)

This is my soft_weigthed_kendalltau so far:

from torch import sigmoid, sum, triu_indices

def pdiffs(x):
    dis = x.unsqueeze(1) - x
    indices = triu_indices(*dis.shape, offset=1)
    return dis[indices[0], indices[1]]

def psums(x):
    dis = x.unsqueeze(1) + x
    indices = torch.triu_indices(*dis.shape, offset=1)
    return dis[indices[0], indices[1]]

def surrogate_wtau(a, b, w, reg=10):
    da, db, sw = pdiffs(a), pdiffs(b), psums(w)
    return sum(sigmoid(reg * da * db) * sw) / sum(sw)

soft_weigthed_kendalltau = surrogate_wtau(pred, target, weights)

# Missing part: the weights are defined according to the rank of each value in `target`.
#               Such a ranking depends on `soft_rank`.
#               I will try with `cau = scipy.stats.cauchy(0).pdf`:
#               w = tensor([cau(r) for r in soft_rank(target)], requires_grad=True)
#               Perhaps this will work fine.

I am not experienced with regularization. I assume that lower reg values make it softer and easier for the gradient to descend, at the expense of having a less reliable approximation to the real (hard) concept.

The paper on weighted tau doesn't seem trivial to follow (at a first glance), and the implementation is so optimized that it looks more complicated than what the problem really is: scipy weighted tau

In a nut shell: you may help me taking a look the soft_kendalltau and soft_weigthed_kendalltau code above to assess if it makes sense, and to decide if it is in a possible scope of the library.

@davips
Copy link
Author

davips commented Sep 24, 2023

A Weighted soft_rank implementation would be a more general addition to the library, that could be used to replace parts of the code I wrote above.
However, too trivial, as it seems to be just soft_rank(x) * weigher(soft_rank(x)).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants