Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantum Chinese Chess] Add save_snapshot() and restore_last_snapshot() to QuantumWorld #173

Merged
merged 6 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions unitary/alpha/quantum_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,22 @@ def clear(self) -> None:
"""
self.circuit = cirq.Circuit()
self.effect_history: List[Tuple[cirq.Circuit, Dict[QuantumObject, int]]] = []
# This variable is used to save the length of current effect history before each move is made,
# so that if we later undo we know how many effects we need to pop out, since each move could
# consisit of several effects.
madcpf marked this conversation as resolved.
Show resolved Hide resolved
self.effect_history_length = []
madcpf marked this conversation as resolved.
Show resolved Hide resolved
self.object_name_dict: Dict[str, QuantumObject] = {}
self.ancilla_names: Set[str] = set()
# When `compile_to_qubits` is True, this tracks the mapping of the
# original qudits to the compiled qubits.
self.compiled_qubits: Dict[cirq.Qid, List[cirq.Qid]] = {}
self.post_selection: Dict[QuantumObject, int] = {}
# This variable is used to save the qubit remapping dictionary before each move, so that if
# we later undo we know how to reverse the mapping.
self.qubit_remapping_dict: List[Dict[cirq.Qid, cirq.Qid]] = []
# This variable is used to save the length of qubit_remapping_dict before each move is made,
# so that if we later undo we know how to remap the qubits.
self.qubit_remapping_dict_length = []

def copy(self) -> "QuantumWorld":
new_objects = []
Expand All @@ -95,7 +105,17 @@ def copy(self) -> "QuantumWorld":
(circuit.copy(), copy.copy(post_selection))
for circuit, post_selection in self.effect_history
]
new_world.effect_history_length = self.effect_history_length.copy()
new_world.post_selection = new_post_selection
# copy qubit_remapping_dict
for remap in self.qubit_remapping_dict:
new_dict = {}
for key_obj, value_obj in remap.items():
new_dict[
new_world.get_object_by_name(key_obj.name)
] = new_world.get_object_by_name(value_obj.name)
new_world.qubit_remapping_dict.append(new_dict)
new_world.qubit_remapping_dict_length = self.qubit_remapping_dict_length.copy()
return new_world

def add_object(self, obj: QuantumObject):
Expand Down Expand Up @@ -257,7 +277,8 @@ def add_effect(self, op_list: List[cirq.Operation]):
self._append_op(op)

def undo_last_effect(self):
"""Restores the `QuantumWorld` to the state before the last effect.
"""Restores the circuit and post selection dictionary of `QuantumWorld` to the
state before the last effect.

Note that pop() is considered to be an effect for the purposes
of this call.
Expand All @@ -269,6 +290,52 @@ def undo_last_effect(self):
raise IndexError("No effects to undo")
self.circuit, self.post_selection = self.effect_history.pop()

def save_snapshot(self) -> None:
"""Saves the current length of the effect history and qubit_remapping_dict."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want a better description here. Is the idea that we save a snapshot every move?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Doug, yes for chess we default to save a snapshot after each player's move, so later on if they choose to undo those quantum properties could be restored. I'm adding this into the comment. (In the end it's up to the game developer to decide in what granularity they allow the player to undo.)

self.effect_history_length.append(len(self.effect_history))
self.qubit_remapping_dict_length.append(len(self.qubit_remapping_dict))

def restore_last_snapshot(self) -> None:
"""Restores the `QuantumWorld` to the last snapshot (which was saved after the last move
finished), which includes
- reversing the mapping of qubits, if any,
- restoring the post selection dictionary,
- restoring the circuit.
"""
if (
len(self.effect_history_length) <= 1
or len(self.qubit_remapping_dict_length) <= 1
):
# length == 1 corresponds to the initial state, and no more restore could be made.
raise ValueError("Unable to restore any more.")

# Recover the mapping of qubits to the last snapshot, and remove any related post selection memory.
# Note that this need to be done before calling `undo_last_effect()`, otherwise the remapping does not
# work as expected.
self.qubit_remapping_dict_length.pop()
last_length = self.qubit_remapping_dict_length[-1]
while len(self.qubit_remapping_dict) > last_length:
qubit_remapping_dict = self.qubit_remapping_dict.pop()
if len(qubit_remapping_dict) == 0:
continue
# Reverse the mapping.
self.circuit = self.circuit.transform_qubits(
lambda q: qubit_remapping_dict.get(q, q)
)
# Clear relevant qubits from the post selection dictionary.
# TODO(): rethink if this is necessary, given that undo_last_effect()
# will also restore post selection dictionary.
for obj in qubit_remapping_dict.keys():
if obj in self.post_selection:
self.post_selection.pop(obj)

