From 76e643864be06ddd70523a6649b70c970013ef64 Mon Sep 17 00:00:00 2001 From: Vamsi Krishna Devabathini Date: Tue, 23 Jul 2019 19:08:01 +0530 Subject: [PATCH 1/2] Detect duplicate measurement keys --- cirq/google/sim/xmon_simulator.py | 2 -- cirq/google/xmon_device.py | 11 ----------- cirq/google/xmon_device_test.py | 25 ------------------------- cirq/ion/ion_device.py | 13 ------------- cirq/ion/ion_device_test.py | 25 ------------------------- cirq/sim/simulator.py | 17 +++++++++++++++++ cirq/sim/simulator_test.py | 4 ++++ 7 files changed, 21 insertions(+), 76 deletions(-) diff --git a/cirq/google/sim/xmon_simulator.py b/cirq/google/sim/xmon_simulator.py index 65e16040ee1..a5b934c5a87 100644 --- a/cirq/google/sim/xmon_simulator.py +++ b/cirq/google/sim/xmon_simulator.py @@ -295,8 +295,6 @@ def find_measurement_keys(circuit: circuits.Circuit) -> Set[str]: for _, _, gate in circuit.findall_operations_with_gate_type( ops.MeasurementGate): key = protocols.measurement_key(gate) - if key in keys: - raise ValueError('Repeated Measurement key {}'.format(key)) keys.add(key) return keys diff --git a/cirq/google/xmon_device.py b/cirq/google/xmon_device.py index 160a1d11ed2..b0abcf56917 100644 --- a/cirq/google/xmon_device.py +++ b/cirq/google/xmon_device.py @@ -145,7 +145,6 @@ def validate_scheduled_operation(self, schedule, scheduled_operation): def validate_circuit(self, circuit: circuits.Circuit): super().validate_circuit(circuit) - _verify_unique_measurement_keys(circuit.all_operations()) def validate_moment(self, moment: ops.Moment): super().validate_moment(moment) @@ -173,8 +172,6 @@ def can_add_operation_into_moment(self, return True def validate_schedule(self, schedule): - _verify_unique_measurement_keys( - s.operation for s in schedule.scheduled_operations) for scheduled_operation in schedule.scheduled_operations: self.validate_scheduled_operation(schedule, scheduled_operation) @@ -221,11 +218,3 @@ def _value_equality_values_(self): self.qubits) -def _verify_unique_measurement_keys(operations: Iterable[ops.Operation]): - seen = set() # type: Set[str] - for op in operations: - if protocols.is_measurement(op): - key = protocols.measurement_key(op) - if key in seen: - raise ValueError('Measurement key {} repeated'.format(key)) - seen.add(key) diff --git a/cirq/google/xmon_device_test.py b/cirq/google/xmon_device_test.py index 3483afbd9b1..5dd5dcde6b5 100644 --- a/cirq/google/xmon_device_test.py +++ b/cirq/google/xmon_device_test.py @@ -227,31 +227,6 @@ def test_validate_scheduled_operation_not_adjacent_exp_11_exp_w(): d.validate_schedule(s) -def test_validate_circuit_repeat_measurement_keys(): - d = square_device(3, 3) - - circuit = cirq.Circuit() - circuit.append([cirq.measure(cirq.GridQubit(0, 0), key='a'), - cirq.measure(cirq.GridQubit(0, 1), key='a')]) - - with pytest.raises(ValueError, match='Measurement key a repeated'): - d.validate_circuit(circuit) - - -def test_validate_schedule_repeat_measurement_keys(): - d = square_device(3, 3) - - s = cirq.Schedule(d, [ - cirq.ScheduledOperation.op_at_on( - cirq.measure(cirq.GridQubit(0, 0), key='a'), cirq.Timestamp(), d), - cirq.ScheduledOperation.op_at_on( - cirq.measure(cirq.GridQubit(0, 1), key='a'), cirq.Timestamp(), d), - ]) - - with pytest.raises(ValueError, match='Measurement key a repeated'): - d.validate_schedule(s) - - def test_xmon_device_eq(): eq = cirq.testing.EqualsTester() eq.make_equality_group(lambda: square_device(3, 3)) diff --git a/cirq/ion/ion_device.py b/cirq/ion/ion_device.py index beb10e5d898..67223c37923 100644 --- a/cirq/ion/ion_device.py +++ b/cirq/ion/ion_device.py @@ -123,7 +123,6 @@ def validate_scheduled_operation(self, schedule, scheduled_operation): def validate_circuit(self, circuit: circuits.Circuit): super().validate_circuit(circuit) - _verify_unique_measurement_keys(circuit.all_operations()) def can_add_operation_into_moment(self, operation: ops.Operation, @@ -138,8 +137,6 @@ def can_add_operation_into_moment(self, return True def validate_schedule(self, schedule): - _verify_unique_measurement_keys( - s.operation for s in schedule.scheduled_operations) for scheduled_operation in schedule.scheduled_operations: self.validate_scheduled_operation(schedule, scheduled_operation) @@ -185,13 +182,3 @@ def _value_equality_values_(self): self._oneq_gates_duration, self.qubits) - -def _verify_unique_measurement_keys(operations: Iterable[ops.Operation]): - seen = set() # type: Set[str] - for op in operations: - meas = ops.op_gate_of_type(op, ops.MeasurementGate) - if meas: - key = protocols.measurement_key(meas) - if key in seen: - raise ValueError('Measurement key {} repeated'.format(key)) - seen.add(key) \ No newline at end of file diff --git a/cirq/ion/ion_device_test.py b/cirq/ion/ion_device_test.py index 304b3c8b7a8..a61cd120c1f 100644 --- a/cirq/ion/ion_device_test.py +++ b/cirq/ion/ion_device_test.py @@ -204,31 +204,6 @@ def test_ion_device_eq(): lambda: ion_device(4)) -def test_validate_circuit_repeat_measurement_keys(): - d = ion_device(3) - - circuit = cirq.Circuit() - circuit.append([cirq.measure(cirq.LineQubit(0), key='a'), - cirq.measure(cirq.LineQubit(1), key='a')]) - - with pytest.raises(ValueError, match='Measurement key a repeated'): - d.validate_circuit(circuit) - - -def test_validate_schedule_repeat_measurement_keys(): - d = ion_device(3) - - s = cirq.Schedule(d, [ - cirq.ScheduledOperation.op_at_on( - cirq.measure(cirq.LineQubit(0), key='a'), cirq.Timestamp(), d), - cirq.ScheduledOperation.op_at_on( - cirq.measure(cirq.LineQubit(1), key='a'), cirq.Timestamp(), d), - ]) - - with pytest.raises(ValueError, match='Measurement key a repeated'): - d.validate_schedule(s) - - def test_ion_device_str(): assert str(ion_device(3)).strip() == """ 0───1───2 diff --git a/cirq/sim/simulator.py b/cirq/sim/simulator.py index 956759a421e..6fc72d594f7 100644 --- a/cirq/sim/simulator.py +++ b/cirq/sim/simulator.py @@ -73,6 +73,7 @@ def run_sweep( else program.to_circuit()) if not circuit.has_measurements(): raise ValueError("Circuit has no measurements to sample.") + _verify_unique_measurement_keys(circuit) param_resolvers = study.to_resolvers(params) trial_results = [] # type: List[study.TrialResult] @@ -290,6 +291,7 @@ def simulate_sweep( """ circuit = (program if isinstance(program, circuits.Circuit) else program.to_circuit()) + _verify_unique_measurement_keys(circuit) param_resolvers = study.to_resolvers(params) trial_results = [] @@ -336,6 +338,7 @@ def simulate_moment_steps( Iterator that steps through the simulation, simulating each moment and returning a StepResult for each moment. """ + _verify_unique_measurement_keys(circuit) return self._simulator_iterator( circuit, study.ParamResolver(param_resolver), @@ -546,3 +549,17 @@ def qubit_map(self) -> Dict[ops.Qid, int]: the result. """ return self._final_simulator_state.qubit_map + + +def _verify_unique_measurement_keys(circuit: circuits.Circuit): + result = collections.Counter( + protocols.measurement_key(op, default=None) + for op in ops.flatten_op_tree(iter(circuit)) + ) + result[None] = 0 + duplicates = [k for k, v in result.most_common() if v > 1] + if duplicates: + raise ValueError( + 'Measurement key {} repeated'.format(duplicates)) + + diff --git a/cirq/sim/simulator_test.py b/cirq/sim/simulator_test.py index 778c18c6468..ce16295c373 100644 --- a/cirq/sim/simulator_test.py +++ b/cirq/sim/simulator_test.py @@ -29,6 +29,7 @@ def test_run_simulator_run(): expected_measurements = {'a': np.array([[1]])} simulator._run.return_value = expected_measurements circuit = mock.Mock(cirq.Circuit) + circuit.__iter__ = mock.Mock(return_value=iter([])) param_resolver = mock.Mock(cirq.ParamResolver) expected_result = cirq.TrialResult.from_single_parameter_set( measurements=expected_measurements, params=param_resolver) @@ -48,6 +49,7 @@ def test_run_simulator_sweeps(): expected_measurements = {'a': np.array([[1]])} simulator._run.return_value = expected_measurements circuit = mock.Mock(cirq.Circuit) + circuit.__iter__ = mock.Mock(return_value=iter([])) param_resolvers = [mock.Mock(cirq.ParamResolver), mock.Mock(cirq.ParamResolver)] expected_results = [ @@ -83,6 +85,7 @@ def steps(*args, **kwargs): simulator._simulator_iterator.side_effect = steps circuit = mock.Mock(cirq.Circuit) + circuit.__iter__ = mock.Mock(return_value=iter([])) param_resolver = mock.Mock(cirq.ParamResolver) qubit_order = mock.Mock(cirq.QubitOrder) result = simulator.simulate(program=circuit, @@ -112,6 +115,7 @@ def steps(*args, **kwargs): simulator._simulator_iterator.side_effect = steps circuit = mock.Mock(cirq.Circuit) + circuit.__iter__ = mock.Mock(return_value=iter([])) param_resolvers = [mock.Mock(cirq.ParamResolver), mock.Mock(cirq.ParamResolver)] qubit_order = mock.Mock(cirq.QubitOrder) From 91cb54757774213b1a73441fac7f40a6c7b5d103 Mon Sep 17 00:00:00 2001 From: Vamsi Krishna Devabathini Date: Tue, 23 Jul 2019 19:31:09 +0530 Subject: [PATCH 2/2] Indentation changes --- cirq/google/xmon_device.py | 10 +++------- cirq/ion/ion_device.py | 9 +++------ cirq/sim/simulator.py | 8 ++------ 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/cirq/google/xmon_device.py b/cirq/google/xmon_device.py index b0abcf56917..a0f2e479239 100644 --- a/cirq/google/xmon_device.py +++ b/cirq/google/xmon_device.py @@ -15,7 +15,7 @@ from datetime import timedelta from typing import Iterable, cast, Optional, List, Union, TYPE_CHECKING -from cirq import circuits, devices, ops, protocols, value +from cirq import circuits, devices, ops, value from cirq.google import convert_to_xmon_gates from cirq.devices.grid_qubit import GridQubit @@ -212,9 +212,5 @@ def __str__(self): use_unicode_characters=True) def _value_equality_values_(self): - return (self._measurement_duration, - self._exp_w_duration, - self._exp_z_duration, - self.qubits) - - + return (self._measurement_duration, self._exp_w_duration, + self._exp_z_duration, self.qubits) \ No newline at end of file diff --git a/cirq/ion/ion_device.py b/cirq/ion/ion_device.py index 67223c37923..70c8b8d0727 100644 --- a/cirq/ion/ion_device.py +++ b/cirq/ion/ion_device.py @@ -15,7 +15,7 @@ from datetime import timedelta from typing import cast, Iterable, Optional, Union, TYPE_CHECKING -from cirq import circuits, value, devices, ops, protocols +from cirq import circuits, value, devices, ops from cirq.line import LineQubit from cirq.ion import convert_to_ion_gates @@ -177,8 +177,5 @@ def __str__(self): use_unicode_characters=True) def _value_equality_values_(self): - return (self._measurement_duration, - self._twoq_gates_duration, - self._oneq_gates_duration, - self.qubits) - + return (self._measurement_duration, self._twoq_gates_duration, + self._oneq_gates_duration, self.qubits) \ No newline at end of file diff --git a/cirq/sim/simulator.py b/cirq/sim/simulator.py index 6fc72d594f7..6f98151c440 100644 --- a/cirq/sim/simulator.py +++ b/cirq/sim/simulator.py @@ -554,12 +554,8 @@ def qubit_map(self) -> Dict[ops.Qid, int]: def _verify_unique_measurement_keys(circuit: circuits.Circuit): result = collections.Counter( protocols.measurement_key(op, default=None) - for op in ops.flatten_op_tree(iter(circuit)) - ) + for op in ops.flatten_op_tree(iter(circuit))) result[None] = 0 duplicates = [k for k, v in result.most_common() if v > 1] if duplicates: - raise ValueError( - 'Measurement key {} repeated'.format(duplicates)) - - + raise ValueError('Measurement key {} repeated'.format(duplicates)) \ No newline at end of file