Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear_operator cat_rows performance improvement #93

Merged

Conversation

naefjo
Copy link
Contributor

@naefjo naefjo commented Mar 11, 2024

Hello :)

This PR addresses the observations about computational bottlenecks made in cornellius-gp/gpytorch#2468

In cat_rows, the schur_root is converted to a dense operator using to_dense. Hence, the subsequent inversion of the root fails to exploit the structure of the operator and defaults to stable_pinverse which uses a QR decomposition.

The PR contains the following modifications:

  • Don't convert schur_root to a dense operator unless needed for tensor assignment.
  • In root_inv_decomposition exploit structure of the resulting inversion using cholesky by casting the result to a TriangularLinearOperator instead of a DenseLinearOperator.
  • Add an option to specify the matrix size threshold where QR decomposition should be performed on the CPU instead of the GPU which is more in line with the observations made in the linked issue.

@Balandat
Copy link
Collaborator

Thanks for the contribution! Overall this makes sense to me. Would you be able to provide some benchmark results of this change relative to the previous implementation?

@naefjo
Copy link
Contributor Author

naefjo commented Mar 17, 2024

Sure thing.
Timing cat_rows in isolation seems to not make any noticable difference at all in a toy example i tried to cook up. However, it is very noticable in gpytorch's get_fantasy_model. Here is a graph showing the computation times of gpytorch's get_fantasy_model method as a function of number of datapoints based on the basic example notebook. The updates were performed in a "batched" setting, i.e. 10 points at a time. Note that this is in conjunction with the changes from cornellius-gp/gpytorch#2494.

image

BTW, the failing CI seems to be related to an updated version of mpmath downloaded here. 1.4.0 seems to have some breaking API changes. A quick google search shows that other repos were affected as well pytorch/pytorch#120995 NVIDIA/TensorRT-LLM#1145

@naefjo
Copy link
Contributor Author

naefjo commented Mar 17, 2024

Disregard my last comment about there being no difference in cat_rows. Apparently the matrices I was testing were not p.d. enough for cholesky factorization which led to root decompositions being performed with symeig which again led to root inverses being computed with stable_pinverse in cat_rows.... If the matrices are well conditioned enough for cholesky not to fail, then the result look as follows:

image

@Balandat
Copy link
Collaborator

Nice, this seems like a meaningful improvement. Thanks for the perf fix.

BTW, the failing CI seems to be related to an updated version of mpmath

Yes, thanks, I've run into this with other libraries before. #94 pins the version to avoid this issue for now. Could you rebase on that change so we can run the test?

@naefjo naefjo force-pushed the feature/online-learning-improvements branch from 716222d to 0a94f8b Compare March 18, 2024 13:54
@Balandat
Copy link
Collaborator

Not sure what is going on with the docs, but it's unrelated to this PR

@Balandat Balandat merged commit a0a9c42 into cornellius-gp:main Mar 18, 2024
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants