-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/multiclass recall macro avg ignore index #2710
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, can we add also test for this case...
Sure |
@Borda What do I have to modify? |
@rittik9 mind checking the changed docstest values and whether it is correct? |
Any update on when this PR could be merged? It would really help if we could update from the 0.9.3 version once this fix is merged. |
the tests/doctests need to be fixed, are you interested in submitting a suggestion on what else needs to be fixed/chnaged? |
@@ -661,6 +661,37 @@ def test_corner_case(): | |||
assert res == 1.0 | |||
|
|||
|
|||
def test_multiclass_recall_ignore_index(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems we are already testing various ignore_index with reference metric so if we had it wrong this did not pass already... it is possible that we also have a bug in the reference metric?
cc: @SkafteNicki
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking to the code and the ignore index is already applied in _multilabel_stat_scores_format
which reduces the preds/target size the same way as the reference metric so calling it with null weights in fact ignores additional index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is we are using sklearn's recall_score
as a reference for our unittests. So even if in _reference_sklearn_precision_recall_multiclass()
function we are using remove_ignore_index
function for removing those predictions whose real values are ignore_index
class before passing it to recall_score
function, it does not matter. Because whenever average='macro'
sklearn's recall_score
will always return mean cosidering the total no. of classes (as we are passing all the classes in recall_score()
function's labels
argument). That is the reason why unittests failed in the first place. I think we need to fix the unittests to take care of ignore_index using sklearn's recall_score()
function's labels
argument. I've prepared a notebook for explanation. cc:@Borda.
Just to chime in, I think this issue is present in pretty much all metrics that make use of I see this PR fixes, some of them, but others, such as |
… wrong answer when ignore_index is specified
What does this PR do?
Fixes #2441
Details
Did you have fun?
Yes
Issue:
ignore_index
information is not being properly propagated to the final averaging step i.e. the_adjust_weights_safe_divide
function doesn't know that which class should be ignored.Solution:
ignore_index
information is preserved throughout the entire process, making sure it is correctly passed through all intermediate steps up to the final averaging stage i.e._adjust_weights_safe_divide
function ._adjust_weights_safe_divide
function to accept an additionalignore_index
parameter, which is passed through the_precision_recall_reduce
function, called in thecompute
method of theMulticlassRecall
class. This change adjusts the weights in the_adjust_weights_safe_divide
function, setting the weight of the ignored class to 0.📚 Documentation preview 📚: https://torchmetrics--2710.org.readthedocs.build/en/2710/