Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wistaria committed Nov 3, 2023
1 parent 7263d50 commit d66e2e2
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 55 deletions.
13 changes: 9 additions & 4 deletions src/qailo/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from .canonical import is_canonical
from .mps import MPS, check, norm, product_state
from .canonical import check_mps, is_canonical
from .mps import MPS
from .norm import norm
from .num_qubits import num_qubits
from .product_state import product_state
from .state_vector import state_vector
from .svd import compact_svd, tensor_svd

__all__ = [
is_canonical,
check_mps,
MPS,
check,
norm,
product_state,
norm,
num_qubits,
compact_svd,
tensor_svd,
product_state,
state_vector,
]
34 changes: 33 additions & 1 deletion src/qailo/mps/canonical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from .type import num_qubits
from .num_qubits import num_qubits


def is_canonical(m):
Expand All @@ -16,3 +16,35 @@ def is_canonical(m):
return True
else:
return False


def check_mps(m):
"""
Check the shape of mps
"""
n = num_qubits(m)

# tensor shape
dims = []
assert m.tensors[0].shape[0] == 1
dims.append(m.tensors[0].shape[0])
for t in range(1, n - 1):
dims.append(m.tensors[t].shape[0])
assert m.tensors[t].shape[0] == m.tensors[t - 1].shape[2]
assert m.tensors[t].shape[2] == m.tensors[t + 1].shape[0]
assert m.tensors[n - 1].shape[2] == 1
dims.append(m.tensors[n - 1].shape[0])
dims.append(m.tensors[n - 1].shape[2])
# print(dims)

# qubit <-> tensor mapping
for q in range(n):
assert m.t2q[m.q2t[q]] == q
for t in range(n):
assert m.q2t[m.t2q[t]] == t

# canonical position
assert m.cp[0] in range(n)
assert m.cp[1] in range(n)

return True
26 changes: 3 additions & 23 deletions src/qailo/mps/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ..operator import type as op
from ..operator.swap import swap
from .num_qubits import num_qubits
from .svd import tensor_svd


Expand Down Expand Up @@ -34,9 +35,6 @@ def __init__(self, tensors, q2t=None, t2q=None, cp=None):
self.cp = cp if cp is not None else [0, n - 1]
assert 0 <= self.cp[0] and self.cp[0] <= self.cp[1] and self.cp[1] < n

def num_qubits(self):
return len(self.tensors)

def canonicalize(self, p0, p1=None):
p1 = p0 if p1 is None else p1
n = len(self.tensors)
Expand Down Expand Up @@ -81,7 +79,7 @@ def _swap_tensors(self, s, maxdim=None):
"""
swap neighboring two tensors at s and s+1
"""
assert s in range(0, self.num_qubits() - 1)
assert s in range(0, num_qubits(self) - 1)
self._apply_two(swap(), s, maxdim=maxdim)
p0, p1 = self.t2q[s], self.t2q[s + 1]
self.q2t[p0], self.q2t[p1] = s + 1, s
Expand Down Expand Up @@ -118,7 +116,7 @@ def check(mps):
"""
Check the shape of mps
"""
n = mps.num_qubits()
n = num_qubits(mps)

# tensor shape
dims = []
Expand All @@ -144,21 +142,3 @@ def check(mps):
assert mps.cp[1] in range(n)

return True


def product_state(n, c=0):
assert n > 0
tensors = []
for t in range(n):
tensor = np.zeros((1, 2, 1))
tensor[0, (c >> (n - t - 1)) & 1, 0] = 1
tensors.append(tensor)
return MPS(tensors)


def norm(m):
A = np.identity(2)
for t in range(m.num_qubits()):
A = np.einsum("ij,jkl->ikl", A, m.tensors[t])
A = np.einsum("ijk,ijl->kl", A, m.tensors[t].conj())
return np.sqrt(np.trace(A))
11 changes: 11 additions & 0 deletions src/qailo/mps/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np

from .num_qubits import num_qubits


def norm(m):
A = np.identity(2)
for t in range(num_qubits(m)):
A = np.einsum("ij,jkl->ikl", A, m.tensors[t])
A = np.einsum("ijk,ijl->kl", A, m.tensors[t].conj())
return np.sqrt(np.trace(A))
2 changes: 2 additions & 0 deletions src/qailo/mps/num_qubits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def num_qubits(m):
return len(m.tensors)
13 changes: 13 additions & 0 deletions src/qailo/mps/product_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np

from .mps import MPS


