Skip to content

Commit

Permalink
update fitting tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Dec 18, 2024
1 parent 95fccd1 commit c6d0bac
Showing 1 changed file with 61 additions and 45 deletions.
106 changes: 61 additions & 45 deletions tests/test_tensor/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@


@pytest.mark.parametrize("method", ("auto", "dense", "overlap"))
@pytest.mark.parametrize("normalized", (
True,
False,
"squared",
"infidelity",
"infidelity_sqrt",
))
@pytest.mark.parametrize(
"normalized",
(
True,
False,
"squared",
"infidelity",
"infidelity_sqrt",
),
)
def test_tensor_network_distance(method, normalized):
n = 6
A = qtn.TN_rand_reg(n=n, reg=3, D=2, phys_dim=2, dtype=complex)
Expand All @@ -34,87 +37,100 @@ def test_tensor_network_distance(method, normalized):


@pytest.mark.parametrize(
"method,opts",
(
("als", (("enforce_pos", False), ("solver", "lstsq"))),
("als", (("enforce_pos", True),)),
("tree", ()),
"opts",
[
dict(method="als", dense_solve=False),
dict(method="als", dense_solve=False, solver="lgmres"),
dict(
method="als",
dense_solve=True,
enforce_pos=False,
solver_dense="lstsq",
),
dict(method="als", dense_solve=True, enforce_pos=True),
dict(method="tree"),
pytest.param(
"autodiff",
(("distance_method", "dense"),),
dict(method="autodiff", distance_method="dense"),
marks=requires_autograd,
),
pytest.param(
"autodiff",
(("distance_method", "overlap"),),
dict(method="autodiff", distance_method="overlap"),
marks=requires_autograd,
),
),
],
)
@pytest.mark.parametrize("dtype", ("float64", "complex128"))
def test_fit_mps(method, opts, dtype):
def test_fit_mps(opts, dtype):
k1 = qtn.MPS_rand_state(5, 3, seed=666, dtype=dtype)
k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype)
assert k1.distance_normalized(k2) > 1e-3
k1.fit_(k2, method=method, progbar=True, **dict(opts))
k1.fit_(k2, progbar=True, **dict(opts))
assert k1.distance_normalized(k2) < 1e-3


@pytest.mark.parametrize(
"method,opts",
(
("als", (("enforce_pos", False),)),
("als", (("enforce_pos", True),)),
"opts",
[
dict(method="als", dense_solve=False),
dict(method="als", dense_solve=False, solver="lgmres"),
dict(
method="als",
dense_solve=True,
enforce_pos=False,
solver_dense="lstsq",
),
dict(method="als", dense_solve=True, enforce_pos=True),
pytest.param(
"autodiff",
(("distance_method", "dense"),),
dict(method="autodiff", distance_method="dense"),
marks=requires_autograd,
),
pytest.param(
"autodiff",
(("distance_method", "overlap"),),
dict(method="autodiff", distance_method="overlap"),
marks=requires_autograd,
),
),
],
)
@pytest.mark.parametrize("dtype", ("float64", "complex128"))
def test_fit_rand_reg(method, opts, dtype):
def test_fit_rand_reg(opts, dtype):
r1 = qtn.TN_rand_reg(5, 4, D=2, seed=666, phys_dim=2, dtype=dtype)
k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype)
assert r1.distance(k2) > 1e-3
r1.fit_(k2, method=method, progbar=True, **dict(opts))
r1.fit_(k2, progbar=True, **dict(opts))
assert r1.distance(k2) < 1e-3


@pytest.mark.parametrize(
"method,opts",
(
("als", (("enforce_pos", False),)),
("als", (("enforce_pos", True),)),
("tree", ()),
"opts",
[
dict(method="als", dense_solve=False),
dict(method="als", dense_solve=False, solver="lgmres"),
dict(
method="als",
dense_solve=True,
enforce_pos=False,
solver_dense="lstsq",
),
dict(method="als", dense_solve=True, enforce_pos=True),
dict(method="tree"),
pytest.param(
"autodiff",
(("distance_method", "dense"),),
dict(method="autodiff", distance_method="dense"),
marks=requires_autograd,
),
pytest.param(
"autodiff",
(("distance_method", "overlap"),),
dict(method="autodiff", distance_method="overlap"),
marks=requires_autograd,
),
),
],
)
@pytest.mark.parametrize("dtype", ("float64", "complex128"))
def test_fit_partial_tags(method, opts, dtype):
def test_fit_partial_tags(opts, dtype):
k1 = qtn.MPS_rand_state(5, 3, seed=666, dtype=dtype)
k2 = qtn.MPS_rand_state(5, 3, seed=667, dtype=dtype)
d0 = k1.distance(k2)
tags = ["I0", "I2", "I4"]
k1f = k1.fit(
k2, tol=1e-3, tags=tags, method=method, progbar=True, **dict(opts)
)
k1f = k1.fit(k2, tol=1e-3, tags=tags, progbar=True, **dict(opts))
assert k1f.distance(k2) < d0
if method != "tree":
if opts["method"] != "tree":
assert (k1f[0] - k1[0]).norm() > 1e-12
assert (k1f[1] - k1[1]).norm() < 1e-12
assert (k1f[2] - k1[2]).norm() > 1e-12
Expand Down

0 comments on commit c6d0bac

Please sign in to comment.