diff --git a/unitary/alpha/quantum_world.py b/unitary/alpha/quantum_world.py index 43d401ce..d5e6390d 100644 --- a/unitary/alpha/quantum_world.py +++ b/unitary/alpha/quantum_world.py @@ -336,7 +336,7 @@ def force_measurement( def peek( self, - objects: Optional[Sequence[QuantumObject]] = None, + objects: Optional[Sequence[Union[QuantumObject, str]]] = None, count: int = 1, convert_to_enum: bool = True, _existing_list: Optional[List[List[Union[enum.Enum, int]]]] = None, @@ -364,8 +364,13 @@ def peek( measure_circuit = self.circuit.copy() if objects is None: - objects = self.public_objects - measure_set = set(objects) + quantum_objects = self.public_objects + else: + quantum_objects = [ + self[obj_or_str] if isinstance(obj_or_str, str) else obj_or_str + for obj_or_str in objects + ] + measure_set = set(quantum_objects) measure_set.update(self.post_selection.keys()) measure_circuit.append( [ @@ -390,7 +395,7 @@ def peek( rtn_list.append( [ self._interpret_result(results.measurements[obj.name][rep]) - for obj in objects + for obj in quantum_objects ] ) if len(rtn_list) == count: @@ -398,12 +403,16 @@ def peek( if len(rtn_list) < count: # We post-selected too much, get more reps return self.peek( - objects, count, convert_to_enum, rtn_list, _num_reps=num_reps * 10 + quantum_objects, + count, + convert_to_enum, + rtn_list, + _num_reps=num_reps * 10, ) if convert_to_enum: rtn_list = [ - [objects[idx].enum_type(meas) for idx, meas in enumerate(res)] + [quantum_objects[idx].enum_type(meas) for idx, meas in enumerate(res)] for res in rtn_list ] @@ -411,17 +420,22 @@ def peek( def pop( self, - objects: Optional[Sequence[QuantumObject]] = None, + objects: Optional[Sequence[Union[QuantumObject, str]]] = None, convert_to_enum: bool = True, ) -> List[Union[enum.Enum, int]]: self.effect_history.append( (self.circuit.copy(), copy.copy(self.post_selection)) ) if objects is None: - objects = self.public_objects - results = self.peek(objects, convert_to_enum=convert_to_enum) + quantum_objects = self.public_objects + else: + quantum_objects = [ + self[obj_or_str] if isinstance(obj_or_str, str) else obj_or_str + for obj_or_str in objects + ] + results = self.peek(quantum_objects, convert_to_enum=convert_to_enum) for idx, result in enumerate(results[0]): - self.force_measurement(objects[idx], result) + self.force_measurement(quantum_objects[idx], result) return results[0] @@ -490,3 +504,9 @@ def get_binary_probabilities( for one_probs in full_probs: binary_probs.append(1 - one_probs[0]) return binary_probs + + def __getitem__(self, name: str) -> QuantumObject: + quantum_object = self.object_name_dict.get(name, None) + if not quantum_object: + raise KeyError(f"{name} did not exist in this world.") + return quantum_object diff --git a/unitary/alpha/quantum_world_test.py b/unitary/alpha/quantum_world_test.py index aa5c4e2e..b656fcfb 100644 --- a/unitary/alpha/quantum_world_test.py +++ b/unitary/alpha/quantum_world_test.py @@ -51,6 +51,10 @@ def test_get_object_by_name(compile_to_qubits): assert board.get_object_by_name("test") == light assert board.get_object_by_name("test2") == light2 assert board.get_object_by_name("test3") == None + assert board["test"] == light + assert board["test2"] == light2 + with pytest.raises(KeyError): + _ = board["test3"] @pytest.mark.parametrize("compile_to_qubits", [False, True]) @@ -62,10 +66,12 @@ def test_one_qubit(simulator, compile_to_qubits): ) assert board.peek() == [[Light.GREEN]] assert board.peek([light], count=2) == [[Light.GREEN], [Light.GREEN]] + assert board.peek(["test"], count=2) == [[Light.GREEN], [Light.GREEN]] light = alpha.QuantumObject("test", 1) board = alpha.QuantumWorld([light], compile_to_qubits=compile_to_qubits) assert board.peek() == [[1]] assert board.peek([light], count=2) == [[1], [1]] + assert board.peek(["test"], count=2) == [[1], [1]] assert board.pop() == [1] @@ -81,6 +87,8 @@ def test_two_qubits(simulator, compile_to_qubits): assert board.peek(convert_to_enum=False) == [[1, 0]] assert board.peek([light], count=2) == [[Light.GREEN], [Light.GREEN]] assert board.peek([light2], count=2) == [[Light.RED], [Light.RED]] + assert board.peek(["green"], count=2) == [[Light.GREEN], [Light.GREEN]] + assert board.peek(["red"], count=2) == [[Light.RED], [Light.RED]] assert board.peek(count=3) == [ [Light.GREEN, Light.RED], [Light.GREEN, Light.RED], @@ -173,6 +181,8 @@ def test_pop(simulator, compile_to_qubits): assert not all(result[0] == 0 for result in results) assert not all(result[0] == 1 for result in results) popped = board.pop([light2])[0] + popped2 = board.pop(["l2"])[0] + assert popped == popped2 results = board.peek([light2, light3], count=200) assert len(results) == 200 assert all(result[0] == popped for result in results)