Skip to content

Commit

Permalink
Merge pull request qutip#2513 from qutip/dev.major
Browse files Browse the repository at this point in the history
Merge `jax` support PR in `dev.major` into `master`.
  • Loading branch information
Ericgig authored Aug 20, 2024
2 parents e86f131 + 9f412f5 commit 03487d2
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 48 deletions.
1 change: 1 addition & 0 deletions doc/changes/2461.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for `jit` and `grad` in qutip.core.metrics
1 change: 1 addition & 0 deletions doc/changes/2490.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This pull request introduces a new `NumpyBackend `class that enables dynamic selection of the numpy_backend used in `qutip`. The class facilitates switching between different numpy implementations ( `numpy` and `jax.numpy` mainly) based on the configuration specified in the `settings.core` dictionary.
1 change: 1 addition & 0 deletions doc/changes/2499.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable mcsolve with jax.grad using numpy_backend
1 change: 1 addition & 0 deletions doc/changes/2507.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`clip` gives deprecation warning, that might be a problem in the future. Hence switch to `where`
10 changes: 5 additions & 5 deletions qutip/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'hellinger_dist', 'hilbert_dist', 'average_gate_fidelity',
'process_fidelity', 'unitarity', 'dnorm']

import numpy as np
from .numpy_backend import np
from scipy import linalg as la
import scipy.sparse as sp
from .superop_reps import to_choi, _to_superpauli, to_super, kraus_to_choi
Expand Down Expand Up @@ -80,7 +80,8 @@ def fidelity(A, B):
# even for positive semidefinite matrices, small negative eigenvalues
# can be reported.
eig_vals = (sqrtmA * B * sqrtmA).eigenenergies()
return float(np.real(np.sqrt(eig_vals[eig_vals > 0]).sum()))
eig_vals_non_neg = np.where(eig_vals > 0, eig_vals, 0)
return np.real(np.sqrt(eig_vals_non_neg).sum())


def _hilbert_space_dims(oper):
Expand Down Expand Up @@ -288,7 +289,7 @@ def tracedist(A, B, sparse=False, tol=0):
diff = A - B
diff = diff.dag() * diff
vals = diff.eigenenergies(sparse=sparse, tol=tol)
return float(np.real(0.5 * np.sum(np.sqrt(np.abs(vals)))))
return np.real(0.5 * np.sum(np.sqrt(np.abs(vals))))


def hilbert_dist(A, B):
Expand Down Expand Up @@ -520,7 +521,6 @@ def dnorm(A, B=None, solver="CVXOPT", verbose=False, force_solve=False,
# of the dual map of Lambda. We can evaluate that norm much more
# easily if Lambda is completely positive, since then the largest
# eigenvalue is the same as the largest singular value.

if not force_solve and J.iscp:
S_dual = to_super(J.dual_chan())
vec_eye = operator_to_vector(qeye(S_dual.dims[1][1]))
Expand Down Expand Up @@ -575,7 +575,7 @@ def unitarity(oper):
return np.linalg.norm(Eu, 'fro')**2 / len(Eu)


def _find_poly_distance(eigenvals: np.ndarray) -> float:
def _find_poly_distance(eigenvals) -> float:
"""
Returns the distance between the origin and the convex hull of eigenvalues.
Expand Down
13 changes: 13 additions & 0 deletions qutip/core/numpy_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ..settings import settings


class NumpyBackend:
def _qutip_setting_backend(self, np):
self._qt_np = np

def __getattr__(self, name):
return getattr(self._qt_np, name)


# Initialize the numpy backend
np = NumpyBackend()
28 changes: 25 additions & 3 deletions qutip/core/options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ..settings import settings
from .numpy_backend import np as qt_np
import numpy
from typing import overload, Literal, Any
import types

Expand All @@ -11,9 +13,13 @@ class QutipOptions:
Define basic method to wrap an ``options`` dict.
Default options are in a class _options dict.
Options can also act as properties. The ``_properties`` map options keys to
a function to call when the ``QutipOptions`` become the default.
"""

_options: dict[str, Any] = {}
_properties = {}
_settings_name = None # Where the default is in settings

def __init__(self, **options):
Expand All @@ -33,6 +39,11 @@ def __getitem__(self, key: str) -> Any:
def __setitem__(self, key: str, value: Any) -> None:
# Let the dict catch the KeyError
self.options[key] = value
if (
key in self._properties
and self is getattr(settings, self._settings_name)
):
self._properties[key](value)

def __repr__(self, full: bool = True) -> str:
out = [f"<{self.__class__.__name__}("]
Expand All @@ -47,15 +58,20 @@ def __repr__(self, full: bool = True) -> str:

def __enter__(self):
self._backup = getattr(settings, self._settings_name)
setattr(settings, self._settings_name, self)
self._set_as_global_default()

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: types.TracebackType | None,
) -> None:
setattr(settings, self._settings_name, self._backup)
self._backup._set_as_global_default()

