Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect duplicate measurement keys #1687 #1862

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions cirq/google/sim/xmon_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ def find_measurement_keys(circuit: circuits.Circuit) -> Set[str]:
for _, _, gate in circuit.findall_operations_with_gate_type(
ops.MeasurementGate):
key = protocols.measurement_key(gate)
if key in keys:
raise ValueError('Repeated Measurement key {}'.format(key))
keys.add(key)
return keys

Expand Down
21 changes: 3 additions & 18 deletions cirq/google/xmon_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datetime import timedelta
from typing import Iterable, cast, Optional, List, Union, TYPE_CHECKING

from cirq import circuits, devices, ops, protocols, value
from cirq import circuits, devices, ops, value
from cirq.google import convert_to_xmon_gates
from cirq.devices.grid_qubit import GridQubit

Expand Down Expand Up @@ -145,7 +145,6 @@ def validate_scheduled_operation(self, schedule, scheduled_operation):

def validate_circuit(self, circuit: circuits.Circuit):
super().validate_circuit(circuit)
_verify_unique_measurement_keys(circuit.all_operations())

def validate_moment(self, moment: ops.Moment):
super().validate_moment(moment)
Expand Down Expand Up @@ -173,8 +172,6 @@ def can_add_operation_into_moment(self,
return True

def validate_schedule(self, schedule):
_verify_unique_measurement_keys(
s.operation for s in schedule.scheduled_operations)
for scheduled_operation in schedule.scheduled_operations:
self.validate_scheduled_operation(schedule, scheduled_operation)

Expand Down Expand Up @@ -215,17 +212,5 @@ def __str__(self):
use_unicode_characters=True)

def _value_equality_values_(self):
return (self._measurement_duration,
self._exp_w_duration,
self._exp_z_duration,
self.qubits)


def _verify_unique_measurement_keys(operations: Iterable[ops.Operation]):
seen = set() # type: Set[str]
for op in operations:
if protocols.is_measurement(op):
key = protocols.measurement_key(op)
if key in seen:
raise ValueError('Measurement key {} repeated'.format(key))
seen.add(key)
return (self._measurement_duration, self._exp_w_duration,
self._exp_z_duration, self.qubits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised the formatter lets lack of trailing newlines through

25 changes: 0 additions & 25 deletions cirq/google/xmon_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,31 +227,6 @@ def test_validate_scheduled_operation_not_adjacent_exp_11_exp_w():
d.validate_schedule(s)


def test_validate_circuit_repeat_measurement_keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this test removed with no test being added to e.g. circuit_test.py to compensate? What is checking that we detect dupes?

d = square_device(3, 3)

circuit = cirq.Circuit()
circuit.append([cirq.measure(cirq.GridQubit(0, 0), key='a'),
cirq.measure(cirq.GridQubit(0, 1), key='a')])

with pytest.raises(ValueError, match='Measurement key a repeated'):
d.validate_circuit(circuit)


def test_validate_schedule_repeat_measurement_keys():
d = square_device(3, 3)

s = cirq.Schedule(d, [
cirq.ScheduledOperation.op_at_on(
cirq.measure(cirq.GridQubit(0, 0), key='a'), cirq.Timestamp(), d),
cirq.ScheduledOperation.op_at_on(
cirq.measure(cirq.GridQubit(0, 1), key='a'), cirq.Timestamp(), d),
])

with pytest.raises(ValueError, match='Measurement key a repeated'):
d.validate_schedule(s)


def test_xmon_device_eq():
eq = cirq.testing.EqualsTester()
eq.make_equality_group(lambda: square_device(3, 3))
Expand Down
22 changes: 3 additions & 19 deletions cirq/ion/ion_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datetime import timedelta
from typing import cast, Iterable, Optional, Union, TYPE_CHECKING

from cirq import circuits, value, devices, ops, protocols
from cirq import circuits, value, devices, ops
from cirq.line import LineQubit
from cirq.ion import convert_to_ion_gates

Expand Down Expand Up @@ -123,7 +123,6 @@ def validate_scheduled_operation(self, schedule, scheduled_operation):

def validate_circuit(self, circuit: circuits.Circuit):
super().validate_circuit(circuit)
_verify_unique_measurement_keys(circuit.all_operations())

def can_add_operation_into_moment(self,
operation: ops.Operation,
Expand All @@ -138,8 +137,6 @@ def can_add_operation_into_moment(self,
return True

def validate_schedule(self, schedule):
_verify_unique_measurement_keys(
s.operation for s in schedule.scheduled_operations)
for scheduled_operation in schedule.scheduled_operations:
self.validate_scheduled_operation(schedule, scheduled_operation)

Expand Down Expand Up @@ -180,18 +177,5 @@ def __str__(self):
use_unicode_characters=True)

def _value_equality_values_(self):
return (self._measurement_duration,
self._twoq_gates_duration,
self._oneq_gates_duration,
self.qubits)


def _verify_unique_measurement_keys(operations: Iterable[ops.Operation]):
seen = set() # type: Set[str]
for op in operations:
meas = ops.op_gate_of_type(op, ops.MeasurementGate)
if meas:
key = protocols.measurement_key(meas)
if key in seen:
raise ValueError('Measurement key {} repeated'.format(key))
seen.add(key)
return (self._measurement_duration, self._twoq_gates_duration,
self._oneq_gates_duration, self.qubits)
25 changes: 0 additions & 25 deletions cirq/ion/ion_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,31 +204,6 @@ def test_ion_device_eq():
lambda: ion_device(4))


def test_validate_circuit_repeat_measurement_keys():
d = ion_device(3)

circuit = cirq.Circuit()
circuit.append([cirq.measure(cirq.LineQubit(0), key='a'),
cirq.measure(cirq.LineQubit(1), key='a')])

with pytest.raises(ValueError, match='Measurement key a repeated'):
d.validate_circuit(circuit)


def test_validate_schedule_repeat_measurement_keys():
d = ion_device(3)

s = cirq.Schedule(d, [
cirq.ScheduledOperation.op_at_on(
cirq.measure(cirq.LineQubit(0), key='a'), cirq.Timestamp(), d),
cirq.ScheduledOperation.op_at_on(
cirq.measure(cirq.LineQubit(1), key='a'), cirq.Timestamp(), d),
])

with pytest.raises(ValueError, match='Measurement key a repeated'):
d.validate_schedule(s)


def test_ion_device_str():
assert str(ion_device(3)).strip() == """
0───1───2
Expand Down
13 changes: 13 additions & 0 deletions cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def run_sweep(
else program.to_circuit())
if not circuit.has_measurements():
raise ValueError("Circuit has no measurements to sample.")
_verify_unique_measurement_keys(circuit)
param_resolvers = study.to_resolvers(params)

trial_results = [] # type: List[study.TrialResult]
Expand Down Expand Up @@ -290,6 +291,7 @@ def simulate_sweep(
"""
circuit = (program if isinstance(program, circuits.Circuit)
else program.to_circuit())
_verify_unique_measurement_keys(circuit)
param_resolvers = study.to_resolvers(params)

trial_results = []
Expand Down Expand Up @@ -336,6 +338,7 @@ def simulate_moment_steps(
Iterator that steps through the simulation, simulating each
moment and returning a StepResult for each moment.
"""
_verify_unique_measurement_keys(circuit)
return self._simulator_iterator(
circuit,
study.ParamResolver(param_resolver),
Expand Down Expand Up @@ -546,3 +549,13 @@ def qubit_map(self) -> Dict[ops.Qid, int]:
the result.
"""
return self._final_simulator_state.qubit_map


def _verify_unique_measurement_keys(circuit: circuits.Circuit):
result = collections.Counter(
protocols.measurement_key(op, default=None)
for op in ops.flatten_op_tree(iter(circuit)))
result[None] = 0
duplicates = [k for k, v in result.most_common() if v > 1]
if duplicates:
raise ValueError('Measurement key {} repeated'.format(duplicates))
4 changes: 4 additions & 0 deletions cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_run_simulator_run():
expected_measurements = {'a': np.array([[1]])}
simulator._run.return_value = expected_measurements
circuit = mock.Mock(cirq.Circuit)
circuit.__iter__ = mock.Mock(return_value=iter([]))
param_resolver = mock.Mock(cirq.ParamResolver)
expected_result = cirq.TrialResult.from_single_parameter_set(
measurements=expected_measurements, params=param_resolver)
Expand All @@ -48,6 +49,7 @@ def test_run_simulator_sweeps():
expected_measurements = {'a': np.array([[1]])}
simulator._run.return_value = expected_measurements
circuit = mock.Mock(cirq.Circuit)
circuit.__iter__ = mock.Mock(return_value=iter([]))
param_resolvers = [mock.Mock(cirq.ParamResolver),
mock.Mock(cirq.ParamResolver)]
expected_results = [
Expand Down Expand Up @@ -83,6 +85,7 @@ def steps(*args, **kwargs):

simulator._simulator_iterator.side_effect = steps
circuit = mock.Mock(cirq.Circuit)
circuit.__iter__ = mock.Mock(return_value=iter([]))
param_resolver = mock.Mock(cirq.ParamResolver)
qubit_order = mock.Mock(cirq.QubitOrder)
result = simulator.simulate(program=circuit,
Expand Down Expand Up @@ -112,6 +115,7 @@ def steps(*args, **kwargs):

simulator._simulator_iterator.side_effect = steps
circuit = mock.Mock(cirq.Circuit)
circuit.__iter__ = mock.Mock(return_value=iter([]))
param_resolvers = [mock.Mock(cirq.ParamResolver),
mock.Mock(cirq.ParamResolver)]
qubit_order = mock.Mock(cirq.QubitOrder)
Expand Down