From d5ac0207443c3a24d026aebacc9b4c09b5b9b371 Mon Sep 17 00:00:00 2001 From: Antoine Cornillot <61453516+a-corni@users.noreply.github.com> Date: Mon, 24 Jul 2023 15:12:32 +0200 Subject: [PATCH] Adding DetuningMap, DMM (#539) * adding dmm, draw to DetuningMap * Testing avoiding circular import * Refactoring to avoid circular imports * Fix broken UTs * Import sorting * Fixing plot DetuningMap * Serialization/Deserialization of DMM in device * Fixing typos * Testing DMM and DetuningMap * Fixing typo * adding xfails, fixing type * Remove DMM from devices and finish UTs * Take into account review comments * Add annotations * Error in Global and Local * Defining _DMMSchedule * Fixing nits * Fixing typo --------- Co-authored-by: HGSilveri --- pulser-core/pulser/channels/__init__.py | 1 + pulser-core/pulser/channels/base_channel.py | 20 +- pulser-core/pulser/channels/channels.py | 2 +- pulser-core/pulser/channels/dmm.py | 74 +++++++ pulser-core/pulser/devices/_device_datacls.py | 57 +++++- pulser-core/pulser/devices/_devices.py | 9 + pulser-core/pulser/devices/_mock_device.py | 1 + .../pulser/json/abstract_repr/deserializer.py | 12 +- .../pulser/json/abstract_repr/serializer.py | 2 +- pulser-core/pulser/register/_reg_drawer.py | 73 +++++-- pulser-core/pulser/register/base_register.py | 25 +++ pulser-core/pulser/register/mappable_reg.py | 17 ++ .../pulser/register/register_layout.py | 27 +++ pulser-core/pulser/register/weight_maps.py | 106 ++++++++++ pulser-core/pulser/sequence/_schedule.py | 6 + tests/test_abstract_repr.py | 41 ++++ tests/test_devices.py | 46 ++++- tests/test_dmm.py | 193 ++++++++++++++++++ 18 files changed, 677 insertions(+), 35 deletions(-) create mode 100644 pulser-core/pulser/channels/dmm.py create mode 100644 pulser-core/pulser/register/weight_maps.py create mode 100644 tests/test_dmm.py diff --git a/pulser-core/pulser/channels/__init__.py b/pulser-core/pulser/channels/__init__.py index 9645bddbd..8fea9deea 100644 --- a/pulser-core/pulser/channels/__init__.py +++ b/pulser-core/pulser/channels/__init__.py @@ -14,3 +14,4 @@ """The various hardware channel types.""" from pulser.channels.channels import Microwave, Raman, Rydberg +from pulser.channels.dmm import DMM diff --git a/pulser-core/pulser/channels/base_channel.py b/pulser-core/pulser/channels/base_channel.py index 3afabd99a..2edd2b5fd 100644 --- a/pulser-core/pulser/channels/base_channel.py +++ b/pulser-core/pulser/channels/base_channel.py @@ -17,7 +17,7 @@ import warnings from abc import ABC, abstractmethod -from dataclasses import dataclass, field, fields +from dataclasses import MISSING, dataclass, field, fields from typing import Any, Literal, Optional, Type, TypeVar, cast import numpy as np @@ -94,7 +94,7 @@ def basis(self) -> str: def _internal_param_valid_options(self) -> dict[str, tuple[str, ...]]: """Internal parameters and their valid options.""" return dict( - name=("Rydberg", "Raman", "Microwave"), + name=("Rydberg", "Raman", "Microwave", "DMM"), basis=("ground-rydberg", "digital", "XY"), addressing=("Local", "Global"), ) @@ -262,6 +262,14 @@ def Local( min_avg_amp: The minimum average amplitude of a pulse (when not zero). """ + # Can't initialize a channel whose addressing is determined internally + for cls_field in fields(cls): + if cls_field.name == "addressing": + break + if not cls_field.init and cls_field.default is not MISSING: + raise NotImplementedError( + f"{cls} cannot be initialized from `Local` method." + ) return cls( "Local", max_abs_detuning, @@ -299,6 +307,14 @@ def Global( min_avg_amp: The minimum average amplitude of a pulse (when not zero). """ + # Can't initialize a channel whose addressing is determined internally + for cls_field in fields(cls): + if cls_field.name == "addressing": + break + if not cls_field.init and cls_field.default is not MISSING: + raise NotImplementedError( + f"{cls} cannot be initialized from `Global` method." + ) return cls("Global", max_abs_detuning, max_amp, **kwargs) def validate_duration(self, duration: int) -> int: diff --git a/pulser-core/pulser/channels/channels.py b/pulser-core/pulser/channels/channels.py index 390286b71..9a95a31e1 100644 --- a/pulser-core/pulser/channels/channels.py +++ b/pulser-core/pulser/channels/channels.py @@ -41,7 +41,7 @@ class Rydberg(Channel): """Rydberg beam channel. Channel targeting the transition between the ground and rydberg states, - thus enconding the 'ground-rydberg' basis. See base class. + thus encoding the 'ground-rydberg' basis. See base class. """ eom_config: Optional[RydbergEOM] = None diff --git a/pulser-core/pulser/channels/dmm.py b/pulser-core/pulser/channels/dmm.py new file mode 100644 index 000000000..744fd2ee5 --- /dev/null +++ b/pulser-core/pulser/channels/dmm.py @@ -0,0 +1,74 @@ +# Copyright 2023 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the detuning map modulator.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal, Optional + +from pulser.channels.base_channel import Channel + + +@dataclass(init=True, repr=False, frozen=True) +class DMM(Channel): + """Defines a Detuning Map Modulator (DMM) Channel. + + A Detuning Map Modulator can be used to define `Global` detuning Pulses + (of zero amplitude and phase). These Pulses are locally modulated by the + weights of a `DetuningMap`, thus providing a local control over the + detuning. The detuning of the pulses added to a DMM has to be negative, + between 0 and `bottom_detuning`. Channel targeting the transition between + the ground and rydberg states, thus encoding the 'ground-rydberg' basis. + + Note: + The protocol to add pulses to the DMM Channel is by default + "no-delay". + + Args: + bottom_detuning: Minimum possible detuning (in rad/µs), must be below + zero. + clock_period: The duration of a clock cycle (in ns). The duration of a + pulse or delay instruction is enforced to be a multiple of the + clock cycle. + min_duration: The shortest duration an instruction can take. + max_duration: The longest duration an instruction can take. + min_avg_amp: The minimum average amplitude of a pulse (when not zero). + mod_bandwidth: The modulation bandwidth at -3dB (50% reduction), in + MHz. + """ + + bottom_detuning: Optional[float] = field(default=None, init=True) + addressing: Literal["Global"] = field(default="Global", init=False) + max_abs_detuning: Optional[float] = field(init=False, default=None) + max_amp: float = field(default=1e-16, init=False) # can't be 0 + min_retarget_interval: Optional[int] = field(init=False, default=None) + fixed_retarget_t: Optional[int] = field(init=False, default=None) + max_targets: Optional[int] = field(init=False, default=None) + + def __post_init__(self) -> None: + super().__post_init__() + if self.bottom_detuning and self.bottom_detuning > 0: + raise ValueError("bottom_detuning must be negative.") + + @property + def basis(self) -> Literal["ground-rydberg"]: + """The addressed basis name.""" + return "ground-rydberg" + + def _undefined_fields(self) -> list[str]: + optional = [ + "bottom_detuning", + "max_duration", + ] + return [field for field in optional if getattr(self, field) is None] diff --git a/pulser-core/pulser/devices/_device_datacls.py b/pulser-core/pulser/devices/_device_datacls.py index 0e1940870..2dca543b0 100644 --- a/pulser-core/pulser/devices/_device_datacls.py +++ b/pulser-core/pulser/devices/_device_datacls.py @@ -24,6 +24,7 @@ from scipy.spatial.distance import pdist, squareform from pulser.channels.base_channel import Channel +from pulser.channels.dmm import DMM from pulser.devices.interaction_coefficients import c6_dict from pulser.json.abstract_repr.serializer import AbstractReprEncoder from pulser.json.abstract_repr.validation import validate_abstract_repr @@ -34,7 +35,8 @@ DIMENSIONS = Literal[2, 3] -ALWAYS_OPTIONAL_PARAMS = ("max_sequence_duration", "max_runs") +ALWAYS_OPTIONAL_PARAMS = ("max_sequence_duration", "max_runs", "dmm_objects") +PARAMS_WITH_ABSTR_REPR = ("channel_objects", "channel_ids", "dmm_objects") @dataclass(frozen=True, repr=False) @@ -49,6 +51,9 @@ class BaseDevice(ABC): channel_ids: Custom IDs for each channel object. When defined, an ID must be given for each channel. If not defined, the IDs are generated internally based on the channels' names and addressing. + dmm_objects: The DMM subclass instances specifying each channel in the + device. They are referenced by their order in the list, with the ID + "dmm_[index in dmm_objects]". rybderg_level: The value of the principal quantum number :math:`n` when the Rydberg level used is of the form :math:`|nS_{1/2}, m_j = +1/2\rangle`. @@ -83,6 +88,7 @@ class BaseDevice(ABC): reusable_channels: bool = field(default=False, init=False) channel_ids: tuple[str, ...] | None = None channel_objects: tuple[Channel, ...] = field(default_factory=tuple) + dmm_objects: tuple[DMM, ...] = field(default_factory=tuple) def __post_init__(self) -> None: def type_check( @@ -155,6 +161,16 @@ def type_check( for ch_obj in self.channel_objects: type_check("All channels", Channel, value_override=ch_obj) + for dmm_obj in self.dmm_objects: + type_check("All DMM channels", DMM, value_override=dmm_obj) + + # TODO: Check that device has dmm objects if it supports SLM mask + # once DMM is supported for serialization + # if self.supports_slm_mask and not self.dmm_objects: + # raise ValueError( + # "One DMM object should be defined to support SLM mask." + # ) + if self.channel_ids is not None: if not ( isinstance(self.channel_ids, (tuple, list)) @@ -174,6 +190,12 @@ def type_check( "When defined, the number of channel IDs must" " match the number of channel objects." ) + if set(self.channel_ids) & set(self.dmm_channels.keys()): + raise ValueError( + "When defined, the names of channel IDs must be different" + " than the names of DMM channels 'dmm_0', 'dmm_1', ... ." + ) + else: # Make the channel IDs from the default IDs ids_counter: Counter = Counter() @@ -203,7 +225,7 @@ def to_tuple(obj: tuple | list) -> tuple: # Turns mutable lists into immutable tuples for param in self._params(): - if "channel" in param: + if "channel" in param or param == "dmm_objects": object.__setattr__(self, param, to_tuple(getattr(self, param))) @property @@ -216,6 +238,13 @@ def channels(self) -> dict[str, Channel]: """Dictionary of available channels on this device.""" return dict(zip(cast(tuple, self.channel_ids), self.channel_objects)) + @property + def dmm_channels(self) -> dict[str, DMM]: + """Dictionary of available DMM channels on this device.""" + return { + f"dmm_{i}": dmm_obj for (i, dmm_obj) in enumerate(self.dmm_objects) + } + @property def supported_bases(self) -> set[str]: """Available electronic transitions for control and measurement.""" @@ -420,18 +449,26 @@ def _to_dict(self) -> dict[str, Any]: @abstractmethod def _to_abstract_repr(self) -> dict[str, Any]: - ex_params = ("channel_objects", "channel_ids") defaults = get_dataclass_defaults(fields(self)) params = self._params() - for p in ex_params: - params.pop(p, None) for p in ALWAYS_OPTIONAL_PARAMS: if params[p] == defaults[p]: params.pop(p, None) ch_list = [] for ch_name, ch_obj in self.channels.items(): ch_list.append(ch_obj._to_abstract_repr(ch_name)) - return {"version": "1", "channels": ch_list, **params} + # Add version and channels to params + params.update({"version": "1", "channels": ch_list}) + dmm_list = [] + for dmm_name, dmm_obj in self.dmm_channels.items(): + dmm_list.append(dmm_obj._to_abstract_repr(dmm_name)) + # Add dmm channels if different than default + if "dmm_objects" in params: + params["dmm_channels"] = dmm_list + # Delete parameters of PARAMS_WITH_ABSTR_REPR in params + for p in PARAMS_WITH_ABSTR_REPR: + params.pop(p, None) + return params def to_abstract_repr(self) -> str: """Serializes the Sequence into an abstract JSON object.""" @@ -541,7 +578,7 @@ def _specs(self, for_docs: bool = False) -> str: ) ch_lines = ["\nChannels:"] - for name, ch in self.channels.items(): + for name, ch in {**self.channels, **self.dmm_channels}.items(): if for_docs: ch_lines += [ f" - ID: '{name}'", @@ -556,6 +593,12 @@ def _specs(self, for_docs: bool = False) -> str: "\t" + r"- Maximum :math:`|\delta|`:" + f" {ch.max_abs_detuning:.4g} rad/µs" + ) + if not isinstance(ch, DMM) + else ( + "\t" + + r"- Bottom :math:`|\delta|`:" + + f" {ch.bottom_detuning:.4g} rad/µs" ), f"\t- Minimum average amplitude: {ch.min_avg_amp} rad/µs", ] diff --git a/pulser-core/pulser/devices/_devices.py b/pulser-core/pulser/devices/_devices.py index 85eac8155..bbfe8f6d5 100644 --- a/pulser-core/pulser/devices/_devices.py +++ b/pulser-core/pulser/devices/_devices.py @@ -56,6 +56,15 @@ max_duration=2**26, ), ), + # TODO: Add DMM once it is supported for serialization + # dmm_objects=( + # DMM( + # clock_period=4, + # min_duration=16, + # max_duration=2**26, + # bottom_detuning=-20, + # ), + # ), ) IroiseMVP = Device( diff --git a/pulser-core/pulser/devices/_mock_device.py b/pulser-core/pulser/devices/_mock_device.py index 28c05eca8..206d04559 100644 --- a/pulser-core/pulser/devices/_mock_device.py +++ b/pulser-core/pulser/devices/_mock_device.py @@ -31,4 +31,5 @@ Raman.Local(None, None, max_duration=None), Microwave.Global(None, None, max_duration=None), ), + # TODO: Add DMM once it is supported for serialization ) diff --git a/pulser-core/pulser/json/abstract_repr/deserializer.py b/pulser-core/pulser/json/abstract_repr/deserializer.py index 8413d62a2..48d74360f 100644 --- a/pulser-core/pulser/json/abstract_repr/deserializer.py +++ b/pulser-core/pulser/json/abstract_repr/deserializer.py @@ -31,6 +31,7 @@ RydbergEOM, ) from pulser.devices import Device, VirtualDevice +from pulser.devices._device_datacls import PARAMS_WITH_ABSTR_REPR from pulser.json.abstract_repr.signatures import ( BINARY_OPERATORS, UNARY_OPERATORS, @@ -346,12 +347,19 @@ def _deserialize_device_object(obj: dict[str, Any]) -> Device | VirtualDevice: params: dict[str, Any] = dict( channel_ids=tuple(ch_ids), channel_objects=tuple(ch_objs) ) - ex_params = ("channel_objects", "channel_ids") + if "dmm_channels" in obj: + params["dmm_objects"] = tuple( + _deserialize_channel(dmm_ch) for dmm_ch in obj["dmm_channels"] + ) device_fields = dataclasses.fields(device_cls) device_defaults = get_dataclass_defaults(device_fields) for param in device_fields: use_default = param.name not in obj and param.name in device_defaults - if not param.init or param.name in ex_params or use_default: + if ( + not param.init + or param.name in PARAMS_WITH_ABSTR_REPR + or use_default + ): continue if param.name == "pre_calibrated_layouts": key = "pre_calibrated_layouts" diff --git a/pulser-core/pulser/json/abstract_repr/serializer.py b/pulser-core/pulser/json/abstract_repr/serializer.py index 5c2c0a3f5..c76e196d9 100644 --- a/pulser-core/pulser/json/abstract_repr/serializer.py +++ b/pulser-core/pulser/json/abstract_repr/serializer.py @@ -26,9 +26,9 @@ from pulser.json.abstract_repr.signatures import SIGNATURES from pulser.json.abstract_repr.validation import validate_abstract_repr from pulser.json.exceptions import AbstractReprError -from pulser.register.base_register import QubitId if TYPE_CHECKING: + from pulser.register.base_register import QubitId from pulser.sequence import Sequence from pulser.sequence._call import _Call diff --git a/pulser-core/pulser/register/_reg_drawer.py b/pulser-core/pulser/register/_reg_drawer.py index 7c6e3ae32..81a5cdeaf 100644 --- a/pulser-core/pulser/register/_reg_drawer.py +++ b/pulser-core/pulser/register/_reg_drawer.py @@ -18,14 +18,15 @@ from collections.abc import Mapping from collections.abc import Sequence as abcSequence from itertools import combinations -from typing import Optional +from typing import TYPE_CHECKING, Optional import matplotlib.pyplot as plt import numpy as np from matplotlib import collections as mc from scipy.spatial import KDTree -from pulser.register.base_register import QubitId +if TYPE_CHECKING: + from pulser.register.base_register import QubitId class RegDrawer: @@ -58,6 +59,7 @@ def _draw_2D( qubit_colors: Mapping[QubitId, str] = dict(), masked_qubits: set[QubitId] = set(), are_traps: bool = False, + dmm_qubits: Mapping[QubitId, float] = {}, ) -> None: ordered_qubit_colors = RegDrawer._compute_ordered_qubit_colors( ids, qubit_colors @@ -73,21 +75,34 @@ def _draw_2D( ax.scatter(pos[:, ix], pos[:, iy], alpha=0.7, **params) # Draw square halo around masked qubits - if masked_qubits: - mask_pos = [] + if ( + masked_qubits + and dmm_qubits + and masked_qubits != set(dmm_qubits.keys()) + ): + raise ValueError("masked qubits and dmm qubits must be the same.") + elif masked_qubits: + dmm_qubits = { + masked_qubit: 1.0 / len(masked_qubits) + for masked_qubit in masked_qubits + } + + if dmm_qubits: + dmm_pos = [] for i, c in zip(ids, pos): - if i in masked_qubits: - mask_pos.append(c) - mask_arr = np.array(mask_pos) + if i in dmm_qubits.keys(): + dmm_pos.append(c) + dmm_arr = np.array(dmm_pos) ax.scatter( - mask_arr[:, ix], - mask_arr[:, iy], + dmm_arr[:, ix], + dmm_arr[:, iy], marker="s", s=1200, - alpha=0.2, + alpha=0.2 + * np.array(list(dmm_qubits.values())) + / max(dmm_qubits.values()), c="black", ) - axes = "xyz" ax.set_xlabel(axes[ix] + " (µm)") @@ -106,6 +121,7 @@ def _draw_2D( i = 0 bbs = {} final_plot_ids: list[str] = [] + final_plot_det_map: list = [] while i < len(plot_ids): r = plot_pos[i] j = i + 1 @@ -120,26 +136,41 @@ def _draw_2D( else: j += 1 # Sort qubits in plot_ids[i] according to masked status + det_map = [ + q for q, weight in dmm_qubits.items() if weight > 0.0 + ] plot_ids[i] = sorted( plot_ids[i], - key=lambda s: s in [str(q) for q in masked_qubits], + key=lambda s: s in det_map, ) - # Merge all masked qubits - has_masked = False + # Merge all masked qubits with their detuning + # if the detunings are not all the same (masked qubits then) + has_det_map = False + is_mask = len(set([dmm_qubits[q] for q in det_map])) == 1 for j in range(len(plot_ids[i])): - if plot_ids[i][j] in [str(q) for q in masked_qubits]: - plot_ids[i][j:] = [", ".join(plot_ids[i][j:])] - has_masked = True + if plot_ids[i][j] in [str(q) for q in det_map]: + qubit_det = [] + for q in plot_ids[i][j:]: + extra_label = ( + f": {dmm_qubits[int(q)]:.2f}" + if not is_mask + else "" + ) + qubit_det.append(q + extra_label) + plot_ids[i][j:] = [", ".join(qubit_det)] + has_det_map = True break # Add a square bracket that encloses all masked qubits - if has_masked: + if has_det_map: plot_ids[i][-1] = "[" + plot_ids[i][-1] + "]" + # Lower the fontsize if detuning is shown (not a mask) + if not is_mask: + final_plot_det_map.append(i) # Merge what remains final_plot_ids.append(", ".join(plot_ids[i])) bbs[final_plot_ids[i]] = overlap i += 1 - - for q, coords in zip(final_plot_ids, plot_pos): + for i, (q, coords) in enumerate(zip(final_plot_ids, plot_pos)): bb = ( dict(boxstyle="square", fill=False, ec="gray", ls="--") if bbs[q] @@ -154,7 +185,7 @@ def _draw_2D( va=v_al, wrap=True, bbox=bb, - fontsize=12, + fontsize=12 if i not in final_plot_det_map else 8.3, multialignment="right", ) txt._get_wrap_line_width = lambda: 50.0 diff --git a/pulser-core/pulser/register/base_register.py b/pulser-core/pulser/register/base_register.py index 1062bf25d..473b10a87 100644 --- a/pulser-core/pulser/register/base_register.py +++ b/pulser-core/pulser/register/base_register.py @@ -33,6 +33,7 @@ from numpy.typing import ArrayLike from pulser.json.utils import obj_to_dict +from pulser.register.weight_maps import DetuningMap if TYPE_CHECKING: from pulser.register.register_layout import RegisterLayout @@ -204,6 +205,30 @@ def _validate_layout( " register's coordinates." ) + def define_detuning_map( + self, detuning_weights: Mapping[QubitId, float] + ) -> DetuningMap: + """Defines a DetuningMap for some qubits of the register. + + Args: + detuning_weights: A mapping between the IDs of the targeted qubits + and detuning weights (between 0 and 1, their sum must be equal + to 1). + + Returns: + A DetuningMap associating detuning weights to the trap coordinates + of the targeted qubits. + """ + if not set(detuning_weights.keys()) <= set(self.qubit_ids): + raise ValueError( + "The qubit ids linked to detuning weights have to be defined" + " in the register." + ) + return DetuningMap( + [self.qubits[qubit_id] for qubit_id in detuning_weights], + list(detuning_weights.values()), + ) + @abstractmethod def _to_dict(self) -> dict[str, Any]: """Serializes the object. diff --git a/pulser-core/pulser/register/mappable_reg.py b/pulser-core/pulser/register/mappable_reg.py index 29c789c2c..d11f5a06e 100644 --- a/pulser-core/pulser/register/mappable_reg.py +++ b/pulser-core/pulser/register/mappable_reg.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from pulser.register.base_register import BaseRegister, QubitId from pulser.register.register_layout import RegisterLayout + from pulser.register.weight_maps import DetuningMap class MappableRegister: @@ -122,6 +123,22 @@ def find_indices(self, id_list: abcSequence[QubitId]) -> list[int]: ) return [self.qubit_ids.index(id) for id in id_list] + def define_detuning_map( + self, detuning_weights: Mapping[int, float] + ) -> DetuningMap: + """Defines a DetuningMap for some trap ids of the register layout. + + Args: + detuning_weights: A mapping between the IDs of the targeted traps + and detuning weights (between 0 and 1, their sum must be equal + to 1). + + Returns: + A DetuningMap associating detuning weights to the trap coordinates + of the targeted traps. + """ + return self._layout.define_detuning_map(detuning_weights) + def _to_dict(self) -> dict[str, Any]: return obj_to_dict(self, self._layout, *self._qubit_ids) diff --git a/pulser-core/pulser/register/register_layout.py b/pulser-core/pulser/register/register_layout.py index a8aa8143a..22d0917b7 100644 --- a/pulser-core/pulser/register/register_layout.py +++ b/pulser-core/pulser/register/register_layout.py @@ -15,10 +15,12 @@ from __future__ import annotations +from collections.abc import Mapping from collections.abc import Sequence as abcSequence from dataclasses import dataclass from functools import cached_property from hashlib import sha256 +from operator import itemgetter from typing import Any, Optional, cast import matplotlib.pyplot as plt @@ -31,6 +33,7 @@ from pulser.register.mappable_reg import MappableRegister from pulser.register.register import Register from pulser.register.register3d import Register3D +from pulser.register.weight_maps import DetuningMap COORD_PRECISION = 6 @@ -185,6 +188,30 @@ def define_register( reg = reg_class(qubits, layout=self, trap_ids=trap_ids) return reg + def define_detuning_map( + self, detuning_weights: Mapping[int, float] + ) -> DetuningMap: + """Defines a DetuningMap for some trap ids of the register layout. + + Args: + detuning_weights: A mapping between the IDs of the targeted traps + and detuning weights (between 0 and 1, their sum must be equal + to 1). + + Returns: + A DetuningMap associating detuning weights to the trap coordinates + of the targeted traps. + """ + if not set(detuning_weights.keys()) <= set(self.traps_dict): + raise ValueError( + "The trap ids of detuning weights have to be integers" + f" between 0 and {self.number_of_traps}." + ) + return DetuningMap( + itemgetter(*detuning_weights.keys())(self.traps_dict), + list(detuning_weights.values()), + ) + def draw( self, blockade_radius: Optional[float] = None, diff --git a/pulser-core/pulser/register/weight_maps.py b/pulser-core/pulser/register/weight_maps.py new file mode 100644 index 000000000..2e3e1fd23 --- /dev/null +++ b/pulser-core/pulser/register/weight_maps.py @@ -0,0 +1,106 @@ +# Copyright 2023 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines weight maps on top of traps.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, cast + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from numpy.typing import ArrayLike + +from pulser.register._reg_drawer import RegDrawer + + +@dataclass +class WeightMap(RegDrawer): + """Defines a generic map of weights on traps. + + The sum of the provided weights must be equal to 1. + + Args: + trap_coordinates: An array containing the coordinates of the traps. + weights: A list weights to associate to the traps. + """ + + trap_coordinates: ArrayLike + weights: list[float] + + def __post_init__(self) -> None: + if len(cast(list, self.trap_coordinates)) != len(self.weights): + raise ValueError("Number of traps and weights don't match.") + if not np.all(np.array(self.weights) >= 0): + raise ValueError("All weights must be non-negative.") + if not np.isclose(sum(self.weights), 1.0, atol=1e-16): + raise ValueError("The sum of the weights should be 1.") + + def draw( + self, + with_labels: bool = True, + fig_name: str | None = None, + kwargs_savefig: dict = {}, + custom_ax: Optional[Axes] = None, + show: bool = True, + ) -> None: + """Draws the detuning map. + + Args: + with_labels: If True, writes the qubit ID's + next to each qubit. + fig_name: The name on which to save the figure. + If None the figure will not be saved. + kwargs_savefig: Keywords arguments for + ``matplotlib.pyplot.savefig``. Not applicable if `fig_name` + is ``None``. + custom_ax: If present, instead of creating its own Axes object, + the function will use the provided one. Warning: if fig_name + is set, it may save content beyond what is drawn in this + function. + show: Whether or not to call `plt.show()` before returning. When + combining this plot with other ones in a single figure, one may + need to set this flag to False. + """ + pos = np.array(self.trap_coordinates) + if custom_ax is None: + _, custom_ax = self._initialize_fig_axes(pos) + + super()._draw_2D( + custom_ax, + pos, + [i for i, _ in enumerate(cast(list, self.trap_coordinates))], + with_labels=with_labels, + dmm_qubits=dict(enumerate(self.weights)), + ) + + if fig_name is not None: + plt.savefig(fig_name, **kwargs_savefig) + + if show: + plt.show() + + +@dataclass +class DetuningMap(WeightMap): + """Defines a DetuningMap. + + A DetuningMap associates a detuning weight to the coordinates of a trap. + The sum of the provided weights must be equal to 1. + + Args: + trap_coordinates: an array containing the coordinates of the traps. + weights: A list of detuning weights to associate to the traps. + """ diff --git a/pulser-core/pulser/sequence/_schedule.py b/pulser-core/pulser/sequence/_schedule.py index f2cad62ed..816c5c912 100644 --- a/pulser-core/pulser/sequence/_schedule.py +++ b/pulser-core/pulser/sequence/_schedule.py @@ -24,6 +24,7 @@ from pulser.channels.base_channel import Channel from pulser.pulse import Pulse from pulser.register.base_register import QubitId +from pulser.register.weight_maps import DetuningMap from pulser.sampler.samples import ChannelSamples, _PulseTargetSlot from pulser.waveforms import ConstantWaveform @@ -247,6 +248,11 @@ def __iter__(self) -> Iterator[_TimeSlot]: yield slot +@dataclass +class _DMMSchedule(_ChannelSchedule): + detuning_map: DetuningMap + + class _Schedule(Dict[str, _ChannelSchedule]): def __init__(self, max_duration: int | None = None): self.max_duration = max_duration diff --git a/tests/test_abstract_repr.py b/tests/test_abstract_repr.py index e37afab4e..1e7307cda 100644 --- a/tests/test_abstract_repr.py +++ b/tests/test_abstract_repr.py @@ -27,6 +27,7 @@ from pulser import Pulse, Register, Register3D, Sequence, devices from pulser.channels import Rydberg +from pulser.channels.dmm import DMM from pulser.channels.eom import RydbergBeam, RydbergEOM from pulser.devices import AnalogDevice, Chadoq2, Device, IroiseMVP, MockDevice from pulser.json.abstract_repr.deserializer import ( @@ -235,6 +236,46 @@ def test_optional_channel_fields(self, ch_obj): dev_str = device.to_abstract_repr() assert device == deserialize_device(dev_str) + @pytest.fixture + def chadoq2_with_dmm(self): + # TODO: Delete once Chadoq2 actually has a DMM + dmm = DMM( + bottom_detuning=-1, + clock_period=1, + min_duration=1, + max_duration=1e6, + mod_bandwidth=20, + ) + return replace(Chadoq2, dmm_objects=(dmm,)) + + @pytest.mark.xfail( + raises=jsonschema.exceptions.ValidationError, strict=True + ) + def test_abstract_repr_dmm_serialize(self, chadoq2_with_dmm): + chadoq2_with_dmm.to_abstract_repr() + + @pytest.mark.xfail(raises=DeserializeDeviceError, strict=True) + @pytest.mark.parametrize( + "skip_validation", + [ + False, # Fails validation + True, # Fails because the DMM channel is deserialized as Rydberg + ], + ) + def test_abstract_repr_dmm_deserialize( + self, chadoq2_with_dmm, monkeypatch, skip_validation + ): + ser_device = json.dumps(chadoq2_with_dmm, cls=AbstractReprEncoder) + if skip_validation: + + def dummy(*args, **kwargs): + return True + + # Patches jsonschema.validate with a function that returns True + monkeypatch.setattr(jsonschema, "validate", dummy) + device = deserialize_device(ser_device) + assert device == chadoq2_with_dmm + def validate_schema(instance): with open( diff --git a/tests/test_devices.py b/tests/test_devices.py index c9981bdc8..1214aa320 100644 --- a/tests/test_devices.py +++ b/tests/test_devices.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from dataclasses import FrozenInstanceError +from dataclasses import FrozenInstanceError, replace from unittest.mock import patch import numpy as np @@ -21,6 +21,7 @@ import pulser from pulser.channels import Microwave, Raman, Rydberg +from pulser.channels.dmm import DMM from pulser.devices import Chadoq2, Device, VirtualDevice from pulser.register import Register, Register3D from pulser.register.register_layout import RegisterLayout @@ -74,6 +75,11 @@ def test_params(): "'interaction_coeff_xy' must be a 'float'," " not ''.", ), + ( + "dmm_objects", + ("DMM(bottom_detuning=-1)",), + "All DMM channels must be of type 'DMM', not 'str'", + ), ("max_sequence_duration", 1.02, None), ("max_runs", 1e8, None), ], @@ -130,6 +136,16 @@ def test_post_init_value_errors(test_params, param, value, msg): VirtualDevice(**test_params) +# TODO: Add test of comptability SLM-DMM once DMM is added for serialization +# def test_post_init_slm_dmm_compatibility(test_params): +# test_params["supports_slm_mask"] = True +# test_params["dmm_objects"] = () +# with pytest.raises(ValueError, +# match="One DMM object should be defined to support SLM mask." +# ): +# VirtualDevice(**test_params) + + potential_params = ["max_atom_num", "max_radial_distance"] always_none_allowed = ["max_sequence_duration", "max_runs"] @@ -384,3 +400,31 @@ def test_device_params(): assert set(all_params) - set(all_virtual_params) == { "pre_calibrated_layouts" } + + +def test_dmm_channels(): + dmm = DMM( + bottom_detuning=-1, + clock_period=1, + min_duration=1, + max_duration=1e6, + mod_bandwidth=20, + ) + device = replace(Chadoq2, dmm_objects=(dmm,)) + assert len(device.dmm_channels) == 1 + assert device.dmm_channels["dmm_0"] == dmm + with pytest.raises( + ValueError, + match=( + "When defined, the names of channel IDs must be different" + " than the names of DMM channels 'dmm_0', 'dmm_1', ... ." + ), + ): + device = replace( + Chadoq2, + dmm_objects=(dmm,), + channel_objects=(Rydberg.Global(None, None),), + channel_ids=("dmm_0",), + ) + assert not dmm.is_virtual() + assert DMM().is_virtual() diff --git a/tests/test_dmm.py b/tests/test_dmm.py new file mode 100644 index 000000000..22055dd63 --- /dev/null +++ b/tests/test_dmm.py @@ -0,0 +1,193 @@ +# Copyright 2020 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import cast +from unittest.mock import patch + +import numpy as np +import pytest + +from pulser.channels.dmm import DMM +from pulser.register.base_register import BaseRegister +from pulser.register.mappable_reg import MappableRegister +from pulser.register.register_layout import RegisterLayout +from pulser.register.weight_maps import DetuningMap + + +@pytest.fixture +def layout() -> RegisterLayout: + return RegisterLayout([[0, 0], [1, 0], [0, 1], [1, 1]]) + + +@pytest.fixture +def register(layout: RegisterLayout) -> BaseRegister: + return layout.define_register(0, 1, 2, 3, qubit_ids=(0, 1, 2, 3)) + + +@pytest.fixture +def map_reg(layout: RegisterLayout) -> MappableRegister: + return layout.make_mappable_register(4) + + +@pytest.fixture +def det_dict() -> dict[int, float]: + return {0: 0.7, 1: 0.3, 2: 0} + + +@pytest.fixture +def det_map(layout: RegisterLayout, det_dict: dict[int, float]) -> DetuningMap: + return layout.define_detuning_map(det_dict) + + +@pytest.fixture +def slm_dict() -> dict[int, float]: + return {0: 1 / 3, 1: 1 / 3, 2: 1 / 3} + + +@pytest.fixture +def slm_map(layout: RegisterLayout, slm_dict: dict[int, float]) -> DetuningMap: + return layout.define_detuning_map(slm_dict) + + +@pytest.mark.parametrize("bad_key", [{"1": 1.0}, {4: 1.0}]) +def test_define_detuning_att( + layout: RegisterLayout, + register: BaseRegister, + map_reg: MappableRegister, + bad_key: dict, +): + for reg in (layout, map_reg): + with pytest.raises( + ValueError, + match=( + "The trap ids of detuning weights have to be integers" + " between 0 and 4" + ), + ): + reg.define_detuning_map(bad_key) # type: ignore + with pytest.raises( + ValueError, + match=( + "The qubit ids linked to detuning weights have to be defined in" + " the register." + ), + ): + register.define_detuning_map(bad_key) + + +def test_bad_init( + layout: RegisterLayout, + register: BaseRegister, + map_reg: MappableRegister, +): + with pytest.raises( + ValueError, match="Number of traps and weights don't match." + ): + DetuningMap([(0, 0), (1, 0)], [0]) + + bad_weights = {0: -1.0, 1: 1.0, 2: 1.0} + bad_sum = {0: 0.1, 2: 0.9, 3: 0.1} + for reg in (layout, map_reg, register): + with pytest.raises( + ValueError, match="All weights must be non-negative." + ): + reg.define_detuning_map(bad_weights) # type: ignore + with pytest.raises( + ValueError, match="The sum of the weights should be 1." + ): + reg.define_detuning_map(bad_sum) # type: ignore + + +def test_init( + layout: RegisterLayout, + register: BaseRegister, + map_reg: MappableRegister, + det_dict: dict[int, float], + slm_dict: dict[int, float], +): + for reg in (layout, map_reg, register): + for detuning_map_dict in (det_dict, slm_dict): + detuning_map = cast( + DetuningMap, + reg.define_detuning_map(detuning_map_dict), # type: ignore + ) + assert np.all( + [ + detuning_map_dict[i] == detuning_map.weights[i] + for i in range(len(detuning_map_dict)) + ] + ) + assert np.all( + [ + layout.coords[i] + == np.array(detuning_map.trap_coordinates)[i] + for i in range(len(detuning_map_dict)) + ] + ) + + +def test_draw(det_map, slm_map, patch_plt_show): + for detuning_map in (det_map, slm_map): + detuning_map.draw(with_labels=True, show=True, custom_ax=None) + with patch("matplotlib.pyplot.savefig"): + detuning_map.draw(fig_name="det_map.pdf") + with pytest.raises( + ValueError, match="masked qubits and dmm qubits must be the same." + ): + slm_map._draw_2D( + slm_map._initialize_fig_axes(np.array(slm_map.trap_coordinates))[ + 1 + ], + np.array(slm_map.trap_coordinates), + [i for i, _ in enumerate(cast(list, slm_map.trap_coordinates))], + with_labels=True, + dmm_qubits=dict(enumerate(slm_map.weights)), + masked_qubits={ + 1, + }, + ) + + +def test_DMM(): + dmm = DMM( + bottom_detuning=-1, + clock_period=1, + min_duration=1, + max_duration=1e6, + mod_bandwidth=20, + ) + assert dmm.basis == "ground-rydberg" + assert dmm.addressing == "Global" + assert dmm.bottom_detuning == -1 + assert dmm.max_amp == 1e-16 + for value in ( + dmm.max_abs_detuning, + dmm.min_retarget_interval, + dmm.fixed_retarget_t, + dmm.max_targets, + ): + assert value is None + with pytest.raises(ValueError, match="bottom_detuning must be negative."): + DMM(bottom_detuning=1) + with pytest.raises( + NotImplementedError, + match=f"{DMM} cannot be initialized from `Global` method.", + ): + DMM.Global(None, None, bottom_detuning=1) + with pytest.raises( + NotImplementedError, + match=f"{DMM} cannot be initialized from `Local` method.", + ): + DMM.Local(None, None, bottom_detuning=1)