def _set_as_global_default(self):
setattr(settings, self._settings_name, self)
for key in self._properties:
self._properties[key](self.options[key])


class CoreOptions(QutipOptions):
Expand Down Expand Up @@ -137,8 +153,13 @@ class CoreOptions(QutipOptions):
# Expect, trace, etc. will return real for hermitian matrices.
# Hermiticity checks can be slow, stop jitting, etc.
"auto_real_casting": True,
# Default backend is numpy
"numpy_backend": numpy
}
_settings_name = "core"
_properties = {
"numpy_backend": qt_np._qutip_setting_backend,
}

@overload
def __getitem__(
Expand Down Expand Up @@ -191,4 +212,5 @@ def __setitem__(self, key: str, value: Any) -> None:


# Creating the instance of core options to use everywhere.
settings.core = CoreOptions()
# settings.core = CoreOptions()
CoreOptions()._set_as_global_default()
58 changes: 29 additions & 29 deletions qutip/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
'concurrence', 'entropy_conditional', 'entangling_power',
'entropy_relative']

from numpy import conj, e, inf, imag, inner, real, sort, sqrt
from numpy.lib.scimath import log, log2
from .core.numpy_backend import np
from .partial_transpose import partial_transpose
from . import (ptrace, tensor, sigmay, ket2dm,
expand_operator)
from .core import data as _data


def entropy_vn(rho, base=e, sparse=False):
def entropy_vn(rho, base=np.e, sparse=False):
"""
Von-Neumann entropy of density matrix
Expand Down Expand Up @@ -38,14 +37,15 @@ def entropy_vn(rho, base=e, sparse=False):
if rho.type == 'ket' or rho.type == 'bra':
rho = ket2dm(rho)
vals = rho.eigenenergies(sparse=sparse)
nzvals = vals[vals != 0]
threshold = 1e-17
nzvals = np.where(vals < threshold, threshold, vals)
if base == 2:
logvals = log2(nzvals)
elif base == e:
logvals = log(nzvals)
logvals = np.log2(nzvals)
elif base == np.e:
logvals = np.log(nzvals)
else:
raise ValueError("Base must be 2 or e.")
return float(real(-sum(nzvals * logvals)))
return np.real(-sum(nzvals * logvals))


def entropy_linear(rho):
Expand All @@ -71,7 +71,7 @@ def entropy_linear(rho):
"""
if rho.type == 'ket' or rho.type == 'bra':
rho = ket2dm(rho)
return float(real(1.0 - (rho ** 2).tr()))
return np.real(1.0 - (rho ** 2).tr())


def concurrence(rho):
Expand Down Expand Up @@ -113,11 +113,11 @@ def concurrence(rho):
evals = rho_tilde.eigenenergies()

# abs to avoid problems with sqrt for very small negative numbers
evals = abs(sort(real(evals)))
evals = abs(np.sort(np.real(evals)))