def product_state(n, c=0):
assert n > 0
tensors = []
for t in range(n):
tensor = np.zeros((1, 2, 1))
tensor[0, (c >> (n - t - 1)) & 1, 0] = 1
tensors.append(tensor)
return MPS(tensors)
15 changes: 8 additions & 7 deletions src/qailo/mps/state_vector.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import numpy as np

from ..util.strops import letters, replace
from .type import is_mps, num_qubits
from .num_qubits import num_qubits
from .type import is_mps


def state_vector(mps):
assert is_mps(mps)
n = num_qubits(mps)
v = mps.tensors[0]
def state_vector(m):
assert is_mps(m)
n = num_qubits(m)
v = m.tensors[0]
for t in range(1, n):
ss_v0 = letters()[: t + 2]
ss_v1 = letters()[t + 1 : t + 4]
ss_to = letters()[: t + 1] + letters()[t + 2 : t + 4]
v = np.einsum(f"{ss_v0},{ss_v1}->{ss_to}", v, mps.tensors[t])
v = np.einsum(f"{ss_v0},{ss_v1}->{ss_to}", v, m.tensors[t])
v = v.reshape((2,) * n)
ss_from = letters()[:n]
ss_to = ss_from
for p in range(n):
ss_to = replace(ss_to, p, ss_from[mps.q2t[p]])
ss_to = replace(ss_to, p, ss_from[m.q2t[p]])
return np.einsum(f"{ss_from}->{ss_to}", v).reshape((2,) * n + (1,))
8 changes: 2 additions & 6 deletions src/qailo/mps/type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from .mps import MPS


def is_mps(mps):
return isinstance(mps, MPS)


def num_qubits(mps):
return mps.num_qubits()
def is_mps(m):
return isinstance(m, MPS)
7 changes: 3 additions & 4 deletions src/qailo/operator/hconj.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import numpy as np

from ..util.strops import letters
from .type import is_operator, num_qubits


def hconj(op):
assert is_operator(op)
n = num_qubits(op)
ss_from = letters()[: 2 * n]
ss_to = ss_from[n : 2 * n] + ss_from[:n]
return np.einsum("{}->{}".format(ss_from, ss_to), op).conjugate()
ss_from = list(range(2 * n))
ss_to = list(range(n, 2 * n)) + list(range(n))
return np.einsum(op, ss_from, ss_to).conjugate()
2 changes: 1 addition & 1 deletion test/mps/test_canonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_canonical():
tensors.append(np.random.random((d, 2, 1)))
mps = q.mps.MPS(tensors)
norm = q.mps.norm(mps)
assert q.mps.check(mps)
assert q.mps.check_mps(mps)

for _ in range(n):
p = np.random.randint(n)
Expand Down
4 changes: 2 additions & 2 deletions test/mps/test_move.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_swap():
d = dn
tensors.append(np.random.random((d, 2, 1)))
mps = q.mps.MPS(tensors)
q.mps.check(mps)
q.mps.check_mps(mps)
norm = q.mps.norm(mps)
v0 = q.sv.vector(q.mps.state_vector(mps))

Expand Down Expand Up @@ -45,7 +45,7 @@ def test_move():
d = dn
tensors.append(np.random.random((d, 2, 1)))
mps = q.mps.MPS(tensors)
q.mps.check(mps)
q.mps.check_mps(mps)
norm = q.mps.norm(mps)
v0 = q.sv.vector(q.mps.state_vector(mps))

Expand Down
14 changes: 7 additions & 7 deletions test/mps/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@


def test_compact_svd():
maxn = 64
nt = 64
maxn = 16
nt = 16
for _ in range(nt):
m, n, d = np.random.randint(2, maxn, size=(3,))
A = np.random.random((m, n))
Expand All @@ -22,8 +22,8 @@ def test_compact_svd():


def test_svd_left():
maxn = 32
nt = 64
maxn = 16
nt = 16
for _ in range(nt):
m, n, p, d = np.random.randint(2, maxn, size=(4,))
T = np.random.random((m, n, p))
Expand All @@ -49,8 +49,8 @@ def test_svd_left():


def test_svd_right():
maxn = 32
nt = 64
maxn = 16
nt = 16
for _ in range(nt):
m, n, p, d = np.random.randint(2, maxn, size=(4,))
T = np.random.random((m, n, p))
Expand All @@ -77,7 +77,7 @@ def test_svd_right():

def test_svd_two():
maxn = 8
nt = 64
nt = 16
for _ in range(nt):
m, n, p, r, d = np.random.randint(2, maxn, size=(5,))
T = np.random.random((m, n, p, r))
Expand Down

0 comments on commit d66e2e2

Please sign in to comment.