# Recover the effects up to the last snapshot by popping effects out of the
# effect history of the board until its length equals the last snapshot's length.
self.effect_history_length.pop()
last_length = self.effect_history_length[-1]
while len(self.effect_history) > last_length:
self.undo_last_effect()

def _suggest_num_reps(self, sample_size: int) -> int:
"""Guess the number of raw samples needed to get sample_size results.
Assume that each post-selection is about 50/50.
Expand Down Expand Up @@ -323,6 +390,7 @@ def unhook(self, object: QuantumObject) -> None:
object.qubit: new_ancilla.qubit,
new_ancilla.qubit: object.qubit,
}
self.qubit_remapping_dict.append(qubit_remapping_dict)
self.circuit = self.circuit.transform_qubits(
lambda q: qubit_remapping_dict.get(q, q)
)
Expand All @@ -348,7 +416,7 @@ def force_measurement(
qubit_remapping_dict.update(
{*zip(obj_qubits, new_obj_qubits), *zip(new_obj_qubits, obj_qubits)}
)

self.qubit_remapping_dict.append(qubit_remapping_dict)
self.circuit = self.circuit.transform_qubits(
lambda q: qubit_remapping_dict.get(q, q)
)
Expand Down
76 changes: 76 additions & 0 deletions unitary/alpha/quantum_world_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def test_copy(simulator, compile_to_qubits):
alpha.Flip()(light2)
assert board.pop([light1])[0] == Light.RED
assert board.pop([light2])[0] == Light.GREEN
board.save_snapshot()

board2 = board.copy()

Expand All @@ -345,9 +346,15 @@ def test_copy(simulator, compile_to_qubits):
assert board.circuit is not board2.circuit
assert board.effect_history == board2.effect_history
assert board.effect_history is not board2.effect_history
assert board.effect_history_length == board2.effect_history_length
assert board.qubit_remapping_dict_length == board2.qubit_remapping_dict_length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally these asserts should also have some sort a message (second argument to assert) to help callers debug issues

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most cases here are kind of self explanatory and would be a bit redundant to add messages. Thanks.

assert board.ancilla_names == board2.ancilla_names
assert board.ancilla_names is not board2.ancilla_names
assert len(board2.post_selection) == 2
assert [key.name for key in board2.qubit_remapping_dict[-1].keys()] == [
"l2",
"ancilla_l2_0",
]
madcpf marked this conversation as resolved.
Show resolved Hide resolved

# Assert that they now evolve independently
board2.undo_last_effect()
Expand Down Expand Up @@ -775,3 +782,72 @@ def test_get_correlated_histogram_with_entangled_qobjects(simulator, compile_to_

histogram = world.get_correlated_histogram()
assert histogram.keys() == {(0, 0, 1, 1, 0), (0, 1, 0, 0, 1)}


@pytest.mark.parametrize(
("simulator", "compile_to_qubits"),
[
(cirq.Simulator, False),
(cirq.Simulator, True),
# Cannot use SparseSimulator without `compile_to_qubits` due to issue #78.
(alpha.SparseSimulator, True),
],
)
def test_save_and_restore_snapshot(simulator, compile_to_qubits):
light1 = alpha.QuantumObject("l1", Light.GREEN)
light2 = alpha.QuantumObject("l2", Light.RED)
light3 = alpha.QuantumObject("l3", Light.RED)
light4 = alpha.QuantumObject("l4", Light.RED)
light5 = alpha.QuantumObject("l5", Light.RED)

# Initial state.
world = alpha.QuantumWorld(
[light1, light2, light3, light4, light5],
sampler=simulator(),
compile_to_qubits=compile_to_qubits,
)
# Snapshot #0
world.save_snapshot()
circuit_0 = world.circuit.copy()
# one effect from Flip()
assert world.effect_history_length == [1]
assert world.qubit_remapping_dict_length == [0]
assert world.post_selection == {}

# First move.
alpha.Split()(light1, light2, light3)
# Snapshot #1
world.save_snapshot()
circuit_1 = world.circuit.copy()
# one more effect from Split()
assert world.effect_history_length == [1, 2]
assert world.qubit_remapping_dict_length == [0, 0]
assert world.post_selection == {}

# Second move, which includes multiple effects and post selection.
alpha.Flip()(light2)
alpha.Split()(light3, light4, light5)
world.force_measurement(light4, Light.RED)
world.unhook(light5)
# Snapshot #2
world.save_snapshot()
# 2 more effects from Flip() and Split()
assert world.effect_history_length == [1, 2, 4]
# 2 mapping from force_measurement() and unhook()
assert world.qubit_remapping_dict_length == [0, 0, 2]
# 1 post selection from force_measurement
assert len(world.post_selection) == 1

# Restore to snapshot #1
madcpf marked this conversation as resolved.
Show resolved Hide resolved
world.restore_last_snapshot()
assert world.effect_history_length == [1, 2]
assert world.qubit_remapping_dict_length == [0, 0]
assert world.circuit == circuit_1
assert world.post_selection == {}

# Restore to snapshot #0
world.restore_last_snapshot()
assert world.effect_history_length == [1]
assert world.qubit_remapping_dict_length == [0]
assert world.circuit == circuit_0
assert world.post_selection == {}
80 changes: 66 additions & 14 deletions unitary/examples/quantum_chinese_chess/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(self):
self.game_state = GameState.CONTINUES
self.current_player = self.board.current_player
self.debug_level = 3
# This variable is used to save the classical properties of the whole board before each move is
# made, so that if we later undo we could recover the earlier classical state.
self.classical_properties_history = []
madcpf marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def parse_input_string(str_to_parse: str) -> Tuple[List[str], List[str]]:
Expand Down Expand Up @@ -422,15 +425,10 @@ def next_move(self) -> Tuple[bool, str]:
# TODO(): make it look like the normal board. Right now it's only for debugging purposes.
print(self.board.board.peek(convert_to_enum=False))
elif input_str.lower() == "undo":
output = "Undo last quantum effect."
# Right now it's only for debugging purposes, since it has following problems:
# TODO(): there are several problems here:
# 1) the classical piece information is not reversed back.
# ==> we may need to save the change of classical piece information of each step.
# 2) last move involved multiple effects.
# ==> we may need to save number of effects per move, and undo that number of times.
self.board.board.undo_last_effect()
return True, output
if self.undo():
return True, "Undoing."
return False, "Unable to undo any more."

else:
try:
# The move is success if no ValueError is raised.
Expand Down Expand Up @@ -476,6 +474,56 @@ def game_over(self) -> None:
# TODO(): add the following checks
# - If player 0 made N repeatd back-and_forth moves in a row.

def save_snapshot(self) -> None:
"""Saves the current length of the effect history, qubit_remapping_dict, and the current classical states of all pieces."""
# Save the current length of the effect history and qubit_remapping_dict.
self.board.board.save_snapshot()

# Save the classical states of all pieces.
snapshot = []
for row in range(10):
for col in "abcdefghi":
piece = self.board.board[f"{col}{row}"]
snapshot.append(
[piece.type_.value, piece.color.value, piece.is_entangled]
)
self.classical_properties_history.append(snapshot)

def undo(self) -> bool:
"""Undo the last move, which includes reset quantum effects and classical properties, and remapping
qubits.

