Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU LU factorization and linear solvers #1451

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

abeleinin
Copy link
Contributor

Proposed changes

Adds mx.linalg.solve backend following the numpy implementation.

For the LAPACK routines, I chose to use a combination of sgetrf and sgetrs, which is equivalent to calling sgesv (https://stackoverflow.com/a/36065619) because it removes a transpose of a in the C++.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni
Copy link
Member

awni commented Oct 26, 2024

This looks good to me. I'm debating with myself if we should try to do this in high-level ops rather than add a new primitive for it.

On the one-hand it is definitely going to be way less efficient. On the other hand it will run on the GPU and have transform support out of the box.

I'm also ok landing this and rearranging the back-end at a later time if ti makes sense.

@abeleinin
Copy link
Contributor Author

Were you thinking of an approach using linalg.lu_factor and linalg.solve_triangular instead? Then linalg.solve could just use those ops. I think @barronalex mentioned possibly working on linalg.solve_triangular in a previous review, so maybe those linalg ops should be a higher priority than just linalg.solve support.

@awni
Copy link
Member

awni commented Oct 27, 2024

Then linalg.solve could just use those ops. I think @barronalex mentioned possibly working on linalg.solve_triangular

I think exactly that would be better and more flexible than what we have now. I was even wondering if we could do linalg.solve_triangular using existing MLX ops rather than wrap the lapack routine.

@abeleinin
Copy link
Contributor Author

Ok, working on adding a LUF primitive and subsequent linalg.lu_factor, linalg.solve_triangular, and linalg.solve functions.

@abeleinin abeleinin changed the title CPU mx.linalg.solve CPU LU factorization and linear solvers Oct 31, 2024
@abeleinin
Copy link
Contributor Author

Added LUF primitive which adds support for linalg.lu and linalg.lu_factor functions.

Also added two linear solvers, linalg.solve_triangular and linalg.solve using higher-level ops.

This pr closes #30

@sck-at-ucy
Copy link

This will be really useful, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants