Skip to content

Commit

Permalink
contraction graph
Browse files Browse the repository at this point in the history
  • Loading branch information
wistaria committed Nov 8, 2023
1 parent 06def83 commit 6f4233b
Show file tree
Hide file tree
Showing 19 changed files with 410 additions and 169 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ build
**/*.py[cod]
**/__pycache__
**/*.egg-info
**/test_mps_t-generator.pkl
**/test_mps_t-graph.json
**/test_mps_t-tensor.pkl
3 changes: 2 additions & 1 deletion src/qailo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import alg, mps, mps_p, operator, state_vector, util
from . import alg, mps, mps_p, mps_t, operator, state_vector, util
from . import operator as op
from . import state_vector as sv
from ._version import version
Expand All @@ -8,6 +8,7 @@
alg,
mps,
mps_p,
mps_t,
operator,
state_vector,
util,
Expand Down
2 changes: 1 addition & 1 deletion src/qailo/mps/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def _swap_tensors(m, s, maxdim=None):
"""
swap neighboring two tensors at s and s+1
"""
assert s in range(0, len(m.tensors) - 1)
assert s in range(0, mps.num_qubits(m) - 1)
m._apply_two(swap(), s, maxdim=maxdim)
p0, p1 = m.t2q[s], m.t2q[s + 1]
m.q2t[p0], m.q2t[p1] = s + 1, s
Expand Down
6 changes: 5 additions & 1 deletion src/qailo/mps/mps_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from ..operator import type as op
from .svd import tensor_svd
from .type import MPS


class MPS_C:
class MPS_C(MPS):
"""
MPS representation of quantum pure state
Expand All @@ -29,6 +30,9 @@ def __init__(self, tensors):
self.t2q = list(range(n))
self.cp = [0, n - 1]

def _tensor(self, t):
return self.tensors[t]

def _canonicalize(self, p0, p1=None):
p1 = p0 if p1 is None else p1
n = len(self.tensors)
Expand Down
4 changes: 2 additions & 2 deletions src/qailo/mps/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
def norm(m):
A = np.identity(1)
for t in range(num_qubits(m)):
A = np.einsum("ij,jkl->ikl", A, m.tensors[t])
A = np.einsum("ijk,ijl->kl", A, m.tensors[t].conj())
A = np.einsum("ij,jkl->ikl", A, m._tensor(t))
A = np.einsum("ijk,ijl->kl", A, m._tensor(t).conj())
return np.sqrt(np.trace(A))
4 changes: 2 additions & 2 deletions src/qailo/mps/product_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..dispatch import num_qubits
from ..state_vector.state_vector import one as sv_one
from ..state_vector.state_vector import zero as sv_zero
from ..state_vector.type import num_qubits
from ..state_vector.vector import vector
from .mps_c import MPS_C
from .svd import tensor_svd
Expand All @@ -9,7 +9,7 @@

def tensor_decomposition(v, nkeep=None, tol=1e-12):
if is_mps(v):
return v.tensors
return [v._tensor(s) for s in range(num_qubits(v))]
else:
n = num_qubits(v)
w = vector(v).reshape((1, 2**n))
Expand Down
4 changes: 2 additions & 2 deletions src/qailo/mps/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
def state_vector(m):
assert is_mps(m)
n = num_qubits(m)
v = m.tensors[0]
v = m._tensor(0)
for t in range(1, n):
ss0 = list(range(t + 1)) + [t + 3]
ss1 = [t + 3, t + 1, t + 2]
v = np.einsum(v, ss0, m.tensors[t], ss1)
v = np.einsum(v, ss0, m._tensor(t), ss1)
v = v.reshape((2,) * n)
return np.einsum(v, m.t2q).reshape((2,) * n + (1,))
9 changes: 7 additions & 2 deletions src/qailo/mps/type.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
class MPS(object):
def __init__(self):
True


def is_canonical(m):
return m._is_canonical()


def is_mps(m):
return hasattr(m, "tensors")
return isinstance(m, MPS)


def num_qubits(m):
return len(m.tensors)
return len(m.q2t)
24 changes: 14 additions & 10 deletions src/qailo/mps_p/mps_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import numpy as np

from ..mps.svd import tensor_svd
from ..mps.type import MPS
from ..operator import type as op
from .projector import compact_projector


class MPS_P:
class MPS_P(MPS):
"""
MPS representation of quantum pure state
Expand All @@ -31,22 +32,25 @@ def __init__(self, tensors):
self.cp = [0, n - 1]
# canonicalization matrices
# put sentinels (1x1 identities) at t = 0 and t = n
self.cmat = [np.identity(1)] + [None] * (n - 1) + [np.identity(1)]
self.env = [np.identity(1)] + [None] * (n - 1) + [np.identity(1)]

