From 0eabea708bf257295451e0e56962ba05932f1f39 Mon Sep 17 00:00:00 2001 From: Rochisha Agarwal Date: Wed, 19 Jun 2024 17:20:27 +0530 Subject: [PATCH] add jax support for entropy.py --- qutip/entropy.py | 53 ++++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/qutip/entropy.py b/qutip/entropy.py index 7a5d8a7dc0..e93f42d9b8 100644 --- a/qutip/entropy.py +++ b/qutip/entropy.py @@ -2,8 +2,10 @@ 'concurrence', 'entropy_conditional', 'entangling_power', 'entropy_relative'] -from numpy import conj, e, inf, imag, inner, real, sort, sqrt -from numpy.lib.scimath import log, log2 +from .settings import settings +#from numpy import conj, e, inf, imag, inner, real, sort, sqrt +#from numpy import log, log2 +from math import e from .partial_transpose import partial_transpose from . import (ptrace, tensor, sigmay, ket2dm, expand_operator) @@ -35,17 +37,19 @@ def entropy_vn(rho, base=e, sparse=False): 1.0 """ + np = settings.core["backend"] if rho.type == 'ket' or rho.type == 'bra': rho = ket2dm(rho) vals = rho.eigenenergies(sparse=sparse) - nzvals = vals[vals != 0] + #nzvals = vals[vals != 0] + nzvals = np.clip(vals, 1e-12, None) if base == 2: - logvals = log2(nzvals) - elif base == e: - logvals = log(nzvals) + logvals = np.log2(nzvals) + elif base == np.e: + logvals = np.log(nzvals) else: raise ValueError("Base must be 2 or e.") - return float(real(-sum(nzvals * logvals))) + return np.real(-sum(nzvals * logvals)) def entropy_linear(rho): @@ -69,9 +73,10 @@ def entropy_linear(rho): 0.5 """ + np = settings.core["backend"] if rho.type == 'ket' or rho.type == 'bra': rho = ket2dm(rho) - return float(real(1.0 - (rho ** 2).tr())) + return np.real(1.0 - (rho ** 2).tr()) def concurrence(rho): @@ -94,6 +99,7 @@ def concurrence(rho): .. [1] `https://en.wikipedia.org/wiki/Concurrence_(quantum_computing)` """ + np = settings.core["backend"] if rho.isket and rho.dims != [[2, 2], [1, 1]]: raise Exception("Ket must be tensor product of two qubits.") @@ -113,11 +119,11 @@ def concurrence(rho): evals = rho_tilde.eigenenergies() # abs to avoid problems with sqrt for very small negative numbers - evals = abs(sort(real(evals))) + evals = abs(np.sort(np.real(evals))) - lsum = sqrt(evals[3]) - sqrt(evals[2]) - sqrt(evals[1]) - sqrt(evals[0]) + lsum = np.sqrt(evals[3]) - np.sqrt(evals[2]) - np.sqrt(evals[1]) - np.sqrt(evals[0]) - return max(0, lsum) + return np.maximum(0, lsum) def negativity(rho, subsys, method='tracenorm', logarithmic=False): @@ -130,6 +136,7 @@ def negativity(rho, subsys, method='tracenorm', logarithmic=False): Experimental. """ + np = settings.core["backend"] if rho.isket or rho.isbra: rho = ket2dm(rho) mask = [idx == subsys for idx, n in enumerate(rho.dims[0])] @@ -145,7 +152,7 @@ def negativity(rho, subsys, method='tracenorm', logarithmic=False): # Return the negativity value (or its logarithm if specified) if logarithmic: - return log2(2 * N + 1) + return np.log2(2 * N + 1) else: return N @@ -244,6 +251,8 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12): Section 11.3.1, pg. 511 for a detailed explanation of quantum relative entropy. """ + np = settings.core["backend"] + if rho.isket: rho = ket2dm(rho) if sigma.isket: @@ -253,9 +262,9 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12): if rho.dims != sigma.dims: raise ValueError("Inputs must have the same shape and dims.") if base == 2: - log_base = log2 - elif base == e: - log_base = log + log_base = np.log2 + elif base == np.e: + log_base = np.log else: raise ValueError("Base must be 2 or e.") # S(rho || sigma) = sum_i(p_i log p_i) - sum_ij(p_i P_ij log q_i) @@ -264,19 +273,19 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12): # intersection with the support of rho (i.e. rvecs[rvals != 0]). rvals, rvecs = _data.eigs(rho.data, rho.isherm, True) rvecs = rvecs.to_array().T - if any(abs(imag(rvals)) >= tol): + if any(abs(np.imag(rvals)) >= tol): raise ValueError("Input rho has non-real eigenvalues.") - rvals = real(rvals) + rvals = np.real(rvals) svals, svecs = _data.eigs(sigma.data, sigma.isherm, True) svecs = svecs.to_array().T - if any(abs(imag(svals)) >= tol): + if any(abs(np.imag(svals)) >= tol): raise ValueError("Input sigma has non-real eigenvalues.") - svals = real(svals) + svals = np.real(svals) # Calculate inner products of eigenvectors and return +inf if kernel # of sigma overlaps with support of rho. - P = abs(inner(rvecs, conj(svecs))) ** 2 + P = abs(np.inner(rvecs, np.conj(svecs))) ** 2 if (rvals >= tol) @ (P >= tol) @ (svals < tol): - return inf + return np.inf # Avoid -inf from log(0) -- these terms will be multiplied by zero later # anyway svals[abs(svals) < tol] = 1 @@ -285,7 +294,7 @@ def entropy_relative(rho, sigma, base=e, sparse=False, tol=1e-12): S = nzrvals @ log_base(nzrvals) - rvals @ P @ log_base(svals) # the relative entropy is guaranteed to be >= 0, so we clamp the # calculated value to 0 to avoid small violations of the lower bound. - return max(0, S) + return np.maximum(0, S) def entropy_conditional(rho, selB, base=e, sparse=False):