Skip to content

Commit

Permalink
Adding DetuningMap, DMM (#539)
Browse files Browse the repository at this point in the history
* adding dmm, draw to DetuningMap

* Testing avoiding circular import

* Refactoring to avoid circular imports

* Fix broken UTs

* Import sorting

* Fixing plot DetuningMap

* Serialization/Deserialization of DMM in device

* Fixing typos

* Testing DMM and DetuningMap

* Fixing typo

* adding xfails, fixing type

* Remove DMM from devices and finish UTs

* Take into account review comments

* Add annotations

* Error in Global and Local

* Defining _DMMSchedule

* Fixing nits

* Fixing typo

---------

Co-authored-by: HGSilveri <[email protected]>
  • Loading branch information
a-corni and HGSilveri authored Jul 24, 2023
1 parent 5db7595 commit d5ac020
Show file tree
Hide file tree
Showing 18 changed files with 677 additions and 35 deletions.
1 change: 1 addition & 0 deletions pulser-core/pulser/channels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
"""The various hardware channel types."""

from pulser.channels.channels import Microwave, Raman, Rydberg
from pulser.channels.dmm import DMM
20 changes: 18 additions & 2 deletions pulser-core/pulser/channels/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from dataclasses import MISSING, dataclass, field, fields
from typing import Any, Literal, Optional, Type, TypeVar, cast

import numpy as np
Expand Down Expand Up @@ -94,7 +94,7 @@ def basis(self) -> str:
def _internal_param_valid_options(self) -> dict[str, tuple[str, ...]]:
"""Internal parameters and their valid options."""
return dict(
name=("Rydberg", "Raman", "Microwave"),
name=("Rydberg", "Raman", "Microwave", "DMM"),
basis=("ground-rydberg", "digital", "XY"),
addressing=("Local", "Global"),
)
Expand Down Expand Up @@ -262,6 +262,14 @@ def Local(
min_avg_amp: The minimum average amplitude of a pulse (when not
zero).
"""
# Can't initialize a channel whose addressing is determined internally
for cls_field in fields(cls):
if cls_field.name == "addressing":
break
if not cls_field.init and cls_field.default is not MISSING:
raise NotImplementedError(
f"{cls} cannot be initialized from `Local` method."
)
return cls(
"Local",
max_abs_detuning,
Expand Down Expand Up @@ -299,6 +307,14 @@ def Global(
min_avg_amp: The minimum average amplitude of a pulse (when not
zero).
"""
# Can't initialize a channel whose addressing is determined internally
for cls_field in fields(cls):
if cls_field.name == "addressing":
break
if not cls_field.init and cls_field.default is not MISSING:
raise NotImplementedError(
f"{cls} cannot be initialized from `Global` method."
)
return cls("Global", max_abs_detuning, max_amp, **kwargs)

def validate_duration(self, duration: int) -> int:
Expand Down
2 changes: 1 addition & 1 deletion pulser-core/pulser/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Rydberg(Channel):
"""Rydberg beam channel.
Channel targeting the transition between the ground and rydberg states,
thus enconding the 'ground-rydberg' basis. See base class.
thus encoding the 'ground-rydberg' basis. See base class.
"""

eom_config: Optional[RydbergEOM] = None
Expand Down
74 changes: 74 additions & 0 deletions pulser-core/pulser/channels/dmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2023 Pulser Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the detuning map modulator."""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal, Optional

from pulser.channels.base_channel import Channel


@dataclass(init=True, repr=False, frozen=True)
class DMM(Channel):
"""Defines a Detuning Map Modulator (DMM) Channel.
A Detuning Map Modulator can be used to define `Global` detuning Pulses
(of zero amplitude and phase). These Pulses are locally modulated by the
weights of a `DetuningMap`, thus providing a local control over the
detuning. The detuning of the pulses added to a DMM has to be negative,
between 0 and `bottom_detuning`. Channel targeting the transition between
the ground and rydberg states, thus encoding the 'ground-rydberg' basis.
Note:
The protocol to add pulses to the DMM Channel is by default
"no-delay".
Args:
bottom_detuning: Minimum possible detuning (in rad/µs), must be below
zero.
clock_period: The duration of a clock cycle (in ns). The duration of a
pulse or delay instruction is enforced to be a multiple of the
clock cycle.
min_duration: The shortest duration an instruction can take.
max_duration: The longest duration an instruction can take.
min_avg_amp: The minimum average amplitude of a pulse (when not zero).
mod_bandwidth: The modulation bandwidth at -3dB (50% reduction), in
MHz.
"""

bottom_detuning: Optional[float] = field(default=None, init=True)
addressing: Literal["Global"] = field(default="Global", init=False)
max_abs_detuning: Optional[float] = field(init=False, default=None)
max_amp: float = field(default=1e-16, init=False) # can't be 0
min_retarget_interval: Optional[int] = field(init=False, default=None)
fixed_retarget_t: Optional[int] = field(init=False, default=None)
max_targets: Optional[int] = field(init=False, default=None)

def __post_init__(self) -> None:
super().__post_init__()
if self.bottom_detuning and self.bottom_detuning > 0:
raise ValueError("bottom_detuning must be negative.")

@property
def basis(self) -> Literal["ground-rydberg"]:
"""The addressed basis name."""
return "ground-rydberg"

def _undefined_fields(self) -> list[str]:
optional = [
"bottom_detuning",
"max_duration",
]
return [field for field in optional if getattr(self, field) is None]
57 changes: 50 additions & 7 deletions pulser-core/pulser/devices/_device_datacls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from scipy.spatial.distance import pdist, squareform

from pulser.channels.base_channel import Channel
from pulser.channels.dmm import DMM
from pulser.devices.interaction_coefficients import c6_dict
from pulser.json.abstract_repr.serializer import AbstractReprEncoder
from pulser.json.abstract_repr.validation import validate_abstract_repr
Expand All @@ -34,7 +35,8 @@

DIMENSIONS = Literal[2, 3]

ALWAYS_OPTIONAL_PARAMS = ("max_sequence_duration", "max_runs")
ALWAYS_OPTIONAL_PARAMS = ("max_sequence_duration", "max_runs", "dmm_objects")
PARAMS_WITH_ABSTR_REPR = ("channel_objects", "channel_ids", "dmm_objects")


@dataclass(frozen=True, repr=False)
Expand All @@ -49,6 +51,9 @@ class BaseDevice(ABC):
channel_ids: Custom IDs for each channel object. When defined,
an ID must be given for each channel. If not defined, the IDs are
generated internally based on the channels' names and addressing.
dmm_objects: The DMM subclass instances specifying each channel in the
device. They are referenced by their order in the list, with the ID
"dmm_[index in dmm_objects]".
rybderg_level: The value of the principal quantum number :math:`n`
when the Rydberg level used is of the form
:math:`|nS_{1/2}, m_j = +1/2\rangle`.
Expand Down Expand Up @@ -83,6 +88,7 @@ class BaseDevice(ABC):
reusable_channels: bool = field(default=False, init=False)
channel_ids: tuple[str, ...] | None = None
channel_objects: tuple[Channel, ...] = field(default_factory=tuple)
dmm_objects: tuple[DMM, ...] = field(default_factory=tuple)

def __post_init__(self) -> None:
def type_check(
Expand Down Expand Up @@ -155,6 +161,16 @@ def type_check(
for ch_obj in self.channel_objects:
type_check("All channels", Channel, value_override=ch_obj)

for dmm_obj in self.dmm_objects:
type_check("All DMM channels", DMM, value_override=dmm_obj)

# TODO: Check that device has dmm objects if it supports SLM mask
# once DMM is supported for serialization
# if self.supports_slm_mask and not self.dmm_objects:
# raise ValueError(
# "One DMM object should be defined to support SLM mask."
# )

if self.channel_ids is not None:
if not (
isinstance(self.channel_ids, (tuple, list))
Expand All @@ -174,6 +190,12 @@ def type_check(
"When defined, the number of channel IDs must"
" match the number of channel objects."
)
if set(self.channel_ids) & set(self.dmm_channels.keys()):
raise ValueError(
"When defined, the names of channel IDs must be different"
" than the names of DMM channels 'dmm_0', 'dmm_1', ... ."
)

else:
# Make the channel IDs from the default IDs
ids_counter: Counter = Counter()
Expand Down Expand Up @@ -203,7 +225,7 @@ def to_tuple(obj: tuple | list) -> tuple:

# Turns mutable lists into immutable tuples
for param in self._params():
if "channel" in param:
if "channel" in param or param == "dmm_objects":
object.__setattr__(self, param, to_tuple(getattr(self, param)))

@property
Expand All @@ -216,6 +238,13 @@ def channels(self) -> dict[str, Channel]:
"""Dictionary of available channels on this device."""
return dict(zip(cast(tuple, self.channel_ids), self.channel_objects))

@property
def dmm_channels(self) -> dict[str, DMM]:
"""Dictionary of available DMM channels on this device."""
return {
f"dmm_{i}": dmm_obj for (i, dmm_obj) in enumerate(self.dmm_objects)
}

@property
def supported_bases(self) -> set[str]:
"""Available electronic transitions for control and measurement."""
Expand Down Expand Up @@ -420,18 +449,26 @@ def _to_dict(self) -> dict[str, Any]:

@abstractmethod
def _to_abstract_repr(self) -> dict[str, Any]:
ex_params = ("channel_objects", "channel_ids")
defaults = get_dataclass_defaults(fields(self))
params = self._params()
for p in ex_params:
params.pop(p, None)
for p in ALWAYS_OPTIONAL_PARAMS:
if params[p] == defaults[p]:
params.pop(p, None)
ch_list = []
for ch_name, ch_obj in self.channels.items():
ch_list.append(ch_obj._to_abstract_repr(ch_name))
return {"version": "1", "channels": ch_list, **params}
# Add version and channels to params
params.update({"version": "1", "channels": ch_list})
dmm_list = []
for dmm_name, dmm_obj in self.dmm_channels.items():
dmm_list.append(dmm_obj._to_abstract_repr(dmm_name))
# Add dmm channels if different than default
if "dmm_objects" in params:
params["dmm_channels"] = dmm_list
# Delete parameters of PARAMS_WITH_ABSTR_REPR in params
for p in PARAMS_WITH_ABSTR_REPR:
params.pop(p, None)
return params

def to_abstract_repr(self) -> str:
"""Serializes the Sequence into an abstract JSON object."""
Expand Down Expand Up @@ -541,7 +578,7 @@ def _specs(self, for_docs: bool = False) -> str:
)

ch_lines = ["\nChannels:"]
for name, ch in self.channels.items():
for name, ch in {**self.channels, **self.dmm_channels}.items():
if for_docs:
ch_lines += [
f" - ID: '{name}'",
Expand All @@ -556,6 +593,12 @@ def _specs(self, for_docs: bool = False) -> str:
"\t"
+ r"- Maximum :math:`|\delta|`:"
+ f" {ch.max_abs_detuning:.4g} rad/µs"
)
if not isinstance(ch, DMM)
else (
"\t"
+ r"- Bottom :math:`|\delta|`:"
+ f" {ch.bottom_detuning:.4g} rad/µs"
),
f"\t- Minimum average amplitude: {ch.min_avg_amp} rad/µs",
]
Expand Down
9 changes: 9 additions & 0 deletions pulser-core/pulser/devices/_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
max_duration=2**26,
),
),
# TODO: Add DMM once it is supported for serialization
# dmm_objects=(
# DMM(
# clock_period=4,
# min_duration=16,
# max_duration=2**26,
# bottom_detuning=-20,
# ),
# ),
)

IroiseMVP = Device(
Expand Down
1 change: 1 addition & 0 deletions pulser-core/pulser/devices/_mock_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@
Raman.Local(None, None, max_duration=None),
Microwave.Global(None, None, max_duration=None),
),
# TODO: Add DMM once it is supported for serialization
)
12 changes: 10 additions & 2 deletions pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RydbergEOM,
)
from pulser.devices import Device, VirtualDevice
from pulser.devices._device_datacls import PARAMS_WITH_ABSTR_REPR
from pulser.json.abstract_repr.signatures import (
BINARY_OPERATORS,
UNARY_OPERATORS,
Expand Down Expand Up @@ -346,12 +347,19 @@ def _deserialize_device_object(obj: dict[str, Any]) -> Device | VirtualDevice:
params: dict[str, Any] = dict(
channel_ids=tuple(ch_ids), channel_objects=tuple(ch_objs)
)
ex_params = ("channel_objects", "channel_ids")
if "dmm_channels" in obj:
params["dmm_objects"] = tuple(
_deserialize_channel(dmm_ch) for dmm_ch in obj["dmm_channels"]
)
device_fields = dataclasses.fields(device_cls)
device_defaults = get_dataclass_defaults(device_fields)
for param in device_fields:
use_default = param.name not in obj and param.name in device_defaults
if not param.init or param.name in ex_params or use_default:
if (
not param.init
or param.name in PARAMS_WITH_ABSTR_REPR
or use_default
):
continue
if param.name == "pre_calibrated_layouts":
key = "pre_calibrated_layouts"
Expand Down
2 changes: 1 addition & 1 deletion pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from pulser.json.abstract_repr.signatures import SIGNATURES
from pulser.json.abstract_repr.validation import validate_abstract_repr
from pulser.json.exceptions import AbstractReprError
from pulser.register.base_register import QubitId

if TYPE_CHECKING:
from pulser.register.base_register import QubitId
from pulser.sequence import Sequence
from pulser.sequence._call import _Call

Expand Down
Loading

0 comments on commit d5ac020

Please sign in to comment.