def _tensor(self, t):
return self.tensors[t]

def _canonicalize(self, p0, p1=None):
p1 = p0 if p1 is None else p1
n = len(self.tensors)
assert 0 <= p0 and p0 <= p1 and p1 < n
if self.cp[0] < p0:
for t in range(self.cp[0], p0):
A = np.einsum(self.cmat[t], [0, 3], self.tensors[t], [3, 1, 2])
_, self.cmat[t + 1] = tensor_svd(A, [[0, 1], [2]], "left")
A = np.einsum(self.env[t], [0, 3], self.tensors[t], [3, 1, 2])
_, self.env[t + 1] = tensor_svd(A, [[0, 1], [2]], "left")
self.cp[0] = p0
self.cp[1] = max(p0, self.cp[1])
if self.cp[1] > p1:
for t in range(self.cp[1], p1, -1):
A = np.einsum(self.tensors[t], [0, 1, 3], self.cmat[t + 1], [3, 2])
self.cmat[t], _ = tensor_svd(A, [[0], [1, 2]], "right")
A = np.einsum(self.tensors[t], [0, 1, 3], self.env[t + 1], [3, 2])
self.env[t], _ = tensor_svd(A, [[0], [1, 2]], "right")
self.cp[1] = p1

def _is_canonical(self):
Expand Down Expand Up @@ -76,14 +80,14 @@ def _is_canonical(self):
for t in range(0, self.cp[0]):
A = np.einsum(A, [0, 3], self.tensors[t], [3, 1, 2])
A = np.einsum(A, [2, 3, 1], self.tensors[t].conj(), [2, 3, 0])
B = np.einsum(self.cmat[t + 1], [2, 1], self.cmat[t + 1].conj(), [2, 0])
B = np.einsum(self.env[t + 1], [2, 1], self.env[t + 1].conj(), [2, 0])
assert A.shape == B.shape
assert np.allclose(A, B)
A = np.identity(1)
for t in range(n - 1, self.cp[1], -1):
A = np.einsum(self.tensors[t], [0, 1, 3], A, [3, 2])
A = np.einsum(self.tensors[t].conj(), [1, 2, 3], A, [0, 2, 3])
B = np.einsum(self.cmat[t], [0, 2], self.cmat[t].conj(), [1, 2])
B = np.einsum(self.env[t], [0, 2], self.env[t].conj(), [1, 2])
assert np.allclose(A, B)
return True

Expand All @@ -107,8 +111,8 @@ def _apply_two(self, p, s, maxdim=None, reverse=False):
else:
t0 = np.einsum(t0, [0, 4, 3], p1, [2, 1, 4])
t1 = np.einsum(t1, [0, 4, 3], p0, [2, 4, 1])
tt0 = np.einsum(self.cmat[s], [0, 4], t0, [4, 1, 2, 3])
tt1 = np.einsum(t1, [0, 1, 2, 4], self.cmat[s + 2], [4, 3])
tt0 = np.einsum(self.env[s], [0, 4], t0, [4, 1, 2, 3])
tt1 = np.einsum(t1, [0, 1, 2, 4], self.env[s + 2], [4, 3])
_, WLh, WR = compact_projector(tt0, [0, 1, 4, 5], tt1, [5, 4, 2, 3], maxdim)
self.tensors[s] = np.einsum(t0, [0, 1, 3, 4], WR, [3, 4, 2])
self.tensors[s + 1] = np.einsum(WLh, [3, 4, 0], t1, [4, 3, 1, 2])
4 changes: 2 additions & 2 deletions src/qailo/mps_p/product_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def product_state(states):
return MPS_P(tensors)


def zero(n=1, logger=None):
def zero(n=1):
return product_state([sv_zero()] * n)


def one(n=1, logger=None):
def one(n=1):
return product_state([sv_one()] * n)
9 changes: 9 additions & 0 deletions src/qailo/mps_t/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .mps_t import MPS_T
from .product_state import one, product_state, zero

__all__ = [
MPS_T,
one,
product_state,
zero,
]
193 changes: 193 additions & 0 deletions src/qailo/mps_t/mps_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from copy import deepcopy

import numpy as np

from ..mps.svd import tensor_svd
from ..mps.type import MPS
from ..mps_p.projector import _full_projector
from ..operator import type as op


class MPS_T(MPS):
"""
MPS representation of quantum pure state
shape of tensors: [du, dp, dl]
du: dimension of upper leg (1 for top tensor)
dp: dimension of physical leg (typically 2)
dl: dimension of lower leg (1 for bottom tensor)
canonical position: cp in range(n)
0 <= cp(0) <= cp(1) < n
tensors [0...cp(0)-1]: top canonical
tensors [cp(1)+1...n-1]: bottom canonical
"""

def __init__(self, tensors):
assert isinstance(tensors, list)
n = len(tensors)
self.tpool = []
self.tcurrent = []
for t in tensors:
self.tcurrent.append(self._register(deepcopy(t)))
self.gpool = []
self.q2t = list(range(n))
self.t2q = list(range(n))
self.cp = [0, n - 1]
# canonicalization matrices
# put sentinels (1x1 identities) at t = 0 and t = n
self.env = [np.identity(1)] + [None] * (n - 1) + [np.identity(1)]

def _num_qubits(self):
return len(self.tcurrent)

def _tensor(self, t):
return self.tpool[self.tcurrent[t]][0]

def _canonicalize(self, p0, p1=None):
p1 = p0 if p1 is None else p1
assert 0 <= p0 and p0 <= p1 and p1 < self._num_qubits()
if self.cp[0] < p0:
for t in range(self.cp[0], p0):
A = np.einsum(self.env[t], [0, 3], self._tensor(t), [3, 1, 2])
_, self.env[t + 1] = tensor_svd(A, [[0, 1], [2]], "left")
self.cp[0] = p0
self.cp[1] = max(p0, self.cp[1])
if self.cp[1] > p1:
for t in range(self.cp[1], p1, -1):
A = np.einsum(self._tensor(t), [0, 1, 3], self.env[t + 1], [3, 2])
self.env[t], _ = tensor_svd(A, [[0], [1, 2]], "right")
self.cp[1] = p1

def _is_canonical(self):
# tensor shape
n = len(self.tcurrent)
dims = []
assert self._tensor(0).shape[0] == 1
dims.append(self._tensor(0).shape[0])
for t in range(1, n - 1):
dims.append(self._tensor(t).shape[0])
assert self._tensor(t).shape[0] == self._tensor(t - 1).shape[2]
assert self._tensor(t).shape[2] == self._tensor(t + 1).shape[0]
assert self._tensor(n - 1).shape[2] == 1
dims.append(self._tensor(n - 1).shape[0])
dims.append(self._tensor(n - 1).shape[2])

# qubit <-> tensor mapping
for q in range(n):
assert self.t2q[self.q2t[q]] == q
for t in range(n):
assert self.q2t[self.t2q[t]] == t

# canonicality
assert self.cp[0] in range(n)
assert self.cp[1] in range(n)
A = np.identity(1)
for t in range(0, self.cp[0]):
A = np.einsum(A, [0, 3], self._tensor(t), [3, 1, 2])
A = np.einsum(A, [2, 3, 1], self._tensor(t).conj(), [2, 3, 0])
B = np.einsum(self.env[t + 1], [2, 1], self.env[t + 1].conj(), [2, 0])
assert A.shape == B.shape
assert np.allclose(A, B)
A = np.identity(1)
for t in range(n - 1, self.cp[1], -1):
A = np.einsum(self._tensor(t), [0, 1, 3], A, [3, 2])
A = np.einsum(self._tensor(t).conj(), [1, 2, 3], A, [0, 2, 3])
B = np.einsum(self.env[t], [0, 2], self.env[t].conj(), [1, 2])
assert np.allclose(A, B)
return True