lsum = sqrt(evals[3]) - sqrt(evals[2]) - sqrt(evals[1]) - sqrt(evals[0])

return max(0, lsum)
sqrt_evals = np.sqrt(evals)
lsum = sqrt_evals[3] - sqrt_evals[2] - sqrt_evals[1] - sqrt_evals[0]
return np.maximum(0, lsum)


def negativity(rho, subsys, method='tracenorm', logarithmic=False):
Expand Down Expand Up @@ -145,12 +145,12 @@ def negativity(rho, subsys, method='tracenorm', logarithmic=False):

# Return the negativity value (or its logarithm if specified)
if logarithmic:
return log2(2 * N + 1)
return np.log2(2 * N + 1)
else:
return N


def entropy_mutual(rho, selA, selB, base=e, sparse=False):
def entropy_mutual(rho, selA, selB, base=np.e, sparse=False):
"""
Calculates the mutual information S(A:B) between selection
components of a system density matrix.
Expand Down Expand Up @@ -192,7 +192,7 @@ def entropy_mutual(rho, selA, selB, base=e, sparse=False):
return out


def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12):
def entropy_relative(rho, sigma, base=np.e, sparse=False, tol=1e-12):
"""
Calculates the relative entropy S(rho||sigma) between two density
matrices.
Expand Down Expand Up @@ -253,9 +253,9 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12):
if rho.dims != sigma.dims:
raise ValueError("Inputs must have the same shape and dims.")
if base == 2:
log_base = log2
elif base == e:
log_base = log
log_base = np.log2
elif base == np.e:
log_base = np.log
else:
raise ValueError("Base must be 2 or e.")
# S(rho || sigma) = sum_i(p_i log p_i) - sum_ij(p_i P_ij log q_i)
Expand All @@ -264,19 +264,19 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12):
# intersection with the support of rho (i.e. rvecs[rvals != 0]).
rvals, rvecs = _data.eigs(rho.data, rho.isherm, True)
rvecs = rvecs.to_array().T
if any(abs(imag(rvals)) >= tol):
if any(abs(np.imag(rvals)) >= tol):
raise ValueError("Input rho has non-real eigenvalues.")
rvals = real(rvals)
rvals = np.real(rvals)
svals, svecs = _data.eigs(sigma.data, sigma.isherm, True)
svecs = svecs.to_array().T
if any(abs(imag(svals)) >= tol):
if any(abs(np.imag(svals)) >= tol):
raise ValueError("Input sigma has non-real eigenvalues.")
svals = real(svals)
svals = np.real(svals)
# Calculate inner products of eigenvectors and return +inf if kernel
# of sigma overlaps with support of rho.
P = abs(inner(rvecs, conj(svecs))) ** 2
P = abs(np.inner(rvecs, np.conj(svecs))) ** 2
if (rvals >= tol) @ (P >= tol) @ (svals < tol):
return inf
return np.inf
# Avoid -inf from log(0) -- these terms will be multiplied by zero later
# anyway
svals[abs(svals) < tol] = 1
Expand All @@ -285,10 +285,10 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12):
S = nzrvals @ log_base(nzrvals) - rvals @ P @ log_base(svals)
# the relative entropy is guaranteed to be >= 0, so we clamp the
# calculated value to 0 to avoid small violations of the lower bound.
return max(0, S)
return np.maximum(0, S)


