Skip to content

Commit

Permalink
Merge pull request #248 from kevinsung/sample-seed
Browse files Browse the repository at this point in the history
fix MPS sample handling of RNG seed
  • Loading branch information
jcmgray authored Jul 20, 2024
2 parents 1960e05 + dd1f7e2 commit 2f41cf9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
4 changes: 3 additions & 1 deletion quimb/tensor/tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from math import log, log2
from numbers import Integral

import numpy as np
import scipy.sparse.linalg as spla
from autoray import conj, dag, do, get_dtype_name, reshape, size, transpose

Expand Down Expand Up @@ -3266,8 +3267,9 @@ def sample(self, C, seed=None, info=None):
# do right canonicalization once (supplying info avoids re-performing)
psi0 = self.canonicalize(0, info=info)

rng = np.random.default_rng(seed)
for _ in range(C):
yield psi0.sample_configuration(seed=seed, info=info)
yield psi0.sample_configuration(seed=rng, info=info)

class MatrixProductOperator(TensorNetwork1DOperator, TensorNetwork1DFlat):
"""Initialise a matrix product operator, with auto labelling and tagging.
Expand Down
14 changes: 14 additions & 0 deletions tests/test_tensor/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ def test_mps_sampling(self):
for x in circ.sample(10):
assert x in {"000010", "111101"}

def test_mps_sampling_seed(self):
N = 1
circ = qtn.CircuitMPS(N)
circ.h(0)
samples = list(circ.sample(10, seed=1234))
assert len(set(samples)) == 2

def test_permmps_sampling(self):
N = 6
circ = qtn.CircuitPermMPS(N)
Expand All @@ -710,6 +717,13 @@ def test_permmps_sampling(self):
for x in circ.sample(10):
assert x in {"000010", "111101"}

def test_permmps_sampling_seed(self):
N = 1
circ = qtn.CircuitPermMPS(N)
circ.h(0)
samples = list(circ.sample(10, seed=1234))
assert len(set(samples)) == 2


class TestCircuitGen:
@pytest.mark.parametrize(
Expand Down
8 changes: 8 additions & 0 deletions tests/test_tensor/test_tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,14 @@ def test_sample_configuration(self):
).contract()
) ** 2 == pytest.approx(omega)

def test_sample_seed(self):
psi = qtn.MPS_rand_state(10, 7)
configs = [
"".join(map(str, config))
for config, _ in psi.sample(10, seed=1234)
]
assert len(set(configs)) > 1


class TestMatrixProductOperator:
@pytest.mark.parametrize("cyclic", [False, True])
Expand Down

0 comments on commit 2f41cf9

Please sign in to comment.