Skip to content

Commit

Permalink
Override optimal_detuning_off on stored calls (#588)
Browse files Browse the repository at this point in the history
* Get rid of DeprecationWarnings

* Override `optimal_detuning_off` on stored calls

* Override `optimal_detuning_off` when ever`detuning_off` is known
  • Loading branch information
HGSilveri authored Sep 27, 2023
1 parent 3e40319 commit e0943d9
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 17 deletions.
34 changes: 32 additions & 2 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,13 @@ def _config_detuning_map(
def switch_device(
self, new_device: DeviceType, strict: bool = False
) -> Sequence:
"""Switch the device of a sequence.
"""Replicate the sequence with a different device.
This method is designed to replicate the sequence with as few changes
to the original contents as possible.
If the `strict` option is chosen, the device switch will fail whenever
it cannot guarantee that the new sequence's contents will not be
modified in the process.
Args:
new_device: The target device instance.
Expand Down Expand Up @@ -1002,7 +1008,7 @@ def declare_variable(
self._variables[name] = var
return var

@seq_decorators.store
@seq_decorators.verify_parametrization
@seq_decorators.block_if_measured
def enable_eom_mode(
self,
Expand Down Expand Up @@ -1057,6 +1063,7 @@ def enable_eom_mode(
on_pulse = Pulse.ConstantPulse(
channel_obj.min_duration, amp_on, detuning_on, 0.0
)
stored_opt_detuning_off = optimal_detuning_off
if not isinstance(on_pulse, Parametrized):
channel_obj.validate_pulse(on_pulse)
amp_on = cast(float, amp_on)
Expand All @@ -1070,6 +1077,10 @@ def enable_eom_mode(
channel_obj.min_duration, 0.0, detuning_off, 0.0
)
channel_obj.validate_pulse(off_pulse)
# Update optimal_detuning_off to match the chosen detuning_off
# This minimizes the changes to the sequence when the device
# is switched
stored_opt_detuning_off = detuning_off

if not self.is_parametrized():
phase_drift_params = _PhaseDriftParams(
Expand All @@ -1085,6 +1096,25 @@ def enable_eom_mode(
-drift, *buffer_slot.targets, basis=channel_obj.basis
)

# Manually store the call to "enable_eom_mode" so that the updated
# 'optimal_detuning_off' is stored
call_container = (
self._to_build_calls if self.is_parametrized() else self._calls
)
call_container.append(
_Call(
"enable_eom_mode",
(),
dict(
channel=channel,
amp_on=amp_on,
detuning_on=detuning_on,
optimal_detuning_off=stored_opt_detuning_off,
correct_phase_drift=correct_phase_drift,
),
)
)

@seq_decorators.store
@seq_decorators.block_if_measured
def disable_eom_mode(
Expand Down
59 changes: 46 additions & 13 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,14 +1572,24 @@ def test_deserialize_parametrized_pulse(self, op, pulse_cls):
assert pulse.kwargs["detuning"] == 1

@pytest.mark.parametrize("correct_phase_drift", (False, True, None))
def test_deserialize_eom_ops(self, correct_phase_drift):
@pytest.mark.parametrize("var_detuning_on", [False, True])
def test_deserialize_eom_ops(self, correct_phase_drift, var_detuning_on):
detuning_on = (
{
"expression": "index",
"lhs": {"variable": "detuning_on"},
"rhs": 0,
}
if var_detuning_on
else 0.0
)
s = _get_serialized_seq(
operations=[
{
"op": "enable_eom_mode",
"channel": "global",
"amp_on": 3.0,
"detuning_on": 0.0,
"detuning_on": detuning_on,
"optimal_detuning_off": -1.0,
"correct_phase_drift": correct_phase_drift,
},
Expand All @@ -1602,29 +1612,52 @@ def test_deserialize_eom_ops(self, correct_phase_drift):
"correct_phase_drift": correct_phase_drift,
},
],
variables={"duration": {"type": "int", "value": [100]}},
variables={
"duration": {"type": "int", "value": [100]},
"detuning_on": {"type": "int", "value": [0.0]},
},
device=json.loads(IroiseMVP.to_abstract_repr()),
channels={"global": "rydberg_global"},
)
if correct_phase_drift is None:
for op in s["operations"]:
del op["correct_phase_drift"]
_check_roundtrip(s)

seq = Sequence.from_abstract_repr(json.dumps(s))
# init + declare_channel + enable_eom_mode
assert len(seq._calls) == 3
# add_eom_pulse + disable_eom
assert len(seq._to_build_calls) == 2
# init + declare_channel + enable_eom_mode (if not var_detuning_on)
assert len(seq._calls) == 3 - var_detuning_on
# add_eom_pulse + disable_eom + enable_eom_mode (if var_detuning_on)
assert len(seq._to_build_calls) == 2 + var_detuning_on

if var_detuning_on:
enable_eom_call = seq._to_build_calls[0]
optimal_det_off = -1.0
else:
enable_eom_call = seq._calls[-1]
eom_conf = seq.declared_channels["global"].eom_config
optimal_det_off = eom_conf.calculate_detuning_off(
3.0, detuning_on, -1.0
)

# Roundtrip will only match if the optimal detuning off matches
# detuning_off from the start
mod_s = deepcopy(s)
mod_s["operations"][0]["optimal_detuning_off"] = optimal_det_off
_check_roundtrip(mod_s)

enable_eom_call = seq._calls[-1]
assert enable_eom_call.name == "enable_eom_mode"
assert enable_eom_call.kwargs == {
enable_eom_kwargs = enable_eom_call.kwargs.copy()
detuning_on_kwarg = enable_eom_kwargs.pop("detuning_on")
assert enable_eom_kwargs == {
"channel": "global",
"amp_on": 3.0,
"detuning_on": 0.0,
"optimal_detuning_off": -1.0,
"optimal_detuning_off": optimal_det_off,
"correct_phase_drift": bool(correct_phase_drift),
}
if var_detuning_on:
assert isinstance(detuning_on_kwarg, VariableItem)
else:
assert detuning_on_kwarg == detuning_on

disable_eom_call = seq._to_build_calls[-1]
assert disable_eom_call.name == "disable_eom_mode"
Expand All @@ -1633,7 +1666,7 @@ def test_deserialize_eom_ops(self, correct_phase_drift):
"correct_phase_drift": bool(correct_phase_drift),
}

eom_pulse_call = seq._to_build_calls[0]
eom_pulse_call = seq._to_build_calls[var_detuning_on]
assert eom_pulse_call.name == "add_eom_pulse"
assert eom_pulse_call.kwargs["channel"] == "global"
assert isinstance(eom_pulse_call.kwargs["duration"], VariableItem)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,8 +1558,8 @@ def test_slm_mask_in_ising(
seq5.add(Pulse.ConstantPulse(200, var, 0, 0), "ch")
assert seq5.is_parametrized()
seq5.config_slm_mask(targets)
seq5_str = seq5.serialize()
seq5_ = Sequence.deserialize(seq5_str)
seq5_str = seq5._serialize()
seq5_ = Sequence._deserialize(seq5_str)
assert str(seq5) == str(seq5_)


Expand Down

0 comments on commit e0943d9

Please sign in to comment.