Skip to content

Commit

Permalink
Merge pull request #144 from Sampreet/pr/feature-numerical-backend
Browse files Browse the repository at this point in the history
JAX Numerical Backend for GPU/TPU Support
  • Loading branch information
piperfw authored Nov 10, 2024
2 parents 1ebaaa8 + f142c44 commit 6ccbd2e
Show file tree
Hide file tree
Showing 64 changed files with 1,098 additions and 687 deletions.
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ The current setup uses:
* [tox](https://tox.readthedocs.io) ... for testing with different environments.
* [travis](https://travis-ci.com) ... for continuous integration.

We are actively incorporating additional features to OQuPy,
details of which can be found in [DEVELOPMENT.md](./DEVELOPMENT.md).

## How to contribute to the code or documentation
Please use the
[Issues](https://github.com/tempoCollaboration/OQuPy/issues) and
Expand Down
43 changes: 43 additions & 0 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Development

The current development branch "dev/jax" implements

* [Experimental Support for GPUs/TPUs](#experimental-support-for-gpustpus)

## Experimental Support for GPUs/TPUs

Although OQuPy is built on top of the backend-agnostic
[TensorNetwork](https://github.com/google/TensorNetwork) library,
OQuPy uses vanilla NumPy and SciPy throughout its implementation.

The "dev/jax" branch adds supports for GPUs/TPUs via the
[JAX](https://jax.readthedocs.io/en/latest/) library.
A new `oqupy.backends.numerical_backend.py` module handles the
[breaking changes in JAX NumPy](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html),
while the rest of the modules utilizes `numpy` and `scipy.linalg` instances from there
without explicitly importing JAX-based libraries.

### Enabling Experimental Features

To enable experimental features switch to the `dev/jax` branch and use
```python
from oqupy.backends import enable_jax_features
enable_jax_features()
```
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
initialize the jax backend by default.

### Contributing Guidelines

To contribute features compatible with the JAX backend,
please adhere to the following set of guidelines:

* avoid wildcard imports of NumPy and SciPy.
* use `from oqupy.backends.numerical_backend import np` instead of `import numpy as np` and use the alias `default_np` in cases vanilla NumPy is explicitly required.
* use `from oqupy.backends.numerical_backend import la` instead of `import scipy.linalg as la`, except that for non-symmetric eigen-decomposition, `scipy.linalg.eig` should be used.
* use one of `np.dtype_complex` (`np.dtype_float`) or `oqupy.config.NumPyDtypeComplex` (`oqupy.config.NumPyDtypeFloat`) instead of `np.complex_` (`np.float_`).
* convert lists or tuples to arrays when passing them as arguments inside functions.
* use `array = np.update(array, indices, values)` instead of `array[indices] = values`.
* use `np.get_random_floats(seed, shape)` instead of `np.random.default_rng(seed).random(shape)`.
* declare signatures for `np.vectorize` explicitly.
* avoid directly changing the `shape` attribute of an array (use `.reshape` instead)
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Furthermore, OQuPy implements methods to ...
:caption: Development

pages/contributing
pages/gpu_features
pages/authors
pages/how_to_cite
pages/sharing
Expand Down
4 changes: 3 additions & 1 deletion docs/pages/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ class :class:`oqupy.pt_tebd.PtTebd`
dictionary.



Results
-------

Expand Down Expand Up @@ -207,3 +206,6 @@ module :mod:`oqupy.operators`
function :func:`oqupy.helpers.plot_correlations_with_parameters`
A helper function to plot an auto-correlation function and the sampling
points given by a set of parameters for a TEMPO computation.

function :func:`oqupy.backends.enable_jax_features`
Option to use JAX to support multiple device backends (CPUs/GPUs/TPUs).
55 changes: 55 additions & 0 deletions docs/pages/gpu_features.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
Experimental Support for GPUs/TPUs
==================================
The current development branch "dev/jax" implements experimental support
for GPUs/TPUs.

Although OQuPy is built on top of the backend-agnostic
`TensorNetwork <https://github.com/google/TensorNetwork>`__ library,
OQuPy uses vanilla NumPy and SciPy throughout its implementation.

The "dev/jax" branch adds supports for GPUs/TPUs via the
`JAX <https://jax.readthedocs.io/en/latest/>`__ library. A new
``oqupy.backends.numerical_backend.py`` module handles the
`breaking changes in JAX
NumPy <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html>`__,
while the rest of the modules utilizes ``numpy`` and ``scipy.linalg``
instances from there without explicitly importing JAX-based libraries.

Enabling Experimental Features
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To enable experimental features, switch to the ``dev/jax`` branch and use

.. code:: python
from oqupy.backends import enable_jax_features
enable_jax_features()
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
initialize the jax backend by default.

Contributing Guidelines
~~~~~~~~~~~~~~~~~~~~~~~

To contribute features compatible with the JAX backend,
please adhere to the following set of guidelines:

- avoid wildcard imports of NumPy and SciPy.
- use ``from oqupy.backends.numerical_backend import np`` instead of
``import numpy as np`` and use the alias ``default_np`` in cases
vanilla NumPy is explicitly required.
- use ``from oqupy.backends.numerical_backend import la`` instead of
``import scipy.linalg as la``, except that for non-symmetric
eigen-decomposition, ``scipy.linalg.eig`` should be used.
- use one of ``np.dtype_complex`` (``np.dtype_float``) or
``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``)
instead of ``np.complex_`` (``np.float_``).
- convert lists or tuples to arrays when passing them as arguments
inside functions.
- use ``array = np.update(array, indices, values)`` instead of
``array[indices] = values``.
- use ``np.get_random_floats(seed, shape)`` instead of
``np.random.default_rng(seed).random(shape)``.
- declare signatures for ``np.vectorize`` explicitly.
- avoid directly changing the ``shape`` attribute of an array (use
``.reshape`` instead)
43 changes: 43 additions & 0 deletions examples/simple_dynamics_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python

import sys
sys.path.insert(0, '.')
# set the 'OQUPY_BACKEND' environment variable
# to 'jax' to initialize JAX backend by default
# or switch to JAX backend using oqupy.backends
import oqupy
from oqupy.backends import enable_jax_features
# import NumPy from numerical_backend
#from oqupy.backends.numerical_backend import np
#enable_jax_features()

import matplotlib.pyplot as plt
sigma_x = oqupy.operators.sigma("x")
sigma_z = oqupy.operators.sigma("z")
up_density_matrix = oqupy.operators.spin_dm("z+")
Omega = 1.0
omega_cutoff = 5.0
alpha = 0.3

system = oqupy.System(0.5 * Omega * sigma_x)
correlations = oqupy.PowerLawSD(alpha=alpha,
zeta=1,
cutoff=omega_cutoff,
cutoff_type='exponential')
bath = oqupy.Bath(0.5 * sigma_z, correlations)
tempo_parameters = oqupy.TempoParameters(dt=0.1, tcut=3.0, epsrel=10**(-4))

dynamics = oqupy.tempo_compute(system=system,
bath=bath,
initial_state=up_density_matrix,
start_time=0.0,
end_time=2.0,
parameters=tempo_parameters,
unique=True)
t, s_z = dynamics.expectations(0.5*sigma_z, real=True)
print(s_z)
plt.plot(t, s_z, label=r'$\alpha=0.3$')
plt.xlabel(r'$t\,\Omega$')
plt.ylabel(r'$\langle\sigma_z\rangle$')
plt.savefig('simple_dynamics_jax.png')
#plt.show()
9 changes: 9 additions & 0 deletions oqupy/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Module to initialize OQuPy's backends."""

from oqupy.backends.numerical_backend import set_numerical_backends

def enable_jax_features():
"""Function to enable experimental features."""

# set numerical backend to JAX
set_numerical_backends('jax')
2 changes: 1 addition & 1 deletion oqupy/backends/node_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

from typing import Any, List, Optional, Text, Tuple, Union

import numpy as np
import tensornetwork as tn
from tensornetwork import Node
from tensornetwork.backends.base_backend import BaseBackend

from oqupy.backends.numerical_backend import np

class NodeArray:
"""NodeArray class. """
Expand Down
166 changes: 166 additions & 0 deletions oqupy/backends/numerical_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing NumPy-like and SciPy-like numerical backends.
"""

import os

import numpy as default_np
import scipy.linalg as default_la

from tensornetwork.backend_contextmanager import \
set_default_backend

import oqupy.config as oc

# store instances of the initialized backends
# this way, `oqupy.config` remains unchanged
# and `ocupy.config.DEFAULT_BACKEND` is used
# when NumPy and LinAlg are initialized
NUMERICAL_BACKEND_INSTANCES = {}

def get_numerical_backends(
backend_name: str,
):
"""Function to get numerical backend.
Parameters
----------
backend_name: str
Name of the backend. Options are `'jax'` and `'numpy'`.
Returns
-------
backends: list
NumPy and LinAlg backends.
"""

_bn = backend_name.lower()
if _bn in NUMERICAL_BACKEND_INSTANCES:
set_default_backend(_bn)
return NUMERICAL_BACKEND_INSTANCES[_bn]
assert _bn in ['jax', 'numpy'], \
"currently supported backends are `'jax'` and `'numpy'`"

if 'jax' in _bn:
try:
# explicitly import and configure jax
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jla
jax.config.update('jax_enable_x64', True)

# # TODO: GPU memory allocation (default is 0.75)
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'

# set TensorNetwork backend
set_default_backend('jax')

NUMERICAL_BACKEND_INSTANCES['jax'] = [jnp, jla]
return NUMERICAL_BACKEND_INSTANCES['jax']
except ImportError:
print("JAX not installed, defaulting to NumPy")

# set TensorNetwork backend
set_default_backend('numpy')

NUMERICAL_BACKEND_INSTANCES['numpy'] = [default_np, default_la]
return NUMERICAL_BACKEND_INSTANCES['numpy']

class NumPy:
"""
The NumPy backend employing
dynamic switching through `oqupy.config`.
"""
def __init__(self,
backend_name=oc.DEFAULT_BACKEND,
):
"""Getter for the backend."""
self.backend = get_numerical_backends(backend_name)[0]

@property
def dtype_complex(self) -> default_np.dtype:
"""Getter for the complex datatype."""
return oc.NumPyDtypeComplex

@property
def dtype_float(self) -> default_np.dtype:
"""Getter for the float datatype."""
return oc.NumPyDtypeFloat

def __getattr__(self,
name: str,
):
"""Return the backend's default attribute."""
return getattr(self.backend, name)

def update(self,
array,
indices: tuple,
values,
) -> default_np.ndarray:
"""Option to update select indices of an array with given values."""
if not isinstance(array, default_np.ndarray):
return array.at[indices].set(values)
array[indices] = values
return array

def get_random_floats(self,
seed,
shape,
):
"""Method to obtain random floats with a given seed and shape."""
random_floats = default_np.random.default_rng(seed).random(shape, \
dtype=default_np.float64)
return self.backend.array(random_floats, dtype=self.dtype_float)

class LinAlg:
"""
The Linear Algebra backend employing
dynamic switching through `oqupy.config`.
"""
def __init__(self,
backend_name=oc.DEFAULT_BACKEND,
):
"""Getter for the backend."""
self.backend = get_numerical_backends(backend_name)[1]

def __getattr__(self,
name: str,
):
"""Return the backend's default attribute."""
return getattr(self.backend, name)

# setup libraries using environment variable
# fall back to oqupy.config.DEFAULT_BACKEND
try:
BACKEND_NAME = os.environ[oc.BACKEND_ENV_VAR]
except KeyError:
BACKEND_NAME = oc.DEFAULT_BACKEND
np = NumPy(backend_name=BACKEND_NAME)
la = LinAlg(backend_name=BACKEND_NAME)

def set_numerical_backends(
backend_name: str
):
"""Function to set numerical backend.
Parameters
----------
backend_name: str
Name of the backend. Options are `'jax'` and `'numpy'`.
"""
backends = get_numerical_backends(backend_name)
np.backend = backends[0]
la.backend = backends[1]
Loading

0 comments on commit 6ccbd2e

Please sign in to comment.