def entropy_conditional(rho, selB, base=e, sparse=False):
def entropy_conditional(rho, selB, base=np.e, sparse=False):
"""
Calculates the conditional entropy :math:`S(A|B)=S(A,B)-S(B)`
of a selected density matrix component.
Expand Down Expand Up @@ -372,9 +372,9 @@ def entangling_power(U):
raise Exception("U must be a two-qubit gate.")

from qutip.core.gates import swap
swap13 = expand_operator(swap(), [2, 2, 2, 2], [1, 3])
swap13 = expand_operator(swap(dtype=U.dtype), [2, 2, 2, 2], [1, 3])
a = tensor(U, U).dag() * swap13 * tensor(U, U) * swap13
Uswap = swap() * U
Uswap = swap(dtype=U.dtype) * U
b = tensor(Uswap, Uswap).dag() * swap13 * tensor(Uswap, Uswap) * swap13

return 5.0/9 - 1.0/36 * (a.tr() + b.tr()).real
2 changes: 1 addition & 1 deletion qutip/solver/mcsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

__all__ = ['mcsolve', "MCSolver"]

import numpy as np
from ..core.numpy_backend import np
from numpy.typing import ArrayLike
from numpy.random import SeedSequence
from time import time
Expand Down
16 changes: 8 additions & 8 deletions qutip/solver/multitraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from time import time
from .solver_base import Solver
from ..core import QobjEvo, Qobj
import numpy as np
from ..core.numpy_backend import np
from numpy.typing import ArrayLike
from numpy.random import SeedSequence
from numpy.random import SeedSequence, default_rng
from numbers import Number
from typing import Any, Callable
import bisect
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, rhs, *, options=None):
else:
raise TypeError("The system should be a QobjEvo")
self.options = options
self.seed_sequence = np.random.SeedSequence()
self.seed_sequence = SeedSequence()
self._integrator = self._get_integrator()
self._state_metadata = {}
self.stats = self._initialize_stats()
Expand Down Expand Up @@ -360,15 +360,15 @@ def _read_seed(self, seed, ntraj):
"""
if seed is None:
seeds = self.seed_sequence.spawn(ntraj)
elif isinstance(seed, np.random.SeedSequence):
elif isinstance(seed, SeedSequence):
seeds = seed.spawn(ntraj)
elif not isinstance(seed, list):
seeds = np.random.SeedSequence(seed).spawn(ntraj)
seeds = SeedSequence(seed).spawn(ntraj)
elif len(seed) >= ntraj:
seeds = [
seed_ if (isinstance(seed_, np.random.SeedSequence)
seed_ if (isinstance(seed_, SeedSequence)
or hasattr(seed_, 'random'))
else np.random.SeedSequence(seed_)
else SeedSequence(seed_)
for seed_ in seed[:ntraj]
]
else:
Expand All @@ -391,7 +391,7 @@ def _get_generator(self, seed):
bit_gen = getattr(np.random, self.options['bitgenerator'])
generator = np.random.Generator(bit_gen(seed))
else:
generator = np.random.default_rng(seed)
generator = default_rng(seed)
return generator


Expand Down
2 changes: 1 addition & 1 deletion qutip/solver/multitrajresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from typing import TypedDict
import numpy as np
from ..core.numpy_backend import np

from copy import copy

Expand Down
2 changes: 1 addition & 1 deletion qutip/solver/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from typing import TypedDict, Any, Callable
import numpy as np
from ..core.numpy_backend import np
from numpy.typing import ArrayLike
from ..core import Qobj, QobjEvo, expect

Expand Down
22 changes: 22 additions & 0 deletions qutip/tests/core/test_numpy_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
import numpy
from unittest.mock import Mock

from qutip.core.numpy_backend import np
from qutip import CoreOptions

# Mocking JAX to demonstrate backend switching
mock_jax = Mock()
mock_np = numpy


class TestNumpyBackend:
def test_getattr_numpy(self):
with CoreOptions(numpy_backend=mock_np):
assert np.sum([1, 2, 3]) == numpy.sum([1, 2, 3])
assert np.sum is numpy.sum

def test_getattr_jax(self):
with CoreOptions(numpy_backend=mock_jax):
mock_jax.sum = Mock(return_value="jax_sum")
assert np.sum([1, 2, 3]) == "jax_sum"
Loading

0 comments on commit 03487d2

Please sign in to comment.