Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Torch implementation (all classes excluding SL3) #20

Merged
merged 24 commits into from
Feb 25, 2024
Merged

Conversation

angadbajwa
Copy link
Contributor

@angadbajwa angadbajwa commented Feb 2, 2024

first pass at full torch implementation. making this now so you guys can start looking at it while I finish it up.

checklist of things left to add (that I know of):

  • np cross-validation. maybe a quick addition to brute force a logm test in batch
  • documentation. base class is definitely lacking relative to the np base.py and the dimensionality rules need to be explained so it doesn't get confusing
  • edit np test_so3.py because I added ordering to the to_euler function
  • dimensionality isn't fulllllly enforced. the rules are followed with inter/intra-function interactions, but we could idiot-proof this further. it's possible that the user forgetting that you need a batch dimension could lead to some funky behaviour especially in the lower-dimensional classes (i.e. SO2)

… (N, P, 1) structure where N is the batch dim. and P is the parameterization dim.
…ation for angles near pi, and batchify pi mask
… name overlap with 'MatrixLieGroup' needs to be addressed
…rch to avoid naming conflict with np implementation
Copy link
Member

@CharlesCossette CharlesCossette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to add torch to this list i think:

https://github.com/decargroup/pymlg/blob/main/setup.py#L83

@CharlesCossette
Copy link
Member

Also, for the dimensionality thing, if its just something thats not supported right now you can throw a value error

raise ValueError("Input must be an [N x 2 x 2] matrix")

and write a comment that its a TODO or something

@angadbajwa angadbajwa merged commit 0715e40 into main Feb 25, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants