The following features are under active development:
The current development branch dev/jax
implements
experimental support for GPUs/TPUs.
Although OQuPy is built on top of the backend-agnostic
TensorNetwork library,
OQuPy uses vanilla NumPy and SciPy throughout its implementation.
The dev/jax
branch adds support for GPUs/TPUs via the
JAX library.
A new oqupy.backends.numerical_backend.py
module handles the
breaking changes in JAX NumPy,
while the rest of the modules utilizes numpy
and scipy.linalg
instances from there
without explicitly importing JAX-based libraries.
To enable experimental features switch to the dev/jax
branch and use
from oqupy.backends import enable_jax_features
enable_jax_features()
Alternatively, the OQUPY_BACKEND
environmental variable may be set to jax
to
initialize the jax backend by default.
To contribute features compatible with the JAX backend, please adhere to the following set of guidelines:
- avoid wildcard imports of NumPy and SciPy.
- use
from oqupy.backends.numerical_backend import np
instead ofimport numpy as np
and use the aliasdefault_np
in cases vanilla NumPy is explicitly required. - use
from oqupy.backends.numerical_backend import la
instead ofimport scipy.linalg as la
, except that for non-symmetric eigen-decomposition,scipy.linalg.eig
should be used. - use one of
np.dtype_complex
(np.dtype_float
) oroqupy.config.NumPyDtypeComplex
(oqupy.config.NumPyDtypeFloat
) instead ofnp.complex_
(np.float_
). - convert lists or tuples to arrays when passing them as arguments inside functions.
- use
array = np.update(array, indices, values)
instead ofarray[indices] = values
. - use
np.get_random_floats(seed, shape)
instead ofnp.random.default_rng(seed).random(shape)
. - declare signatures for
np.vectorize
explicitly. - avoid directly changing the
shape
attribute of an array (use.reshape
instead)