def _apply_one(self, p, s):
assert op.num_qubits(p) == 1
pid = self._register(p)
self.tcurrent[s] = self._contract(self.tcurrent[s], [0, 3, 2], pid, [1, 3])
self.cp[0] = min(self.cp[0], s)
self.cp[1] = max(self.cp[1], s)

def _apply_two(self, p, s, maxdim=None, reverse=False):
"""
apply 2-qubit operator on neighboring tensors, s and s+1
"""
self._canonicalize(s, s + 1)
tid0 = self.tcurrent[s]
tid1 = self.tcurrent[s + 1]
p0, p1 = tensor_svd(p, [[0, 2], [1, 3]])
pid0 = self._register(p0)
pid1 = self._register(p1)
if not reverse:
tid0 = self._contract(tid0, [0, 4, 3], pid0, [1, 4, 2])
tid1 = self._contract(tid1, [0, 4, 3], pid1, [1, 2, 4])
else:
tid0 = self._contract(tid0, [0, 4, 3], pid1, [2, 1, 4])
tid1 = self._contract(tid1, [0, 4, 3], pid0, [2, 4, 1])
tt0 = np.einsum(self.env[s], [0, 4], self.tpool[tid0][0], [4, 1, 2, 3])
tt1 = np.einsum(self.tpool[tid1][0], [0, 1, 2, 4], self.env[s + 2], [4, 3])
_, WLhid, WRid = self._projector(tt0, [0, 1, 4, 5], tt1, [5, 4, 2, 3], maxdim)
self.tcurrent[s] = self._contract(tid0, [0, 1, 3, 4], WRid, [3, 4, 2])
self.tcurrent[s + 1] = self._contract(WLhid, [3, 4, 0], tid1, [4, 3, 1, 2])

def _register(self, tensor):
id = len(self.tpool)
self.tpool.append([tensor, "initial", None])
return id

def _contract(self, tid0, ss0, tid1, ss1, ss2=None):
id = len(self.tpool)
if ss2 is None:
t = np.einsum(self.tpool[tid0][0], ss0, self.tpool[tid1][0], ss1)
else:
np.einsum(self.tpool[tid0][0], ss0, self.tpool[tid1][0], ss1, ss2)
self.tpool.append([t, "product", [tid0, tid1]])
return id

def _projector(self, t0, ss0, t1, ss1, maxdim=None):
S, d, WLh, WR = _full_projector(t0, ss0, t1, ss1)
d = d if maxdim is None else min(d, maxdim)
gid = len(self.gpool)
self.gpool.append([S, d, WLh, WR])
lid = len(self.tpool)
shape = WLh.shape
WLh = WLh.reshape((np.prod(shape[:-1]), shape[-1]))
WLh = WLh[:, :d].reshape(shape[:-1] + (d,))
self.tpool.append([WLh, "squeezer", [gid]])
rid = len(self.tpool)
shape = WR.shape
WR = WR.reshape((np.prod(shape[:-1]), shape[-1]))
WR = WR[:, :d].reshape(shape[:-1] + (d,))
self.tpool.append([WR, "squeezer", [gid]])
return gid, lid, rid

def _dump(self, prefix):
import json
import pickle

dic = {"prefix": prefix}
tlist = []
with open(f"{prefix}-tensor.pkl", "wb") as f:
for id, tp in enumerate(self.tpool):
m = {}
m["id"] = id
m["shape"] = tp[0].shape
m["type"] = tp[1]
m["from"] = tp[2]
tlist.append(m)
if tp[1] == "initial":
pickle.dump(tp[0], f)
dic["tensor"] = tlist
glist = []
with open(f"{prefix}-generator.pkl", "wb") as f:
for id, gp in enumerate(self.gpool):
m = {}
m["id"] = id
m["d"] = gp[1]
m["shape S"] = gp[0].shape
m["shape L"] = gp[2].shape
m["shape R"] = gp[3].shape
glist.append(m)
pickle.dump(gp[0], f)
pickle.dump(gp[2], f)
pickle.dump(gp[3], f)
dic["generator"] = glist

with open(prefix + "-graph.json", mode="w") as f:
json.dump(dic, f, indent=2)
Loading

0 comments on commit 6f4233b

Please sign in to comment.