Skip to content

Commit

Permalink
refactor norm
Browse files Browse the repository at this point in the history
  • Loading branch information
wistaria committed Nov 17, 2023
1 parent 099e784 commit 98f2217
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 25 deletions.
3 changes: 2 additions & 1 deletion src/qailo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from . import operator as op
from . import state_vector as sv
from ._version import version
from .dispatch import apply, apply_seq, num_qubits, probability, vector
from .dispatch import apply, apply_seq, norm, num_qubits, probability, vector

__all__ = [
alg,
Expand All @@ -15,6 +15,7 @@
version,
apply,
apply_seq,
norm,
num_qubits,
probability,
vector,
Expand Down
9 changes: 9 additions & 0 deletions src/qailo/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from . import mps
from . import operator as op
from . import state_vector as sv
Expand All @@ -23,6 +25,13 @@ def apply_seq(v, seq):
return v


def norm(v):
if sv.is_state_vector(v):
return np.linalg.norm(v)
elif mps.is_mps(v):
return v._norm()


def num_qubits(v):
if sv.is_state_vector(v):
return sv.num_qubits(v)
Expand Down
2 changes: 0 additions & 2 deletions src/qailo/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .apply import apply, apply_seq
from .mps_c import canonical_mps
from .mps_p import projector_mps
from .norm import norm
from .product_state import one, product_state, tensor_decomposition, zero
from .projector import projector
from .state_vector import state_vector
Expand All @@ -12,7 +11,6 @@
apply_seq,
canonical_mps,
projector_mps,
norm,
one,
product_state,
tensor_decomposition,
Expand Down
7 changes: 7 additions & 0 deletions src/qailo/mps/mps_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def __init__(self, tensors, nkeep=None):
def _num_qubits(self):
return len(self.tensors)

def _norm(self):
A = np.identity(1)
for t in range(self._num_qubits()):
A = np.einsum("ij,jkl->ikl", A, self._tensor(t))
A = np.einsum("ijk,ijl->kl", A, self._tensor(t).conj())
return np.sqrt(np.trace(A))

def _state_vector(self):
n = self._num_qubits()
v = self._tensor(0)
Expand Down
7 changes: 7 additions & 0 deletions src/qailo/mps/mps_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def __init__(self, tensors, nkeep=None):
def _num_qubits(self):
return len(self.tensors)

def _norm(self):
A = np.identity(1)
for t in range(self._num_qubits()):
A = np.einsum("ij,jkl->ikl", A, self._tensor(t))
A = np.einsum("ijk,ijl->kl", A, self._tensor(t).conj())
return np.sqrt(np.trace(A))

def _state_vector(self):
n = self._num_qubits()
v = self._tensor(0)
Expand Down
11 changes: 0 additions & 11 deletions src/qailo/mps/norm.py

This file was deleted.

14 changes: 7 additions & 7 deletions test/mps/test_canonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,40 @@ def test_canonical():
tensors.append(np.random.random((d, 2, 1)))
for mps in [q.mps.canonical_mps, q.mps.projector_mps]:
m = mps(tensors)
norm = q.mps.norm(m)
assert q.mps.norm(m) == approx(norm)
norm = q.norm(m)
assert q.norm(m) == approx(norm)
assert q.mps.is_canonical(m)

for _ in range(n):
p = np.random.randint(n)
m._canonicalize(p)
assert q.mps.norm(m) == approx(norm)
assert q.norm(m) == approx(norm)
assert q.mps.is_canonical(m)
assert q.mps.is_canonical(m)

for _ in range(n):
p = np.random.randint(n - 1)
m._canonicalize(p, p + 1)
assert q.mps.norm(m) == approx(norm)
assert q.norm(m) == approx(norm)
assert q.mps.is_canonical(m)

v = np.random.random(2**n).reshape((2,) * n + (1,))
v /= np.linalg.norm(v)
tensors = q.mps.tensor_decomposition(v, maxdim)
for mps in [q.mps.canonical_mps, q.mps.projector_mps]:
m = mps(tensors)
norm = q.mps.norm(m)
norm = q.norm(m)

for _ in range(n):
p = np.random.randint(n)
m._canonicalize(p)
assert q.mps.norm(m) == approx(norm)
assert q.norm(m) == approx(norm)
assert q.mps.is_canonical(m)

for _ in range(n):
p = np.random.randint(n - 1)
m._canonicalize(p, p + 1)
assert q.mps.norm(m) == approx(norm)
assert q.norm(m) == approx(norm)
assert q.mps.is_canonical(m)


Expand Down
8 changes: 4 additions & 4 deletions test/mps/test_move.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_swap():
for mps in [q.mps.canonical_mps, q.mps.projector_mps]:
m = mps(tensors)
q.mps.is_canonical(m)
norm = q.mps.norm(m)
norm = q.norm(m)
v = q.sv.vector(q.mps.state_vector(m))
for _ in range(64):
s = np.random.randint(n - 1)
Expand All @@ -28,7 +28,7 @@ def test_swap():
_swap_tensors(m, s)
print(q.sv.vector(q.mps.state_vector(m)))
q.mps.is_canonical(m)
assert q.mps.norm(m) == approx(norm)
assert q.norm(m) == approx(norm)

vn = q.sv.vector(q.mps.state_vector(m))
assert len(v) == len(vn)
Expand All @@ -51,7 +51,7 @@ def test_move():
for mps in [q.mps.canonical_mps, q.mps.projector_mps]:
m = mps(tensors)
q.mps.is_canonical(m)
norm = q.mps.norm(m)
norm = q.norm(m)
v = q.sv.vector(q.mps.state_vector(m))

for _ in range(16):
Expand All @@ -62,7 +62,7 @@ def test_move():
_move_qubit(m, p, s)
print(q.sv.vector(q.mps.state_vector(m)))
q.mps.is_canonical(m)
assert q.mps.norm(m) == approx(norm)
assert q.norm(m) == approx(norm)

vn = q.sv.vector(q.mps.state_vector(m))
assert len(v) == len(vn)
Expand Down

0 comments on commit 98f2217

Please sign in to comment.