Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model with GLVQ classifier #118

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Add model with GLVQ classifier #118

wants to merge 4 commits into from

Conversation

denkle
Copy link
Collaborator

@denkle denkle commented Feb 12, 2023

The code is working properly though the accuracy is lower than expected based in the results reported in the paper.

)

# Training loop
trainer.fit(self.classifier, train_ld)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got an error when running this example:

File "../site-packages/prototorch/models/abstract.py", line 189, in log_acc
    accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
TypeError: accuracy() missing 1 required positional argument: 'task'

It looks like prototorch_models is relying on an older version of torchmetrics which recently changed their API design. Downgrading to torchmetrics=0.10.3 worked for me.

@mikeheddes
Copy link
Member

Thank you @denkle for another great contribution!
I made some small tweaks to improve the backwards compatibility with torchmetrics.
Do you want some extra time to figure out why you obtain lower accuracy or is the PR ready to be merged?

@mikeheddes
Copy link
Member

Hi, just wanted to ask if you had time to look into this yet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants