Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a jax backend based on dia sparse matrix, JaxDia. #20

Merged
merged 32 commits into from
Apr 1, 2024

Conversation

Ericgig
Copy link
Member

@Ericgig Ericgig commented Jun 21, 2023

Add a dia sparse data layer for Jax.
JaxDia's data are jax.array, (differentiable), but the structure is meta data.
This allow most functions to support jit.
tidyup does not jit, so I did not include automated tidyup.
The dia format is good for operator with few diagonals, passable for ket and horrible for bra.
I made the function so it can be supported as the operators in mesolve.

Included specialisations:

  • isherm, iszero, isdiag
  • add, sub, mul, matmul, multiply, kron
  • neg, adjoint, transpose, conj
  • zeros, identity, diag, one_element
  • expect, expect_super, trace
  • extract, tidyup

Won't do:

  • indices, dimensions, reshape_jaxarray, column_stack_jaxarray, column_unstack_jaxarray, split_columns_jaxarray, trace_oper_ket.
    Reshaping, permuting breaks the diagonals structure.
  • inv, expm, eigs, svd, solve. Non-trivial, better done in dense format.

Maybe:

  • project, pow, inner, inner_op, norm..., ptrace

#19 is a branch off this one, and should be merged before this one.

@coveralls
Copy link

coveralls commented Jun 21, 2023

Pull Request Test Coverage Report for Build 8504122727

Details

  • 527 of 585 (90.09%) changed or added relevant lines in 12 files are covered.
  • 3 unchanged lines in 2 files lost coverage.
  • Overall coverage increased (+0.06%) to 90.462%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/qutip_jax/ode.py 8 9 88.89%
src/qutip_jax/qutip_trees.py 3 4 75.0%
src/qutip_jax/binops.py 152 156 97.44%
src/qutip_jax/properties.py 40 44 90.91%
src/qutip_jax/create.py 49 60 81.67%
src/qutip_jax/jaxdia.py 89 102 87.25%
src/qutip_jax/qobjevo.py 83 107 77.57%
Files with Coverage Reduction New Missed Lines %
src/qutip_jax/qobjevo.py 1 81.36%
src/qutip_jax/jaxarray.py 2 85.71%
Totals Coverage Status
Change from base Build 7964056840: 0.06%
Covered Lines: 1195
Relevant Lines: 1321

💛 - Coveralls

@Ericgig Ericgig marked this pull request as ready for review July 11, 2023 15:11
src/qutip_jax/jaxdia.py Outdated Show resolved Hide resolved
Copy link
Member

@nwlambert nwlambert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I stress tested mesolve() with this quite a bit in benchmarking, and didnt have any issues. Left a couple of comments on wrong/missing docstrings.

One random comment, more on qutip jax than jaxdia, would making the ode solver use diffraxs PIDController by default instead of ConstantStepSize cause any problems (assuming its still the default) ? in practice the performance of the ConstantStepSize was terrible (for the cases I tried)

Perform the operation
left + scale*right
where `left` and `right` are matrices, and `scale` is an optional complex
scalar.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the wrong docstring here...



@jit
def kron_jaxdia(left, right):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring?

state = _float2cplx(y)
H, = args
d_state = H.matmul_data(t, JaxArray(state))
return _cplx2float(d_state._jxa)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a minor thing, does the new complex support in diffrax essentially do the same thing? though given all the warnings using complex in diffrax directly throws out, maybe better to stick with this manual treatment at the moment.

Copy link
Member Author

@Ericgig Ericgig Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No sure, I did not play too much with it yet.

@Ericgig
Copy link
Member Author

Ericgig commented Apr 1, 2024

We can change to PIDController before making a release, but not in this PR.

@Ericgig Ericgig merged commit 5f1c2bf into qutip:master Apr 1, 2024
3 checks passed
@Ericgig Ericgig deleted the feature.jaxdiag branch April 1, 2024 07:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants