Skip to content

Commit

Permalink
tpool
Browse files Browse the repository at this point in the history
  • Loading branch information
wistaria committed Nov 10, 2023
1 parent 6f4233b commit a0d2ccc
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 89 deletions.
103 changes: 19 additions & 84 deletions src/qailo/mps_t/mps_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from ..mps.svd import tensor_svd
from ..mps.type import MPS
from ..mps_p.projector import _full_projector
from ..operator import type as op
from .tpool import tpool


class MPS_T(MPS):
Expand All @@ -26,11 +26,10 @@ class MPS_T(MPS):
def __init__(self, tensors):
assert isinstance(tensors, list)
n = len(tensors)
self.tpool = []
self.tp = tpool()
self.tcurrent = []
for t in tensors:
self.tcurrent.append(self._register(deepcopy(t)))
self.gpool = []
self.tcurrent.append(self.tp._register(deepcopy(t)))
self.q2t = list(range(n))
self.t2q = list(range(n))
self.cp = [0, n - 1]
Expand All @@ -42,7 +41,7 @@ def _num_qubits(self):
return len(self.tcurrent)

def _tensor(self, t):
return self.tpool[self.tcurrent[t]][0]
return self.tp.tpool[self.tcurrent[t]][0]

def _canonicalize(self, p0, p1=None):
p1 = p0 if p1 is None else p1
Expand Down Expand Up @@ -99,8 +98,8 @@ def _is_canonical(self):

def _apply_one(self, p, s):
assert op.num_qubits(p) == 1
pid = self._register(p)
self.tcurrent[s] = self._contract(self.tcurrent[s], [0, 3, 2], pid, [1, 3])
pid = self.tp._register(p)
self.tcurrent[s] = self.tp._contract(self.tcurrent[s], [0, 3, 2], pid, [1, 3])
self.cp[0] = min(self.cp[0], s)
self.cp[1] = max(self.cp[1], s)

Expand All @@ -112,82 +111,18 @@ def _apply_two(self, p, s, maxdim=None, reverse=False):
tid0 = self.tcurrent[s]
tid1 = self.tcurrent[s + 1]
p0, p1 = tensor_svd(p, [[0, 2], [1, 3]])
pid0 = self._register(p0)
pid1 = self._register(p1)
pid0 = self.tp._register(p0)
pid1 = self.tp._register(p1)
if not reverse:
tid0 = self._contract(tid0, [0, 4, 3], pid0, [1, 4, 2])
tid1 = self._contract(tid1, [0, 4, 3], pid1, [1, 2, 4])
tid0 = self.tp._contract(tid0, [0, 4, 3], pid0, [1, 4, 2])
tid1 = self.tp._contract(tid1, [0, 4, 3], pid1, [1, 2, 4])
else:
tid0 = self._contract(tid0, [0, 4, 3], pid1, [2, 1, 4])
tid1 = self._contract(tid1, [0, 4, 3], pid0, [2, 4, 1])
tt0 = np.einsum(self.env[s], [0, 4], self.tpool[tid0][0], [4, 1, 2, 3])
tt1 = np.einsum(self.tpool[tid1][0], [0, 1, 2, 4], self.env[s + 2], [4, 3])
_, WLhid, WRid = self._projector(tt0, [0, 1, 4, 5], tt1, [5, 4, 2, 3], maxdim)
self.tcurrent[s] = self._contract(tid0, [0, 1, 3, 4], WRid, [3, 4, 2])
self.tcurrent[s + 1] = self._contract(WLhid, [3, 4, 0], tid1, [4, 3, 1, 2])

def _register(self, tensor):
id = len(self.tpool)
self.tpool.append([tensor, "initial", None])
return id

def _contract(self, tid0, ss0, tid1, ss1, ss2=None):
id = len(self.tpool)
if ss2 is None:
t = np.einsum(self.tpool[tid0][0], ss0, self.tpool[tid1][0], ss1)
else:
np.einsum(self.tpool[tid0][0], ss0, self.tpool[tid1][0], ss1, ss2)
self.tpool.append([t, "product", [tid0, tid1]])
return id

def _projector(self, t0, ss0, t1, ss1, maxdim=None):
S, d, WLh, WR = _full_projector(t0, ss0, t1, ss1)
d = d if maxdim is None else min(d, maxdim)
gid = len(self.gpool)
self.gpool.append([S, d, WLh, WR])
lid = len(self.tpool)
shape = WLh.shape
WLh = WLh.reshape((np.prod(shape[:-1]), shape[-1]))
WLh = WLh[:, :d].reshape(shape[:-1] + (d,))
self.tpool.append([WLh, "squeezer", [gid]])
rid = len(self.tpool)
shape = WR.shape
WR = WR.reshape((np.prod(shape[:-1]), shape[-1]))
WR = WR[:, :d].reshape(shape[:-1] + (d,))
self.tpool.append([WR, "squeezer", [gid]])
return gid, lid, rid

def _dump(self, prefix):
import json
import pickle

