Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switchable numpy backends #80

Merged
merged 58 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
525a664
Support for switching numpy to different module
obackhouse Sep 14, 2024
79f0bb3
Start on tensorflow backend
obackhouse Sep 14, 2024
9831eef
Getting there with tensorflow
obackhouse Sep 14, 2024
c074ee0
More tensorflow backend
obackhouse Sep 15, 2024
10f3e51
Fix RMP2 rdm2
obackhouse Sep 16, 2024
8f18625
Fix einsum tests
obackhouse Sep 16, 2024
2e8d6a3
Fix MP2 2RDMs for tensorflow
obackhouse Sep 16, 2024
71a4f26
Disable CCSDt' tests for tensorflow backend
obackhouse Sep 16, 2024
b22fc51
Fix conversion routines for tensorflow
obackhouse Sep 16, 2024
cf697bd
linting
obackhouse Sep 16, 2024
bdb3f25
Add missing backend import
obackhouse Sep 16, 2024
1686f73
Add backend to workflow name
obackhouse Sep 16, 2024
c40979c
Print durations after tests
obackhouse Sep 16, 2024
4d4b985
Mark GCCSDTQ test as slow
obackhouse Sep 16, 2024
374f811
Don't skip slow tests by default
obackhouse Sep 16, 2024
b1d79a5
Custom DIIS
obackhouse Sep 16, 2024
c53734f
Custom DIIS to allow immutable backend
obackhouse Sep 16, 2024
0f26854
Fix EOM for immutable backend
obackhouse Sep 16, 2024
d4ed4cf
Fix BCC for immutable backend
obackhouse Sep 16, 2024
a5191b4
Fix test for immutable backend
obackhouse Sep 16, 2024
3d5c71c
Missed a file
obackhouse Sep 16, 2024
5e1ed79
Disable EOM tests without numpy backend because of pyscf davidson
obackhouse Sep 16, 2024
00337a8
Linting
obackhouse Sep 16, 2024
2e0ec8b
Ensure cast to backend array type
obackhouse Sep 16, 2024
6325b35
Add minimal test suite
obackhouse Sep 17, 2024
fb88665
Run minimal tests only for tensorflow
obackhouse Sep 17, 2024
dfa4e1e
Jax backend
obackhouse Sep 17, 2024
8b687ab
Minimal tests for jax and tensorflow
obackhouse Sep 17, 2024
03e53cc
Cupy backend
obackhouse Sep 17, 2024
5304b81
Fix workflow file
obackhouse Sep 17, 2024
2f98232
Add other backend installation options
obackhouse Sep 17, 2024
76a9afe
Don't add all deps to dev
obackhouse Sep 17, 2024
7c63102
Fix wrong inplace keyword arg
obackhouse Sep 17, 2024
7af8b09
Fix workflows
obackhouse Sep 18, 2024
3900bae
Clean up DIIS
obackhouse Sep 18, 2024
702c71a
Try to fix workflows again
obackhouse Sep 18, 2024
b222c30
That should work
obackhouse Sep 18, 2024
2f7108a
Linting
obackhouse Sep 18, 2024
ab96061
Fix TTGT name
obackhouse Sep 19, 2024
1b35036
Don't use swapaxes
obackhouse Sep 19, 2024
069c53f
Fix contract name in test
obackhouse Sep 20, 2024
25171d5
Remove tensorflow monkeypatched tensor methods
obackhouse Sep 20, 2024
2fc727b
Fix transpose typos
obackhouse Sep 20, 2024
334190a
Change numpy array method calls to array API calls
obackhouse Sep 20, 2024
5a7a019
Add astype wrapper for tf
obackhouse Sep 20, 2024
0bb6855
CTF backend
obackhouse Sep 20, 2024
ba0c768
Some fixes for CTF backend
obackhouse Sep 20, 2024
dc157e5
Use multiplication instead of bitwise AND
obackhouse Sep 20, 2024
70e3642
Fix einsum call for ctf
obackhouse Sep 20, 2024
4833525
ones_like and zeros_like for ctf
obackhouse Sep 20, 2024
ec61acd
More changes to support ctf a bit
obackhouse Sep 20, 2024
f4349ec
No need for monkeypatching
obackhouse Sep 20, 2024
f179681
linting
obackhouse Sep 21, 2024
c667cdf
linting
obackhouse Sep 21, 2024
181b672
linting
obackhouse Sep 21, 2024
52cbaea
Don't run double unit tests with coverage
obackhouse Sep 21, 2024
c8e0653
Fix damping for immutable backends
obackhouse Sep 21, 2024
6cbcc8c
Remove poorly conditioned test
obackhouse Sep 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@ on:

