You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PR #51 introduces JAX sparse evaluation for models. Due to some limited functionality for the jax BCOO sparse array type, some work arounds were required to implement a few things. Subsequent releases of JAX are expected to eliminate the need for these workarounds, and this issue is a reminder of what these are.
A function jsparse_linear_combo is defined at the beginning of operator_collections.py that takes a linear combination of sparse arrays (specified as a 3d BCOO array), with the coefficients given in a dense 1d array. This cannot be achieved using a sparisfied version of jnp.tensordot as the design convention jax is using is that such operations will always output dense arrays if at least one input is dense. Hence, jsparse_linear_combo multiplies the coefficients against the sparse array directly via broadcasting. However, sparse-dense element-wise multiplication, at the time of writing, is limited to arrays of the same shape, and therefore it is necessary to explicitly blow up the coefficient array to a dense array with the same shape as the sparse array (which is huge). I'm not sure if this is done via views so it's okay, but this should be changed when possible regardless.
Setting up of operators in the jax sparse Lindblad operator collection is done using dense arrays, as sparse-sparse matrix multiplication is not yet implemented. This is relatively minor but it would be nice to do it with sparse when possible.
"Vectorized" products A @ B, where A and B are 3d arrays, with one being sparse and the other dense, are not reverse-mode differentiable. This results in LindbladModel in jax-sparse not being reverse-mode differentiable. It is however, forward mode differentiable. At some point this will change, and we will need to remove the caveat in LindbladModel.evaluation_mode that sparse mode with jax is not reverse-mode differentiable.
The text was updated successfully, but these errors were encountered:
As of jax 0.2.26 and jaxlib 0.1.75 the above code snippets work. PR #69 now removes the caveat that LindbladModel.evaluate_rhs cannot be reverse-mode autodiffed when in sparse mode, and changes the autodiff test case to revert to testing reverse-mode autodiff.
Updating jsparse_linear_combo in operator_collections.py still needs to be done: while the above snippet works, simply updating jsparse_linear_combo results in several test failures, and why these are happening needs to be figured out. It's possible they're all just numpy.array v.s. jax.numpy.array type errors in the test case setups.
PR #51 introduces JAX sparse evaluation for models. Due to some limited functionality for the jax
BCOO
sparse array type, some work arounds were required to implement a few things. Subsequent releases of JAX are expected to eliminate the need for these workarounds, and this issue is a reminder of what these are.jsparse_linear_combo
is defined at the beginning ofoperator_collections.py
that takes a linear combination of sparse arrays (specified as a 3d BCOO array), with the coefficients given in a dense 1d array. This cannot be achieved using a sparisfied version ofjnp.tensordot
as the design convention jax is using is that such operations will always output dense arrays if at least one input is dense. Hence,jsparse_linear_combo
multiplies the coefficients against the sparse array directly via broadcasting. However, sparse-dense element-wise multiplication, at the time of writing, is limited to arrays of the same shape, and therefore it is necessary to explicitly blow up the coefficient array to a dense array with the same shape as the sparse array (which is huge). I'm not sure if this is done via views so it's okay, but this should be changed when possible regardless."Vectorized" products A @ B, where A and B are 3d arrays, with one being sparse and the other dense, are not reverse-mode differentiable. This results inLindbladModel
in jax-sparse not being reverse-mode differentiable. It is however, forward mode differentiable. At some point this will change, and we will need to remove the caveat inLindbladModel.evaluation_mode
that sparse mode with jax is not reverse-mode differentiable.The text was updated successfully, but these errors were encountered: