Skip to content

Commit

Permalink
Detect duplicate mesurement keys. (#2604)
Browse files Browse the repository at this point in the history
* Detect duplicate mesurement keys.

Detect duplicate mesurement keys.

* name the test properly

* Add _verify_unique_measurement_keys  mocks for alredy existing tests

* correct the error string in test

* Fix simulator test

* Fix unit tests

* format correction

* Fix broken tests in other folders

* Change to _ = cirq.sample(circuit) to signal that what's returned is deliberately ignored.
  • Loading branch information
iamvamsikrishnad authored Dec 2, 2019
1 parent 912c82a commit 07955c5
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
4 changes: 2 additions & 2 deletions cirq/sim/density_matrix_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,8 @@ def test_random_seed_does_not_modify_global_state_terminal_measurements():
def test_random_seed_does_not_modify_global_state_non_terminal_measurements():
a = cirq.NamedQubit('a')
circuit = cirq.Circuit(
cirq.X(a)**0.5, cirq.measure(a),
cirq.X(a)**0.5, cirq.measure(a))
cirq.X(a)**0.5, cirq.measure(a, key='a0'),
cirq.X(a)**0.5, cirq.measure(a, key='a1'))

sim = cirq.DensityMatrixSimulator(seed=1234)
result1 = sim.run(circuit, repetitions=50)
Expand Down
13 changes: 13 additions & 0 deletions cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def run_sweep(
if not program.has_measurements():
raise ValueError("Circuit has no measurements to sample.")

_verify_unique_measurement_keys(program)

trial_results = [] # type: List[study.TrialResult]
for param_resolver in study.to_resolvers(params):
measurements = self._run(circuit=program,
Expand Down Expand Up @@ -573,3 +575,14 @@ def _qubit_map_to_shape(qubit_map: Dict[ops.Qid, int]) -> Tuple[int, ...]:
'Invalid qubit_map. Duplicate qubit index. Map is <{!r}>.'.format(
qubit_map))
return tuple(qid_shape)


def _verify_unique_measurement_keys(circuit: circuits.Circuit):
result = collections.Counter(
protocols.measurement_key(op, default=None)
for op in ops.flatten_op_tree(iter(circuit)))
result[None] = 0
duplicates = [k for k, v in result.most_common() if v > 1]
if duplicates:
raise ValueError('Measurement key {} repeated'.format(
",".join(duplicates)))
15 changes: 15 additions & 0 deletions cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_run_simulator_run():
expected_measurements = {'a': np.array([[1]])}
simulator._run.return_value = expected_measurements
circuit = mock.Mock(cirq.Circuit)
circuit.__iter__ = mock.Mock(return_value=iter([]))
param_resolver = mock.Mock(cirq.ParamResolver)
expected_result = cirq.TrialResult.from_single_parameter_set(
measurements=expected_measurements, params=param_resolver)
Expand All @@ -48,6 +49,7 @@ def test_run_simulator_sweeps():
expected_measurements = {'a': np.array([[1]])}
simulator._run.return_value = expected_measurements
circuit = mock.Mock(cirq.Circuit)
circuit.__iter__ = mock.Mock(return_value=iter([]))
param_resolvers = [mock.Mock(cirq.ParamResolver),
mock.Mock(cirq.ParamResolver)]
expected_results = [
Expand Down Expand Up @@ -314,6 +316,19 @@ def test_simulation_trial_result_qubit_map():
assert result.qubit_map == {q[0]: 0, q[1]: 1}


def test_verify_unique_measurement_keys():
q = cirq.LineQubit.range(2)
circuit = cirq.Circuit()
circuit.append([
cirq.measure(q[0], key='a'),
cirq.measure(q[1], key='a'),
cirq.measure(q[0], key='b'),
cirq.measure(q[1], key='b')
])
with pytest.raises(ValueError, match='Measurement key a,b repeated'):
_ = cirq.sample(circuit)


def test_simulate_with_invert_mask():

class PlusGate(cirq.Gate):
Expand Down
4 changes: 2 additions & 2 deletions cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,8 @@ def test_random_seed_does_not_modify_global_state_terminal_measurements():
def test_random_seed_does_not_modify_global_state_non_terminal_measurements():
a = cirq.NamedQubit('a')
circuit = cirq.Circuit(
cirq.X(a)**0.5, cirq.measure(a),
cirq.X(a)**0.5, cirq.measure(a))
cirq.X(a)**0.5, cirq.measure(a, key='a0'),
cirq.X(a)**0.5, cirq.measure(a, key='a1'))

sim = cirq.Simulator(seed=1234)
result1 = sim.run(circuit, repetitions=50)
Expand Down

0 comments on commit 07955c5

Please sign in to comment.