jobs:
build:
name: python ${{ matrix.python-version }} on ${{matrix.os}}
name: python ${{ matrix.python-version }} on ${{ matrix.os }} with ${{ matrix.backend }} backend
runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
include:
- {python-version: "3.9", os: ubuntu-latest, documentation: True, coverage: True}
- {python-version: "3.10", os: ubuntu-latest, documentation: False, coverage: True}
- {python-version: "3.11", os: ubuntu-latest, documentation: False, coverage: True}
- {python-version: "3.12", os: ubuntu-latest, documentation: False, coverage: False}
- {python-version: "3.9", backend: "numpy", os: ubuntu-latest, documentation: True, coverage: True, minimal: True, full: True}
- {python-version: "3.10", backend: "numpy", os: ubuntu-latest, documentation: False, coverage: True, minimal: True, full: True}
- {python-version: "3.11", backend: "numpy", os: ubuntu-latest, documentation: False, coverage: True, minimal: True, full: True}
- {python-version: "3.12", backend: "numpy", os: ubuntu-latest, documentation: False, coverage: False, minimal: True, full: True}
- {python-version: "3.12", backend: "tensorflow", os: ubuntu-latest, documentation: False, coverage: False, minimal: True, full: False}
- {python-version: "3.12", backend: "jax", os: ubuntu-latest, documentation: False, coverage: False, minimal: True, full: False}

steps:
- uses: actions/checkout@v2
Expand All @@ -36,23 +38,28 @@ jobs:
- name: Install ebcc
run: |
python -m pip install wheel
python -m pip install .[dev]
python -m pip install .[dev,numpy,jax,tensorflow]
- name: Linting
run: |
python -m black ebcc/ --diff --check --verbose
python -m isort ebcc/ --diff --check-only --verbose
python -m flake8 ebcc/ --verbose
python -m mypy ebcc/ --verbose
- name: Run unit tests with coverage
- name: Run minimal unit tests
run: |
python -m pip install pytest pytest-cov
pytest --cov ebcc/
if: matrix.coverage
python -m pip install pytest
EBCC_BACKEND=${{ matrix.backend }} pytest tests/test_minimal.py
if: matrix.minimal
- name: Run unit tests
run: |
python -m pip install pytest
pytest
if: ${{ ! matrix.coverage }}
EBCC_BACKEND=${{ matrix.backend }} pytest
if: matrix.full && ! matrix.coverage
- name: Run unit tests with coverage
run: |
python -m pip install pytest pytest-cov
EBCC_BACKEND=${{ matrix.backend }} pytest --cov ebcc/
if: matrix.full && matrix.coverage
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
Expand Down
10 changes: 9 additions & 1 deletion ebcc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@
"""List of supported ansatz types."""
METHOD_TYPES = ["MP", "CC", "LCC", "QCI", "QCC", "DC"]

import importlib
import os
import sys
from typing import TYPE_CHECKING

import numpy
"""Backend to use for NumPy operations."""
BACKEND: str = os.environ.get("EBCC_BACKEND", "numpy")

if TYPE_CHECKING:
# Import NumPy directly for type-checking purposes
import numpy
else:
numpy = importlib.import_module(f"ebcc.backend._{BACKEND}")

from ebcc.core.logging import NullLogger, default_log, init_logging
from ebcc.cc import GEBCC, REBCC, UEBCC
Expand Down
155 changes: 155 additions & 0 deletions ebcc/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Backend for NumPy operations.

Notes:
Currently, the following backends are supported:
- NumPy
- CuPy
- TensorFlow
- JAX
- CTF (Cyclops Tensor Framework)

