-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow other dtype
than complex128.
#22
base: master
Are you sure you want to change the base?
Conversation
Pull Request Test Coverage Report for Build 5455627965
💛 - Coveralls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Eric, I did not finish reviewing this PR but I wanted to ask the following:
- Should we consider registering a
"jax64"
dtype to addres the complex64 case? In qutip-tensorflow I used a _BaseTfTensor class to represent both types of tensors. I then wrapped this class to createTfTensor64
andTfTensor128
. So in qutip-tensorflow these two classes are different but virtually the same. It does the work but there may be cleaner approaches? It does require to register the specialisation twice (or perhaps three times if you want to allow operations between 64 and 128 types). - I see the use of fast_constructor instead of the init method. Is it to provide a significant speedup in the construction time?
@jax.jit | ||
def _cplx2float(arr): | ||
return jnp.stack([arr.real, arr.imag]) | ||
if jnp.iscomplexobj(arr): | ||
return jnp.stack([arr.real, arr.imag]) | ||
return arr | ||
|
||
|
||
@jax.jit | ||
def _float2cplx(arr): | ||
return arr[0] + 1j * arr[1] | ||
if arr.ndim == 3: | ||
return arr[0] + 1j * arr[1] | ||
return arr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is covered in test_non_cplx128_Diffrax
.
I would prefer not to. Also having a
|
Our data layer always use double precision complex, but this is not always ideals, GPUs often don's support double precision well.
With this, the default is still to convert any array to complex128, but it can be overwritten.
However, this is not usable trough the Qutip interface with this PR. I am not too sure how to make it available to the user....