-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a jax backend based on dia sparse matrix, JaxDia
.
#20
Conversation
Pull Request Test Coverage Report for Build 8504122727Details
💛 - Coveralls |
…into feature.jaxdiag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, I stress tested mesolve() with this quite a bit in benchmarking, and didnt have any issues. Left a couple of comments on wrong/missing docstrings.
One random comment, more on qutip jax than jaxdia, would making the ode solver use diffraxs PIDController by default instead of ConstantStepSize cause any problems (assuming its still the default) ? in practice the performance of the ConstantStepSize was terrible (for the cases I tried)
src/qutip_jax/binops.py
Outdated
Perform the operation | ||
left + scale*right | ||
where `left` and `right` are matrices, and `scale` is an optional complex | ||
scalar. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the wrong docstring here...
|
||
|
||
@jit | ||
def kron_jaxdia(left, right): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing docstring?
state = _float2cplx(y) | ||
H, = args | ||
d_state = H.matmul_data(t, JaxArray(state)) | ||
return _cplx2float(d_state._jxa) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a minor thing, does the new complex support in diffrax essentially do the same thing? though given all the warnings using complex in diffrax directly throws out, maybe better to stick with this manual treatment at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No sure, I did not play too much with it yet.
We can change to PIDController before making a release, but not in this PR. |
Add a dia sparse data layer for Jax.
JaxDia
's data are jax.array, (differentiable), but the structure is meta data.This allow most functions to support jit.
tidyup
does not jit, so I did not include automated tidyup.The dia format is good for operator with few diagonals, passable for ket and horrible for bra.
I made the function so it can be supported as the operators in
mesolve
.Included specialisations:
isherm
,iszero
,isdiag
add
,sub
,mul
,matmul
,multiply
,kron
neg
,adjoint
,transpose
,conj
zeros
,identity
,diag
,one_element
expect
,expect_super
,trace
extract
,tidyup
Won't do:
indices
,dimensions
,reshape_jaxarray
,column_stack_jaxarray
,column_unstack_jaxarray
,split_columns_jaxarray
,trace_oper_ket
.Reshaping, permuting breaks the diagonals structure.
inv
,expm
,eigs
,svd
,solve
. Non-trivial, better done in dense format.Maybe:
project
,pow
,inner
,inner_op
,norm...
,ptrace
#19 is a branch off this one, and should be merged before this one.