-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #145 from Sampreet/pr/jax-docs
Update Documentation for Features under Development
- Loading branch information
Showing
5 changed files
with
110 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Development | ||
|
||
The following feautres are under active development: | ||
|
||
* [Experimental Support for GPUs/TPUs](#experimental-support-for-gpustpus) | ||
|
||
## Experimental Support for GPUs/TPUs | ||
|
||
The current development branch `dev/jax` implements | ||
experimental support for GPUs/TPUs. | ||
|
||
Although OQuPy is built on top of the backend-agnostic | ||
[TensorNetwork](https://github.com/google/TensorNetwork) library, | ||
OQuPy uses vanilla NumPy and SciPy throughout its implementation. | ||
The `dev/jax` branch adds support for GPUs/TPUs via the | ||
[JAX](https://jax.readthedocs.io/en/latest/) library. | ||
A new `oqupy.backends.numerical_backend.py` module handles the | ||
[breaking changes in JAX NumPy](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html), | ||
while the rest of the modules utilizes `numpy` and `scipy.linalg` instances from there | ||
without explicitly importing JAX-based libraries. | ||
|
||
### Enabling Experimental Features | ||
|
||
To enable experimental features switch to the `dev/jax` branch and use | ||
```python | ||
from oqupy.backends import enable_jax_features | ||
enable_jax_features() | ||
``` | ||
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to | ||
initialize the jax backend by default. | ||
|
||
### Contributing Guidelines | ||
|
||
To contribute features compatible with the JAX backend, | ||
please adhere to the following set of guidelines: | ||
|
||
* avoid wildcard imports of NumPy and SciPy. | ||
* use `from oqupy.backends.numerical_backend import np` instead of `import numpy as np` and use the alias `default_np` in cases vanilla NumPy is explicitly required. | ||
* use `from oqupy.backends.numerical_backend import la` instead of `import scipy.linalg as la`, except that for non-symmetric eigen-decomposition, `scipy.linalg.eig` should be used. | ||
* use one of `np.dtype_complex` (`np.dtype_float`) or `oqupy.config.NumPyDtypeComplex` (`oqupy.config.NumPyDtypeFloat`) instead of `np.complex_` (`np.float_`). | ||
* convert lists or tuples to arrays when passing them as arguments inside functions. | ||
* use `array = np.update(array, indices, values)` instead of `array[indices] = values`. | ||
* use `np.get_random_floats(seed, shape)` instead of `np.random.default_rng(seed).random(shape)`. | ||
* declare signatures for `np.vectorize` explicitly. | ||
* avoid directly changing the `shape` attribute of an array (use `.reshape` instead) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,16 @@ Authors & Acknowledgements | |
- Lead developer since 2020: `Gerald E. | ||
Fux <https://github.com/gefux>`__ ([email protected]) | ||
- Co-lead developer since 2022: `Piper | ||
Fowler-Wright <https://github.com/piperfw>`__ ([email protected]) | ||
Fowler-Wright <https://github.com/piperfw>`__ ([email protected]) | ||
|
||
Major code contributions | ||
------------------------ | ||
|
||
**Experimental features** | ||
|
||
- `Sampreet Kalita <https://github.com/Sampreet>`__: JAX numerical backend for | ||
GPU/TPU support | ||
|
||
**Version 0.5.0** | ||
|
||
- `Aidan Strathearn <https://github.com/aidanstrathearn>`__: Gibbs state TEMPO [Chiu2022]. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
Experimental Support for GPUs/TPUs | ||
================================== | ||
|
||
The current development branch ``dev/jax`` implements | ||
experimental support for GPUs/TPUs. | ||
|
||
Although OQuPy is built on top of the backend-agnostic | ||
`TensorNetwork <https://github.com/google/TensorNetwork>`__ library, | ||
OQuPy uses vanilla NumPy and SciPy throughout its implementation. | ||
The ``dev/jax`` branch adds support for GPUs/TPUs via the | ||
`JAX <https://jax.readthedocs.io/en/latest/>`__ library. A new | ||
``oqupy.backends.numerical_backend.py`` module handles the | ||
`breaking changes in JAX | ||
NumPy <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html>`__, | ||
while the rest of the modules utilizes ``numpy`` and ``scipy.linalg`` | ||
instances from there without explicitly importing JAX-based libraries. | ||
|
||
Enabling Experimental Features | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
To enable experimental features, switch to the ``dev/jax`` branch and use | ||
|
||
.. code:: python | ||
from oqupy.backends import enable_jax_features | ||
enable_jax_features() | ||
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to | ||
initialize the jax backend by default. | ||
|
||
Contributing Guidelines | ||
~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
To contribute features compatible with the JAX backend, | ||
please adhere to the following set of guidelines: | ||
|
||
- avoid wildcard imports of NumPy and SciPy. | ||
- use ``from oqupy.backends.numerical_backend import np`` instead of | ||
``import numpy as np`` and use the alias ``default_np`` in cases | ||
vanilla NumPy is explicitly required. | ||
- use ``from oqupy.backends.numerical_backend import la`` instead of | ||
``import scipy.linalg as la``, except that for non-symmetric | ||
eigen-decomposition, ``scipy.linalg.eig`` should be used. | ||
- use one of ``np.dtype_complex`` (``np.dtype_float``) or | ||
``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``) | ||
instead of ``np.complex_`` (``np.float_``). | ||
- convert lists or tuples to arrays when passing them as arguments | ||
inside functions. | ||
- use ``array = np.update(array, indices, values)`` instead of | ||
``array[indices] = values``. | ||
- use ``np.get_random_floats(seed, shape)`` instead of | ||
``np.random.default_rng(seed).random(shape)``. | ||
- declare signatures for ``np.vectorize`` explicitly. | ||
- avoid directly changing the ``shape`` attribute of an array (use | ||
``.reshape`` instead) |