diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3a3ed40..1c3c7e5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,6 +29,9 @@ The current setup uses: * [tox](https://tox.readthedocs.io) ... for testing with different environments. * [travis](https://travis-ci.com) ... for continuous integration. +We are actively incorporating additional features to OQuPy, +details of which can be found in [DEVELOPMENT.md](./DEVELOPMENT.md). + ## How to contribute to the code or documentation Please use the [Issues](https://github.com/tempoCollaboration/OQuPy/issues) and diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 0000000..ba706da --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,45 @@ +# Development + +The following feautres are under active development: + +* [Experimental Support for GPUs/TPUs](#experimental-support-for-gpustpus) + +## Experimental Support for GPUs/TPUs + +The current development branch `dev/jax` implements +experimental support for GPUs/TPUs. + +Although OQuPy is built on top of the backend-agnostic +[TensorNetwork](https://github.com/google/TensorNetwork) library, +OQuPy uses vanilla NumPy and SciPy throughout its implementation. +The `dev/jax` branch adds support for GPUs/TPUs via the +[JAX](https://jax.readthedocs.io/en/latest/) library. +A new `oqupy.backends.numerical_backend.py` module handles the +[breaking changes in JAX NumPy](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html), +while the rest of the modules utilizes `numpy` and `scipy.linalg` instances from there +without explicitly importing JAX-based libraries. + +### Enabling Experimental Features + +To enable experimental features switch to the `dev/jax` branch and use +```python +from oqupy.backends import enable_jax_features +enable_jax_features() +``` +Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to +initialize the jax backend by default. + +### Contributing Guidelines + +To contribute features compatible with the JAX backend, +please adhere to the following set of guidelines: + +* avoid wildcard imports of NumPy and SciPy. +* use `from oqupy.backends.numerical_backend import np` instead of `import numpy as np` and use the alias `default_np` in cases vanilla NumPy is explicitly required. +* use `from oqupy.backends.numerical_backend import la` instead of `import scipy.linalg as la`, except that for non-symmetric eigen-decomposition, `scipy.linalg.eig` should be used. +* use one of `np.dtype_complex` (`np.dtype_float`) or `oqupy.config.NumPyDtypeComplex` (`oqupy.config.NumPyDtypeFloat`) instead of `np.complex_` (`np.float_`). +* convert lists or tuples to arrays when passing them as arguments inside functions. +* use `array = np.update(array, indices, values)` instead of `array[indices] = values`. +* use `np.get_random_floats(seed, shape)` instead of `np.random.default_rng(seed).random(shape)`. +* declare signatures for `np.vectorize` explicitly. +* avoid directly changing the `shape` attribute of an array (use `.reshape` instead) diff --git a/docs/index.rst b/docs/index.rst index 8f60b2a..16e2eaf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -112,6 +112,7 @@ Furthermore, OQuPy implements methods to ... :caption: Development pages/contributing + pages/jax_features pages/authors pages/how_to_cite pages/sharing diff --git a/docs/pages/authors.rst b/docs/pages/authors.rst index 0c7f558..b9801a7 100644 --- a/docs/pages/authors.rst +++ b/docs/pages/authors.rst @@ -4,11 +4,16 @@ Authors & Acknowledgements - Lead developer since 2020: `Gerald E. Fux `__ (gerald.e.fux@gmail.com) - Co-lead developer since 2022: `Piper - Fowler-Wright `__ (pfw1@st-andrews.ac.uk) + Fowler-Wright `__ (piperfw@gmail.com) Major code contributions ------------------------ +**Experimental features** + +- `Sampreet Kalita `__: JAX numerical backend for + GPU/TPU support + **Version 0.5.0** - `Aidan Strathearn `__: Gibbs state TEMPO [Chiu2022]. diff --git a/docs/pages/jax_features.rst b/docs/pages/jax_features.rst new file mode 100644 index 0000000..ef168c6 --- /dev/null +++ b/docs/pages/jax_features.rst @@ -0,0 +1,55 @@ +Experimental Support for GPUs/TPUs +================================== + +The current development branch ``dev/jax`` implements +experimental support for GPUs/TPUs. + +Although OQuPy is built on top of the backend-agnostic +`TensorNetwork `__ library, +OQuPy uses vanilla NumPy and SciPy throughout its implementation. +The ``dev/jax`` branch adds support for GPUs/TPUs via the +`JAX `__ library. A new +``oqupy.backends.numerical_backend.py`` module handles the +`breaking changes in JAX +NumPy `__, +while the rest of the modules utilizes ``numpy`` and ``scipy.linalg`` +instances from there without explicitly importing JAX-based libraries. + +Enabling Experimental Features +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To enable experimental features, switch to the ``dev/jax`` branch and use + +.. code:: python + + from oqupy.backends import enable_jax_features + enable_jax_features() + +Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to +initialize the jax backend by default. + +Contributing Guidelines +~~~~~~~~~~~~~~~~~~~~~~~ + +To contribute features compatible with the JAX backend, +please adhere to the following set of guidelines: + +- avoid wildcard imports of NumPy and SciPy. +- use ``from oqupy.backends.numerical_backend import np`` instead of + ``import numpy as np`` and use the alias ``default_np`` in cases + vanilla NumPy is explicitly required. +- use ``from oqupy.backends.numerical_backend import la`` instead of + ``import scipy.linalg as la``, except that for non-symmetric + eigen-decomposition, ``scipy.linalg.eig`` should be used. +- use one of ``np.dtype_complex`` (``np.dtype_float``) or + ``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``) + instead of ``np.complex_`` (``np.float_``). +- convert lists or tuples to arrays when passing them as arguments + inside functions. +- use ``array = np.update(array, indices, values)`` instead of + ``array[indices] = values``. +- use ``np.get_random_floats(seed, shape)`` instead of + ``np.random.default_rng(seed).random(shape)``. +- declare signatures for ``np.vectorize`` explicitly. +- avoid directly changing the ``shape`` attribute of an array (use + ``.reshape`` instead)