Skip to content

Commit

Permalink
add jax support for entropy.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rochisha0 committed Jun 19, 2024
1 parent 16a3708 commit 0eabea7
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions qutip/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
'concurrence', 'entropy_conditional', 'entangling_power',
'entropy_relative']

from numpy import conj, e, inf, imag, inner, real, sort, sqrt
from numpy.lib.scimath import log, log2
from .settings import settings
#from numpy import conj, e, inf, imag, inner, real, sort, sqrt
#from numpy import log, log2
from math import e
from .partial_transpose import partial_transpose
from . import (ptrace, tensor, sigmay, ket2dm,
expand_operator)
Expand Down Expand Up @@ -35,17 +37,19 @@ def entropy_vn(rho, base=e, sparse=False):
1.0
"""
np = settings.core["backend"]
if rho.type == 'ket' or rho.type == 'bra':
rho = ket2dm(rho)
vals = rho.eigenenergies(sparse=sparse)
nzvals = vals[vals != 0]
#nzvals = vals[vals != 0]
nzvals = np.clip(vals, 1e-12, None)
if base == 2:
logvals = log2(nzvals)
elif base == e:
logvals = log(nzvals)
logvals = np.log2(nzvals)
elif base == np.e:
logvals = np.log(nzvals)
else:
raise ValueError("Base must be 2 or e.")
return float(real(-sum(nzvals * logvals)))
return np.real(-sum(nzvals * logvals))


def entropy_linear(rho):
Expand All @@ -69,9 +73,10 @@ def entropy_linear(rho):
0.5
"""
np = settings.core["backend"]
if rho.type == 'ket' or rho.type == 'bra':
rho = ket2dm(rho)
return float(real(1.0 - (rho ** 2).tr()))
return np.real(1.0 - (rho ** 2).tr())


def concurrence(rho):
Expand All @@ -94,6 +99,7 @@ def concurrence(rho):
.. [1] `https://en.wikipedia.org/wiki/Concurrence_(quantum_computing)`
"""
np = settings.core["backend"]
if rho.isket and rho.dims != [[2, 2], [1, 1]]:
raise Exception("Ket must be tensor product of two qubits.")

Expand All @@ -113,11 +119,11 @@ def concurrence(rho):
evals = rho_tilde.eigenenergies()

# abs to avoid problems with sqrt for very small negative numbers
evals = abs(sort(real(evals)))
evals = abs(np.sort(np.real(evals)))

lsum = sqrt(evals[3]) - sqrt(evals[2]) - sqrt(evals[1]) - sqrt(evals[0])
lsum = np.sqrt(evals[3]) - np.sqrt(evals[2]) - np.sqrt(evals[1]) - np.sqrt(evals[0])

return max(0, lsum)
return np.maximum(0, lsum)


def negativity(rho, subsys, method='tracenorm', logarithmic=False):
Expand All @@ -130,6 +136,7 @@ def negativity(rho, subsys, method='tracenorm', logarithmic=False):
Experimental.
"""
np = settings.core["backend"]
if rho.isket or rho.isbra:
rho = ket2dm(rho)
mask = [idx == subsys for idx, n in enumerate(rho.dims[0])]
Expand All @@ -145,7 +152,7 @@ def negativity(rho, subsys, method='tracenorm', logarithmic=False):

# Return the negativity value (or its logarithm if specified)
if logarithmic:
return log2(2 * N + 1)
return np.log2(2 * N + 1)
else:
return N

Expand Down Expand Up @@ -244,6 +251,8 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12):
Section 11.3.1, pg. 511 for a detailed explanation of quantum relative
entropy.
"""
np = settings.core["backend"]

if rho.isket:
rho = ket2dm(rho)
if sigma.isket:
Expand All @@ -253,9 +262,9 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12):
if rho.dims != sigma.dims:
raise ValueError("Inputs must have the same shape and dims.")
if base == 2:
log_base = log2
elif base == e:
log_base = log
log_base = np.log2
elif base == np.e:
log_base = np.log
else:
raise ValueError("Base must be 2 or e.")
# S(rho || sigma) = sum_i(p_i log p_i) - sum_ij(p_i P_ij log q_i)
Expand All @@ -264,19 +273,19 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12):
# intersection with the support of rho (i.e. rvecs[rvals != 0]).
rvals, rvecs = _data.eigs(rho.data, rho.isherm, True)
rvecs = rvecs.to_array().T
if any(abs(imag(rvals)) >= tol):
if any(abs(np.imag(rvals)) >= tol):
raise ValueError("Input rho has non-real eigenvalues.")
rvals = real(rvals)
rvals = np.real(rvals)
svals, svecs = _data.eigs(sigma.data, sigma.isherm, True)
svecs = svecs.to_array().T
if any(abs(imag(svals)) >= tol):
if any(abs(np.imag(svals)) >= tol):
raise ValueError("Input sigma has non-real eigenvalues.")
svals = real(svals)
svals = np.real(svals)
# Calculate inner products of eigenvectors and return +inf if kernel
# of sigma overlaps with support of rho.
P = abs(inner(rvecs, conj(svecs))) ** 2
P = abs(np.inner(rvecs, np.conj(svecs))) ** 2
if (rvals >= tol) @ (P >= tol) @ (svals < tol):
return inf
return np.inf
# Avoid -inf from log(0) -- these terms will be multiplied by zero later
# anyway
svals[abs(svals) < tol] = 1
Expand All @@ -285,7 +294,7 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12):
S = nzrvals @ log_base(nzrvals) - rvals @ P @ log_base(svals)
# the relative entropy is guaranteed to be >= 0, so we clamp the
# calculated value to 0 to avoid small violations of the lower bound.
return max(0, S)
return np.maximum(0, S)


def entropy_conditional(rho, selB, base=e, sparse=False):
Expand Down

0 comments on commit 0eabea7

Please sign in to comment.