Skip to content

Commit

Permalink
Merge pull request #145 from Sampreet/pr/jax-docs
Browse files Browse the repository at this point in the history
Update Documentation for Features under Development
  • Loading branch information
piperfw authored Nov 11, 2024
2 parents 1ebaaa8 + 996d3e5 commit e7b5db8
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ The current setup uses:
* [tox](https://tox.readthedocs.io) ... for testing with different environments.
* [travis](https://travis-ci.com) ... for continuous integration.

We are actively incorporating additional features to OQuPy,
details of which can be found in [DEVELOPMENT.md](./DEVELOPMENT.md).

## How to contribute to the code or documentation
Please use the
[Issues](https://github.com/tempoCollaboration/OQuPy/issues) and
Expand Down
45 changes: 45 additions & 0 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Development

The following feautres are under active development:

* [Experimental Support for GPUs/TPUs](#experimental-support-for-gpustpus)

## Experimental Support for GPUs/TPUs

The current development branch `dev/jax` implements
experimental support for GPUs/TPUs.

Although OQuPy is built on top of the backend-agnostic
[TensorNetwork](https://github.com/google/TensorNetwork) library,
OQuPy uses vanilla NumPy and SciPy throughout its implementation.
The `dev/jax` branch adds support for GPUs/TPUs via the
[JAX](https://jax.readthedocs.io/en/latest/) library.
A new `oqupy.backends.numerical_backend.py` module handles the
[breaking changes in JAX NumPy](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html),
while the rest of the modules utilizes `numpy` and `scipy.linalg` instances from there
without explicitly importing JAX-based libraries.

### Enabling Experimental Features

To enable experimental features switch to the `dev/jax` branch and use
```python
from oqupy.backends import enable_jax_features
enable_jax_features()
```
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
initialize the jax backend by default.

### Contributing Guidelines

To contribute features compatible with the JAX backend,
please adhere to the following set of guidelines:

* avoid wildcard imports of NumPy and SciPy.
* use `from oqupy.backends.numerical_backend import np` instead of `import numpy as np` and use the alias `default_np` in cases vanilla NumPy is explicitly required.
* use `from oqupy.backends.numerical_backend import la` instead of `import scipy.linalg as la`, except that for non-symmetric eigen-decomposition, `scipy.linalg.eig` should be used.
* use one of `np.dtype_complex` (`np.dtype_float`) or `oqupy.config.NumPyDtypeComplex` (`oqupy.config.NumPyDtypeFloat`) instead of `np.complex_` (`np.float_`).
* convert lists or tuples to arrays when passing them as arguments inside functions.
* use `array = np.update(array, indices, values)` instead of `array[indices] = values`.
* use `np.get_random_floats(seed, shape)` instead of `np.random.default_rng(seed).random(shape)`.
* declare signatures for `np.vectorize` explicitly.
* avoid directly changing the `shape` attribute of an array (use `.reshape` instead)
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Furthermore, OQuPy implements methods to ...
:caption: Development

pages/contributing
pages/jax_features
pages/authors
pages/how_to_cite
pages/sharing
Expand Down
7 changes: 6 additions & 1 deletion docs/pages/authors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ Authors & Acknowledgements
- Lead developer since 2020: `Gerald E.
Fux <https://github.com/gefux>`__ ([email protected])
- Co-lead developer since 2022: `Piper
Fowler-Wright <https://github.com/piperfw>`__ ([email protected])
Fowler-Wright <https://github.com/piperfw>`__ ([email protected])

Major code contributions
------------------------

**Experimental features**

- `Sampreet Kalita <https://github.com/Sampreet>`__: JAX numerical backend for
GPU/TPU support

**Version 0.5.0**

- `Aidan Strathearn <https://github.com/aidanstrathearn>`__: Gibbs state TEMPO [Chiu2022].
Expand Down
55 changes: 55 additions & 0 deletions docs/pages/jax_features.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
Experimental Support for GPUs/TPUs
==================================

The current development branch ``dev/jax`` implements
experimental support for GPUs/TPUs.

Although OQuPy is built on top of the backend-agnostic
`TensorNetwork <https://github.com/google/TensorNetwork>`__ library,
OQuPy uses vanilla NumPy and SciPy throughout its implementation.
The ``dev/jax`` branch adds support for GPUs/TPUs via the
`JAX <https://jax.readthedocs.io/en/latest/>`__ library. A new
``oqupy.backends.numerical_backend.py`` module handles the
`breaking changes in JAX
NumPy <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html>`__,
while the rest of the modules utilizes ``numpy`` and ``scipy.linalg``
instances from there without explicitly importing JAX-based libraries.

Enabling Experimental Features
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To enable experimental features, switch to the ``dev/jax`` branch and use

.. code:: python
from oqupy.backends import enable_jax_features
enable_jax_features()
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
initialize the jax backend by default.

Contributing Guidelines
~~~~~~~~~~~~~~~~~~~~~~~

To contribute features compatible with the JAX backend,
please adhere to the following set of guidelines:

- avoid wildcard imports of NumPy and SciPy.
- use ``from oqupy.backends.numerical_backend import np`` instead of
``import numpy as np`` and use the alias ``default_np`` in cases
vanilla NumPy is explicitly required.
- use ``from oqupy.backends.numerical_backend import la`` instead of
``import scipy.linalg as la``, except that for non-symmetric
eigen-decomposition, ``scipy.linalg.eig`` should be used.
- use one of ``np.dtype_complex`` (``np.dtype_float``) or
``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``)
instead of ``np.complex_`` (``np.float_``).
- convert lists or tuples to arrays when passing them as arguments
inside functions.
- use ``array = np.update(array, indices, values)`` instead of
``array[indices] = values``.
- use ``np.get_random_floats(seed, shape)`` instead of
``np.random.default_rng(seed).random(shape)``.
- declare signatures for ``np.vectorize`` explicitly.
- avoid directly changing the ``shape`` attribute of an array (use
``.reshape`` instead)

0 comments on commit e7b5db8

Please sign in to comment.