dic = {"prefix": prefix}
tlist = []
with open(f"{prefix}-tensor.pkl", "wb") as f:
for id, tp in enumerate(self.tpool):
m = {}
m["id"] = id
m["shape"] = tp[0].shape
m["type"] = tp[1]
m["from"] = tp[2]
tlist.append(m)
if tp[1] == "initial":
pickle.dump(tp[0], f)
dic["tensor"] = tlist
glist = []
with open(f"{prefix}-generator.pkl", "wb") as f:
for id, gp in enumerate(self.gpool):
m = {}
m["id"] = id
m["d"] = gp[1]
m["shape S"] = gp[0].shape
m["shape L"] = gp[2].shape
m["shape R"] = gp[3].shape
glist.append(m)
pickle.dump(gp[0], f)
pickle.dump(gp[2], f)
pickle.dump(gp[3], f)
dic["generator"] = glist

with open(prefix + "-graph.json", mode="w") as f:
json.dump(dic, f, indent=2)
tid0 = self.tp._contract(tid0, [0, 4, 3], pid1, [2, 1, 4])
tid1 = self.tp._contract(tid1, [0, 4, 3], pid0, [2, 4, 1])
tt0 = np.einsum(self.env[s], [0, 4], self.tp.tpool[tid0][0], [4, 1, 2, 3])
tt1 = np.einsum(self.tp.tpool[tid1][0], [0, 1, 2, 4], self.env[s + 2], [4, 3])
_, WLhid, WRid = self.tp._projector(
tt0, [0, 1, 4, 5], tt1, [5, 4, 2, 3], maxdim
)
self.tcurrent[s] = self.tp._contract(tid0, [0, 1, 3, 4], WRid, [3, 4, 2])
self.tcurrent[s + 1] = self.tp._contract(WLhid, [3, 4, 0], tid1, [4, 3, 1, 2])
79 changes: 79 additions & 0 deletions src/qailo/mps_t/tpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np

from ..mps_p.projector import _full_projector


class tpool:
def __init__(self):
self.tpool = []
self.gpool = []

def _register(self, tensor):
id = len(self.tpool)
self.tpool.append([tensor, "initial", None])
return id

def _contract(self, tid0, ss0, tid1, ss1, ss2=None):
id = len(self.tpool)
if ss2 is None:
t = np.einsum(self.tpool[tid0][0], ss0, self.tpool[tid1][0], ss1)
else:
t = np.einsum(self.tpool[tid0][0], ss0, self.tpool[tid1][0], ss1, ss2)
self.tpool.append([t, "product", tid0, ss0, tid1, ss1, ss2])
return id

def _trim(self, T, d):
"""trim last dimension of tensor"""
assert d <= T.shape[-1]
shape = T.shape
T = T.reshape((np.prod(shape[:-1]), shape[-1]))
return T[:, :d].reshape(shape[:-1] + (d,))

def _projector(self, t0, ss0, t1, ss1, maxdim=None):
S, d, WLh, WR = _full_projector(t0, ss0, t1, ss1)
d = d if maxdim is None else min(d, maxdim)
gid = len(self.gpool)
self.gpool.append([S, d, WLh, WR])
lid = len(self.tpool)
self.tpool.append([self._trim(WLh, d), "squeezer L", gid])
rid = len(self.tpool)
self.tpool.append([self._trim(WR, d), "squeezer R", gid])
return gid, lid, rid

def _dump(self, prefix):
import json

dic = {"prefix": prefix}
tlist = []
for id, tp in enumerate(self.tpool):
m = {}
m["id"] = id
m["type"] = tp[1]
m["shape"] = tp[0].shape
if tp[1] == "product":
m["from 0"] = tp[2]
m["subscripts 0"] = tp[3]
m["from 1"] = tp[4]
m["subscripts 1"] = tp[5]
if tp[6] is not None:
m["subscripts"] = tp[6]
elif tp[1] == "squeezer L" or tp[1] == "squeezer R":
m["from"] = tp[2]
tlist.append(m)
if tp[1] == "initial":
np.save(f"{prefix}-tensor-{id}", tp[0])
dic["tensor"] = tlist
glist = []
for id, gp in enumerate(self.gpool):
m = {}
m["id"] = id
m["dim from"] = gp[0].shape[0]
m["dim to"] = gp[1]
m["shape L"] = gp[2].shape
m["shape R"] = gp[3].shape
glist.append(m)
np.savez(f"{prefix}-generator-{id}", gp[0], gp[1], gp[2])
dic["generator"] = glist

with open(prefix + "-graph.json", mode="w") as f:
json.dump(dic, f, indent=2)
10 changes: 5 additions & 5 deletions test/mps_t/test_mps_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ def test_mps_t():
print("probabitily:", q.probability(v))

print("# tensor pool")
for id, tp in enumerate(v.tpool):
for id, tp in enumerate(v.tp.tpool):
print(f"{id} {tp[0].shape} {tp[1]} {tp[2]}")
assert len(v.tpool) == 25
assert len(v.tp.tpool) == 25

print("# generator pool")
for id, gp in enumerate(v.gpool):
for id, gp in enumerate(v.tp.gpool):
print(f"{id} {gp[0].shape} {gp[1]} {gp[2].shape} {gp[3].shape}")
assert len(v.gpool) == 2
assert len(v.tp.gpool) == 2

prefix = "test_mps_t"
v._dump(prefix)
v.tp._dump(prefix)


if __name__ == "__main__":
Expand Down

0 comments on commit a0d2ccc

Please sign in to comment.