Non-NumPy backends are only lightly supported. Some functionality may not be available, and only
minimal tests are performed. Some operations that require interaction with NumPy such as the
PySCF interfaces may not be efficient, due to the need to convert between NumPy and the backend
array types.
"""

from __future__ import annotations

import importlib
from typing import TYPE_CHECKING

from ebcc import BACKEND

if TYPE_CHECKING:
from types import ModuleType
from typing import Union, TypeVar, Optional

from numpy import int64, generic
from numpy.typing import NDArray

T = TypeVar("T", bound=generic)

if BACKEND == "numpy":
import numpy as np
elif BACKEND == "cupy":
import cupy as np # type: ignore[no-redef]
elif BACKEND == "tensorflow":
import tensorflow as tf
import tensorflow.experimental.numpy as np # type: ignore[no-redef]
elif BACKEND == "jax":
import jax
import jax.numpy as np # type: ignore[no-redef]
elif BACKEND in ("ctf", "cyclops"):
import ctf


def __getattr__(name: str) -> ModuleType:
"""Get the backend module."""
return importlib.import_module(f"ebcc.backend._{BACKEND.lower()}")


def ensure_scalar(obj: Union[T, NDArray[T]]) -> T:
"""Ensure that an object is a scalar.

Args:
obj: Object to ensure is a scalar.

Returns:
Scalar object.
"""
if BACKEND in ("numpy", "cupy", "jax"):
return np.asarray(obj).item() # type: ignore
elif BACKEND == "tensorflow":
if isinstance(obj, tf.Tensor):
return obj.numpy().item() # type: ignore
return obj # type: ignore
elif BACKEND in ("ctf", "cyclops"):
if isinstance(obj, ctf.tensor):
return obj.to_nparray().item() # type: ignore
return obj # type: ignore
else:
raise NotImplementedError(f"`ensure_scalar` not implemented for backend {BACKEND}.")


def to_numpy(array: NDArray[T], dtype: Optional[type[generic]] = None) -> NDArray[T]:
"""Convert an array to NumPy.

Args:
array: Array to convert.
dtype: Data type to convert to.

Returns:
Array in NumPy format.

Notes:
This function does not guarantee a copy of the array.
"""
if BACKEND == "numpy":
ndarray = array
elif BACKEND == "cupy":
ndarray = np.asnumpy(array) # type: ignore
elif BACKEND == "jax":
ndarray = np.array(array) # type: ignore
elif BACKEND == "tensorflow":
ndarray = array.numpy() # type: ignore
elif BACKEND in ("ctf", "cyclops"):
ndarray = array.to_nparray() # type: ignore
else:
raise NotImplementedError(f"`to_numpy` not implemented for backend {BACKEND}.")
if dtype is not None and ndarray.dtype != dtype:
ndarray = ndarray.astype(dtype)
return ndarray


def _put(
array: NDArray[T],
indices: Union[NDArray[int64], tuple[NDArray[int64], ...]],
values: NDArray[T],
) -> NDArray[T]:
"""Put values into an array at specified indices.

Args:
array: Array to put values into.
indices: Indices to put values at.
values: Values to put into the array.

Returns:
Array with values put at specified indices.

Notes:
This function does not guarantee a copy of the array.
"""
if BACKEND == "numpy" or BACKEND == "cupy":
if isinstance(indices, tuple):
indices_flat = np.ravel_multi_index(indices, array.shape)
np.put(array, indices_flat, values)
else:
np.put(array, indices, values)
return array
elif BACKEND == "jax":
if isinstance(indices, tuple):
indices_flat = np.ravel_multi_index(indices, array.shape)
array = np.put(array, indices_flat, values, inplace=False) # type: ignore
else:
array = np.put(array, indices, values, inplace=False) # type: ignore
return array
elif BACKEND == "tensorflow":
if isinstance(indices, (tuple, list)):
indices_grid = tf.meshgrid(*indices, indexing="ij")
indices = tf.stack([np.ravel(tf.cast(idx, tf.int32)) for idx in indices_grid], axis=1)
else:
indices = tf.cast(tf.convert_to_tensor(indices), tf.int32)
indices = tf.expand_dims(indices, axis=-1)
values = np.ravel(tf.convert_to_tensor(values, dtype=array.dtype))
return tf.tensor_scatter_nd_update(array, indices, values) # type: ignore
elif BACKEND in ("ctf", "cyclops"):
# TODO MPI has to be manually managed here
if isinstance(indices, tuple):
indices_flat = np.ravel_multi_index(indices, array.shape)
array.write(indices_flat, values) # type: ignore
else:
array.write(indices, values) # type: ignore
return array
else:
raise NotImplementedError(f"`_put` not implemented for backend {BACKEND}.")
139 changes: 139 additions & 0 deletions ebcc/backend/_ctf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# type: ignore
"""Cyclops Tensor Framework backend."""

import ctf
import numpy
import opt_einsum


def __getattr__(name):
"""Get the attribute from CTF."""
return getattr(ctf, name)


class FakeLinalg:
"""Fake linalg module for CTF."""

def __getattr__(self, name):
"""Get the attribute from CTF's linalg module."""
return getattr(ctf.linalg, name)