Returns True if the undo is success, and False otherwise.
"""
world = self.board.board
if (
len(world.effect_history_length) <= 1
or len(world.qubit_remapping_dict_length) <= 1
or len(self.classical_properties_history) <= 1
):
# length == 1 corresponds to the initial state, and no more undo could be made.
return False

# Recover the mapping of qubits to the last snapshot, remove any related post selection memory,
# and recover the effects up to the last snapshot (which was saved after the last move finished).
try:
world.restore_last_snapshot()
except:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is a specific exception you can capture here instead of the blanked except?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now the possible error is only ValueError. But to be future proof I think maybe we could just include all kinds of exceptions here, instead of

try:
abc
except ValueError:
return False
except:
return False

Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't think it's a good idea to catch all exceptions. Some you might not want to catch, like "KeyboardInterrupt" for instance. This is also going to swallow errors and hide them which will make debugging harder too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Doug. I'm updating the code to only catch ValueError. Other errors will be raised instead of swallowed.

return False

# Recover the classical properties of all pieces to the last snapshot.
self.classical_properties_history.pop()
snapshot = self.classical_properties_history[-1]
index = 0
for row in range(10):
for col in "abcdefghi":
madcpf marked this conversation as resolved.
Show resolved Hide resolved
piece = world[f"{col}{row}"]
piece.type_ = Type(snapshot[index][0])
piece.color = Color(snapshot[index][1])
piece.is_entangled = snapshot[index][2]
index += 1
return True

def play(self) -> None:
"""The loop where each player takes turn to play."""
while True:
Expand All @@ -487,11 +535,15 @@ def play(self) -> None:
print("\nPlease re-enter your move.")
continue
print(output)
# TODO(): maybe we should not check game_over() when an undo is made.
# Check if the game is over.
self.game_over()
# TODO(): no need to do sampling if the last move was CLASSICAL.
self.update_board_by_sampling()
if output != "Undoing.":
# Check if the game is over.
self.game_over()
# Update any empty or occupied pieces' classical state.
# TODO(): no need to do sampling if the last move was CLASSICAL.
probs = self.update_board_by_sampling()
# Save the current states.
self.save_snapshot()
# TODO(): pass probs into the following method to print probabilities.
print(self.board)
if self.game_state == GameState.CONTINUES:
# If the game continues, switch the player.
Expand Down
Loading
Loading