Skip to content

Commit

Permalink
Abstract representation fixes (#595)
Browse files Browse the repository at this point in the history
* Fix `phase_shift()` deserialization

* Support multi-target serialization

* Bump version to v0.15.2

* Improve comments in the serializer

* Serialize target to single int whenever possible
  • Loading branch information
HGSilveri authored Oct 9, 2023
1 parent adbac8b commit a531a50
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 21 deletions.
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.15.1
0.15.2
12 changes: 10 additions & 2 deletions pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

VARIABLE_TYPE_MAP = {"int": int, "float": float}

ExpReturnType = Union[int, float, ParamObj]
ExpReturnType = Union[int, float, list, ParamObj]


@overload
Expand All @@ -76,6 +76,13 @@ def _deserialize_parameter(param: float, vars: dict[str, Variable]) -> float:
pass


@overload
def _deserialize_parameter(
param: list[int], vars: dict[str, Variable]
) -> list[int]:
pass


@overload
def _deserialize_parameter(
param: dict[str, str], vars: dict[str, Variable]
Expand All @@ -84,7 +91,7 @@ def _deserialize_parameter(


def _deserialize_parameter(
param: Union[int, float, dict[str, Any]],
param: Union[int, float, list[int], dict[str, Any]],
vars: dict[str, Variable],
) -> Union[ExpReturnType, Variable]:
"""Deserialize a parameterized object.
Expand Down Expand Up @@ -213,6 +220,7 @@ def _deserialize_operation(seq: Sequence, op: dict, vars: dict) -> None:
seq.phase_shift_index(
_deserialize_parameter(op["phi"], vars),
*[_deserialize_parameter(t, vars) for t in op["targets"]],
basis=op["basis"],
)
elif op["op"] == "pulse":
phase = _deserialize_parameter(op["phase"], vars)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,15 @@
"type": "string"
},
"target": {
"$ref": "#/definitions/ParametrizedNum",
"description": "New target atom index"
"anyOf": [
{
"$ref": "#/definitions/ParametrizedNum"
},
{
"$ref": "#/definitions/ParametrizedNumArray"
}
],
"description": "New target atom index (or indices)"
}
},
"required": [
Expand Down
39 changes: 31 additions & 8 deletions pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@

import inspect
import json
from collections.abc import Iterable
from itertools import chain
from typing import TYPE_CHECKING, Any
from typing import Sequence as abcSequence
from typing import Union, cast
from typing import TYPE_CHECKING, Any, Union, cast

import numpy as np

import pulser
from pulser.json.abstract_repr.signatures import SIGNATURES
from pulser.json.abstract_repr.validation import validate_abstract_repr
from pulser.json.exceptions import AbstractReprError
from pulser.json.utils import stringify_qubit_ids

if TYPE_CHECKING:
from pulser.parametrized import Parametrized
from pulser.register.base_register import QubitId
from pulser.sequence import Sequence
from pulser.sequence._call import _Call
Expand Down Expand Up @@ -154,17 +155,29 @@ def serialize_abstract_sequence(
for var in seq._variables.values():
res["variables"][var.name]["value"] = [var.dtype()] * var.size

def unfold_targets(
target_ids: QubitId | Iterable[QubitId],
) -> QubitId | list[QubitId]:
if isinstance(target_ids, (int, str)):
return target_ids

targets = list(cast(Iterable, target_ids))
return targets if len(targets) > 1 else targets[0]

def convert_targets(
target_ids: Union[QubitId, abcSequence[QubitId]]
target_ids: Union[QubitId, Iterable[QubitId]],
force_list_out: bool = False,
) -> Union[int, list[int]]:
target_array = np.array(target_ids)
target_array = np.array(unfold_targets(target_ids))
og_dim = target_array.ndim
if og_dim == 0:
target_array = target_array[np.newaxis]
indices = seq.get_register(include_mappable=True).find_indices(
target_array.tolist()
)
return indices[0] if og_dim == 0 else indices
if force_list_out or og_dim > 0:
return indices
return indices[0]

def get_kwarg_default(call_name: str, kwarg_name: str) -> Any:
sig = inspect.signature(getattr(seq, call_name))
Expand Down Expand Up @@ -230,10 +243,20 @@ def remove_kwarg_if_default(
)
elif "target" in call.name:
data = get_all_args(("qubits", "channel"), call)
target: Parametrized | int | list[int]
if call.name == "target":
target = convert_targets(data["qubits"])
elif call.name == "target_index":
target = data["qubits"]
if isinstance(
data["qubits"], pulser.parametrized.Parametrized
):
# The qubit indices are given through a variable
target = data["qubits"]
else:
# Either a single index or a sequence of indices
target = cast(
Union[int, list], unfold_targets(data["qubits"])
)
else:
raise AbstractReprError(f"Unknown call '{call.name}'.")
operations.append(
Expand Down Expand Up @@ -269,7 +292,7 @@ def remove_kwarg_if_default(
elif "phase_shift" in call.name:
targets = call.args[1:]
if call.name == "phase_shift":
targets = convert_targets(targets)
targets = convert_targets(targets, force_list_out=True)
elif call.name != "phase_shift_index":
raise AbstractReprError(f"Unknown call '{call.name}'.")
operations.append(
Expand Down
8 changes: 6 additions & 2 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,9 @@ def current_phase_ref(
)

if basis not in self._basis_ref:
raise ValueError("No declared channel targets the given 'basis'.")
raise ValueError(
f"No declared channel targets the given 'basis' ('{basis}')."
)

return self._basis_ref[basis][qubit].phase.last_phase

Expand Down Expand Up @@ -2079,7 +2081,9 @@ def _phase_shift(
_index: bool = False,
) -> None:
if basis not in self._basis_ref:
raise ValueError("No declared channel targets the given 'basis'.")
raise ValueError(
f"No declared channel targets the given 'basis' ('{basis}')."
)
target_ids = self._check_qubits_give_ids(*targets, _index=_index)

if not self.is_parametrized():
Expand Down
68 changes: 62 additions & 6 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from pulser.json.exceptions import AbstractReprError, DeserializeDeviceError
from pulser.parametrized.decorators import parametrize
from pulser.parametrized.paramobj import ParamObj
from pulser.parametrized.variable import VariableItem
from pulser.parametrized.variable import Variable, VariableItem
from pulser.register.register_layout import RegisterLayout
from pulser.register.special_layouts import TriangularLatticeLayout
from pulser.sequence._call import _Call
Expand Down Expand Up @@ -260,7 +260,10 @@ def sequence(self, request):
reg = Register(qubits)
device = request.param
seq = Sequence(reg, device)
seq.declare_channel("digital", "raman_local", initial_target="control")

seq.declare_channel(
"digital", "raman_local", initial_target=("control",)
)
seq.declare_channel(
"rydberg", "rydberg_local", initial_target="control"
)
Expand Down Expand Up @@ -291,7 +294,7 @@ def sequence(self, request):
seq.align("digital", "rydberg")
seq.add(pi_pulse, "rydberg")
seq.phase_shift(1.0, "control", "target", basis="ground-rydberg")
seq.target("target", "rydberg")
seq.target({"target"}, "rydberg")
seq.add(two_pi_pulse, "rydberg")

seq.delay(100, "digital")
Expand Down Expand Up @@ -348,6 +351,12 @@ def test_values(self, abstract):
assert abstract["operations"][0] == {
"op": "target",
"channel": "digital",
"target": 0, # tuple[int] is still serialized as int
}

assert abstract["operations"][1] == {
"op": "target",
"channel": "rydberg",
"target": 0,
}

Expand Down Expand Up @@ -414,6 +423,12 @@ def test_values(self, abstract):
"post_phase_shift": 0.0,
}

assert abstract["operations"][8] == {
"op": "target",
"channel": "rydberg",
"target": 1,
}

assert abstract["operations"][10] == {
"op": "delay",
"channel": "digital",
Expand Down Expand Up @@ -897,6 +912,33 @@ def test_dmm_slm_mask(self, triangular_lattice, is_empty):
assert abstract["operations"][3]["op"] == "pulse"
assert abstract["operations"][3]["channel"] == "rydberg_global"

def test_multi_qubit_target(self):
seq_ = Sequence(Register.square(2, prefix="q"), MockDevice)
var_targets = seq_.declare_variable("var_targets", dtype=int, size=4)

seq_.declare_channel(
"rydberg_local", "rydberg_local", initial_target=("q0", "q1")
)
seq_.target(["q3", "q2"], "rydberg_local")
seq_.target_index(var_targets, "rydberg_local")
seq_.target(["q0"], "rydberg_local")
seq_.target_index(var_targets[2], "rydberg_local")

abstract = json.loads(seq_.to_abstract_repr())

assert all(op["op"] == "target" for op in abstract["operations"])
assert abstract["operations"][0]["target"] == [0, 1]
assert abstract["operations"][1]["target"] == [3, 2]
assert abstract["operations"][2]["target"] == {
"variable": "var_targets"
}
assert abstract["operations"][3]["target"] == 0
assert abstract["operations"][4]["target"] == {
"expression": "index",
"lhs": {"variable": "var_targets"},
"rhs": 2,
}


def _get_serialized_seq(
operations: list[dict] = [],
Expand Down Expand Up @@ -1185,6 +1227,7 @@ def test_deserialize_variables(self, without_default):
"op",
[
{"op": "target", "target": 2, "channel": "digital"},
{"op": "target", "target": [1, 2], "channel": "digital"},
{"op": "delay", "time": 500, "channel": "global"},
{"op": "align", "channels": ["digital", "global"]},
{
Expand Down Expand Up @@ -1215,7 +1258,9 @@ def test_deserialize_variables(self, without_default):
ids=_get_op,
)
def test_deserialize_non_parametrized_op(self, op):
s = _get_serialized_seq(operations=[op])
s = _get_serialized_seq(
operations=[op], device=json.loads(MockDevice.to_abstract_repr())
)
_check_roundtrip(s)
seq = Sequence.from_abstract_repr(json.dumps(s))

Expand All @@ -1240,6 +1285,7 @@ def test_deserialize_non_parametrized_op(self, op):
elif op["op"] == "phase_shift":
assert c.name == "phase_shift_index"
assert c.args == tuple([op["phi"], *op["targets"]])
assert c.kwargs["basis"] == "digital"
elif op["op"] == "pulse":
assert c.name == "add"
assert c.kwargs["channel"] == op["channel"]
Expand Down Expand Up @@ -1390,12 +1436,17 @@ def test_deserialize_measurement(self):
"op",
[
{"op": "target", "target": var1, "channel": "digital"},
{
"op": "target",
"target": {"variable": "var1"},
"channel": "digital",
},
{"op": "delay", "time": var2, "channel": "global"},
{
"op": "phase_shift",
"phi": var1,
"targets": [2, var1],
"basis": "digital",
"basis": "ground-rydberg",
},
{
"op": "pulse",
Expand Down Expand Up @@ -1438,7 +1489,10 @@ def test_deserialize_parametrized_op(self, op):
c = seq._to_build_calls[0]
if op["op"] == "target":
assert c.name == "target_index"
assert isinstance(c.kwargs["qubits"], VariableItem)
target_type = (
VariableItem if "expression" in op["target"] else Variable
)
assert isinstance(c.kwargs["qubits"], target_type)
assert c.kwargs["channel"] == op["channel"]
elif op["op"] == "delay":
assert c.name == "delay"
Expand All @@ -1452,6 +1506,8 @@ def test_deserialize_parametrized_op(self, op):
assert c.args[1] == 2
# qubit 2 is variable
assert isinstance(c.args[2], VariableItem)
# basis is fixed
assert c.kwargs["basis"] == "ground-rydberg"
elif op["op"] == "pulse":
assert c.name == "add"
assert c.kwargs["channel"] == op["channel"]
Expand Down

0 comments on commit a531a50

Please sign in to comment.