def eigh(self, a): # noqa: D102
# TODO Need to determine if SCALAPACK is available
w, v = numpy.linalg.eigh(a.to_nparray())
w = ctf.astensor(w)
v = ctf.astensor(v)
return w, v

def norm(self, a, ord=None): # noqa: D102
return ctf.norm(a, ord=ord)


linalg = FakeLinalg()


bool_ = numpy.bool_
inf = numpy.inf
asarray = ctf.astensor


_array = ctf.array


def array(obj, **kwargs): # noqa: D103
if isinstance(obj, ctf.tensor):
return obj
return _array(numpy.asarray(obj), **kwargs)


def astype(obj, dtype): # noqa: D103
return obj.astype(dtype)


def zeros_like(obj): # noqa: D103
return ctf.zeros(obj.shape).astype(obj.dtype)


def ones_like(obj): # noqa: D103
return ctf.ones(obj.shape).astype(obj.dtype)


def arange(start, stop=None, step=1, dtype=None): # noqa: D103
if stop is None:
stop = start
start = 0
return ctf.arange(start, stop, step=step, dtype=dtype)


def argmin(obj): # noqa: D103
return ctf.to_nparray(obj).argmin()


def argmax(obj): # noqa: D103
return ctf.to_nparray(obj).argmax()


def bitwise_and(a, b): # noqa: D103
return a * b


def bitwise_not(a): # noqa: D103
return ones_like(a) - a


def concatenate(arrays, axis=None): # noqa: D103
if axis is None:
axis = 0
if axis < 0:
axis += arrays[0].ndim
shape = list(arrays[0].shape)
for arr in arrays[1:]:
for i, (a, b) in enumerate(zip(shape, arr.shape)):
if i == axis:
shape[i] += b
elif a != b:
raise ValueError("All arrays must have the same shape")

result = ctf.zeros(shape, dtype=arrays[0].dtype)
start = 0
for arr in arrays:
end = start + arr.shape[axis]
slices = [slice(None)] * result.ndim
slices[axis] = slice(start, end)
result[tuple(slices)] = arr
start = end

return result


def _block_recursive(arrays, max_depth, depth=0): # noqa: D103
if depth < max_depth:
arrs = [_block_recursive(arr, max_depth, depth + 1) for arr in arrays]
return concatenate(arrs, axis=-(max_depth - depth))
else:
return arrays


def block(arrays): # noqa: D103
def _get_max_depth(arrays):
if isinstance(arrays, list):
return 1 + max([_get_max_depth(arr) for arr in arrays])
return 0

return _block_recursive(arrays, _get_max_depth(arrays))


def einsum(*args, optimize=True, **kwargs):
"""Evaluate an einsum expression."""
# FIXME This shouldn't be called, except via `util.einsum`, which should have already
# optimised the expression. We should check if this contraction has more than
# two tensors and if so, raise an error.
return ctf.einsum(*args, **kwargs)


def einsum_path(*args, **kwargs):
"""Evaluate the lowest cost contraction order for an einsum expression."""
kwargs = dict(kwargs)
if kwargs.get("optimize", True) is True:
kwargs["optimize"] = "optimal"
return opt_einsum.contract_path(*args, **kwargs)
18 changes: 18 additions & 0 deletions ebcc/backend/_cupy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# type: ignore
"""CuPy backend."""

import cupy
import opt_einsum


def __getattr__(name):
"""Get the attribute from CuPy."""
return getattr(cupy, name)


def einsum_path(*args, **kwargs):
"""Evaluate the lowest cost contraction order for an einsum expression."""
kwargs = dict(kwargs)
if kwargs.get("optimize", True) is True:
kwargs["optimize"] = "optimal"
return opt_einsum.contract_path(*args, **kwargs)
Loading
Loading