Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: rust extension interface #663

Open
wants to merge 39 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b43e303
Merge `append_child` with `add_measurement` where possible
shumpohl Feb 22, 2022
243d1f7
Insert extra loop for SequencePT and ForLoopPT
shumpohl Feb 22, 2022
d688af9
Fix error in sequencept create_program
shumpohl Feb 22, 2022
f241d56
Fix first test
shumpohl Feb 22, 2022
04f7d8d
Use qupulse_rs WIP
shumpohl Feb 24, 2022
d6c7a87
Unify empty loop drop logic
shumpohl Feb 25, 2022
02a6633
Adjust tests to changes
shumpohl Feb 25, 2022
f878b70
Increase Loop debuggability by creating a more correct __repr__
shumpohl Feb 25, 2022
f2b2abf
Prepare Loop replacement
shumpohl Feb 25, 2022
f8a0e6e
Merge branch 'issues/647_loop_cleanup' into feat/qupulse_rs
shumpohl Feb 25, 2022
a3f6a67
Better docs and test
shumpohl May 13, 2022
4d523b6
Merge branch 'master' into feat/qupulse_rs
shumpohl Jun 29, 2022
18e91ee
Conditional rust reüplacement import
shumpohl Jun 29, 2022
cb0e78a
Improve testability with rust extension
shumpohl Jun 29, 2022
6b6ba18
Fix more tests by relaxing the assumptions
shumpohl Jun 29, 2022
f41f9ef
Some constant pulse template generalizations
shumpohl Jun 29, 2022
cbe8532
Skip test in python 3.7
shumpohl Jun 29, 2022
f9d339a
Fix wrong legacy import
shumpohl Jun 29, 2022
5947765
Move Loop -> to_single_waveform code
shumpohl Jun 30, 2022
0570c79
Make some create_program tests ProgramBuilder aware
shumpohl Jun 30, 2022
7a2fcee
Make Program runtime_chackable
shumpohl Jul 12, 2022
52c13d4
Fix SubsetWaveform constant_value_dict
shumpohl Jul 12, 2022
375f404
Move compability and waveform trafo code to Loop class
shumpohl Jul 12, 2022
b893752
Add more rust waveforms
shumpohl Jul 12, 2022
8cd42d7
Fix tests
shumpohl Jul 12, 2022
1cf69be
Merge remote-tracking branch 'qutechlab/feat/qupulse_rs' into feat/qu…
shumpohl Jul 12, 2022
39a55c7
Add expressions and scopes from rust extension
shumpohl Sep 6, 2022
d8d7067
Use equality oeprator semantics of TimeType.__value for retry on NotI…
shumpohl Sep 6, 2022
ac9a85a
Use duck typing for AnonymousSerializable serialization
shumpohl Sep 6, 2022
148b43d
Add matplotlib to test requirements
Sep 7, 2022
3881b28
rework replacement code
Sep 16, 2022
a9168cb
Do not use expression sympy interface if not required
shumpohl Sep 19, 2022
639d552
Less usage of expression internals
shumpohl Sep 20, 2022
e453670
Custom subclass check for expressions
shumpohl Sep 20, 2022
87e45e6
Fix, improve and cleanup tests
shumpohl Sep 20, 2022
af743a2
Make more tests rust extension friendly
shumpohl Sep 20, 2022
9613cf6
Fix all tests
shumpohl Sep 20, 2022
af7f000
Merge remote-tracking branch 'qutech/master' into HEAD
Nov 14, 2022
b5b994c
Merge branch 'master' into feat/qupulse_rs
Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions qupulse/_program/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,79 @@
"""This is a private package meaning there are no stability guarantees."""
from abc import ABC, abstractmethod
from typing import Optional, Union, Sequence, ContextManager, Mapping

import numpy as np

from qupulse._program.waveforms import Waveform
from qupulse.utils.types import MeasurementWindow, TimeType
from qupulse._program.volatile import VolatileRepetitionCount

try:
import qupulse_rs
except ImportError:
qupulse_rs = None
RsProgramBuilder = None
else:
from qupulse_rs.replacements import ProgramBuilder as RsProgramBuilder

try:
from typing import Protocol, runtime_checkable
except ImportError:
Protocol = object

def runtime_checkable(cls):
return cls


RepetitionCount = Union[int, VolatileRepetitionCount]


@runtime_checkable
class Program(Protocol):
"""This protocol is used to inspect and or manipulate programs"""

def to_single_waveform(self) -> Waveform:
pass

def get_measurement_windows(self) -> Mapping[str, np.ndarray]:
pass

@property
def duration(self) -> TimeType:
raise NotImplementedError()

def make_compatible_inplace(self):
# TODO: rename?
pass


class ProgramBuilder(Protocol):
"""This protocol is used by PulseTemplate to build the program."""

def append_leaf(self, waveform: Waveform,
measurements: Optional[Sequence[MeasurementWindow]] = None,
repetition_count: int = 1):
pass

def potential_child(self, measurements: Optional[Sequence[MeasurementWindow]],
repetition_count: Union[VolatileRepetitionCount, int] = 1) -> ContextManager['ProgramBuilder']:
"""

Args:
measurements: Measurements to attach to the potential child. Is not repeated with repetition_count.
repetition_count:

Returns:

"""

def to_program(self) -> Optional[Program]:
pass


def default_program_builder() -> ProgramBuilder:
if RsProgramBuilder is None:
from qupulse._program._loop import Loop
return Loop()
else:
return RsProgramBuilder()
155 changes: 108 additions & 47 deletions qupulse/_program/_loop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union, Dict, Iterable, Tuple, cast, List, Optional, Generator, Mapping
import contextlib
from typing import Union, Dict, Iterable, Tuple, cast, List, Optional, Generator, Mapping, ContextManager, Sequence
from collections import defaultdict
from enum import Enum
import warnings
Expand All @@ -15,6 +16,7 @@
from qupulse.utils.tree import Node, is_tree_circular
from qupulse.utils.numeric import smallest_factor_ge

from qupulse._program import ProgramBuilder, Program
from qupulse._program.waveforms import SequenceWaveform, RepetitionWaveform

__all__ = ['Loop', 'make_compatible', 'MakeCompatibleWarning']
Expand Down Expand Up @@ -101,6 +103,9 @@ def add_measurements(self, measurements: Iterable[MeasurementWindow]):
Args:
measurements: Measurements to add
"""
warnings.warn("Loop.add_measurements is deprecated since qupulse 0.7 and will be removed in a future version.",
DeprecationWarning,
stacklevel=2)
body_duration = float(self.body_duration)
if body_duration == 0:
measurements = measurements
Expand Down Expand Up @@ -198,23 +203,47 @@ def encapsulate(self) -> None:
self._measurements = None
self.assert_tree_integrity()

def _get_repr(self, first_prefix, other_prefixes) -> Generator[str, None, None]:
def __repr__(self):
kwargs = []

repetition_count = self._repetition_definition
if repetition_count != 1:
kwargs.append(f"repetition_count={repetition_count!r}")

waveform = self._waveform
if waveform:
kwargs.append(f"waveform={waveform!r}")

children = self.children
if children:
try:
kwargs.append(f"children={self._children_repr()}")
except RecursionError:
kwargs.append("children=[...]")

measurements = self._measurements
if measurements:
kwargs.append(f"measurements={measurements!r}")

return f"Loop({','.join(kwargs)})"

def _get_str(self, first_prefix, other_prefixes) -> Generator[str, None, None]:
if self.is_leaf():
yield '%sEXEC %r %d times' % (first_prefix, self._waveform, self.repetition_count)
else:
yield '%sLOOP %d times:' % (first_prefix, self.repetition_count)

for elem in self:
yield from cast(Loop, elem)._get_repr(other_prefixes + ' ->', other_prefixes + ' ')
yield from cast(Loop, elem)._get_str(other_prefixes + ' ->', other_prefixes + ' ')

def __repr__(self) -> str:
def __str__(self) -> str:
is_circular = is_tree_circular(self)
if is_circular:
return '{}: Circ {}'.format(id(self), is_circular)

str_len = 0
repr_list = []
for sub_repr in self._get_repr('', ''):
for sub_repr in self._get_str('', ''):
str_len += len(sub_repr)

if self.MAX_REPR_SIZE and str_len > self.MAX_REPR_SIZE:
Expand Down Expand Up @@ -404,6 +433,21 @@ def _merge_single_child(self):
self._invalidate_duration()
return True

@contextlib.contextmanager
def potential_child(self,
measurements: Optional[List[MeasurementWindow]],
repetition_count: Union[VolatileRepetitionCount, int] = 1) -> ContextManager['Loop']:
if repetition_count != 1 and measurements:
# current design requires an extra level of nesting here because the measurements are NOT to be repeated
# with the repetition count
inner_child = Loop(repetition_count=repetition_count)
child = Loop(measurements=measurements, children=[inner_child])
else:
inner_child = child = Loop(measurements=measurements, repetition_count=repetition_count)
yield inner_child
if inner_child.waveform or len(inner_child):
self.append_child(child)

def cleanup(self, actions=('remove_empty_loops', 'merge_single_child')):
"""Apply the specified actions to cleanup the Loop.

Expand Down Expand Up @@ -451,6 +495,32 @@ def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]:
else:
return self.repetition_count, tuple(child.get_duration_structure() for child in self)

def to_single_waveform(self) -> Waveform:
if self.is_leaf():
if self.repetition_count == 1:
return self.waveform
else:
return RepetitionWaveform.from_repetition_count(self.waveform, self.repetition_count)
else:
if len(self) == 1:
sequenced_waveform = to_waveform(cast(Loop, self[0]))
else:
sequenced_waveform = SequenceWaveform.from_sequence([to_waveform(cast(Loop, sub_program))
for sub_program in self])
if self.repetition_count > 1:
return RepetitionWaveform.from_repetition_count(sequenced_waveform, self.repetition_count)
else:
return sequenced_waveform

def append_leaf(self, waveform: Waveform,
measurements: Optional[Sequence[MeasurementWindow]] = None,
repetition_count: int = 1):
self.append_child(waveform=waveform, measurements=measurements, repetition_count=repetition_count)

def to_program(self) -> Optional['Loop']:
if self.waveform or self.children:
return self

def reverse_inplace(self):
if self.is_leaf():
self._waveform = self._waveform.reversed()
Expand All @@ -465,29 +535,45 @@ def reverse_inplace(self):
for name, begin, length in self._measurements
]

def make_compatible_inplace(self, minimal_waveform_length: int, waveform_quantum: int, sample_rate: TimeType):
program = self
comp_level = _is_compatible(program,
min_len=minimal_waveform_length,
quantum=waveform_quantum,
sample_rate=sample_rate)
if comp_level == _CompatibilityLevel.incompatible_fraction:
raise ValueError(
'The program duration in samples {} is not an integer'.format(program.duration * sample_rate))
if comp_level == _CompatibilityLevel.incompatible_too_short:
raise ValueError('The program is too short to be a valid waveform. \n'
' program duration in samples: {} \n'
' minimal length: {}'.format(program.duration * sample_rate, minimal_waveform_length))
if comp_level == _CompatibilityLevel.incompatible_quantum:
raise ValueError('The program duration in samples {} '
'is not a multiple of quantum {}'.format(program.duration * sample_rate, waveform_quantum))

elif comp_level == _CompatibilityLevel.action_required:
warnings.warn(
"qupulse will now concatenate waveforms to make the pulse/program compatible with the chosen AWG."
" This might take some time. If you need this pulse more often it makes sense to write it in a "
"way which is more AWG friendly.", MakeCompatibleWarning)

_make_compatible(program,
min_len=minimal_waveform_length,
quantum=waveform_quantum,
sample_rate=sample_rate)

else:
assert comp_level == _CompatibilityLevel.compatible


class ChannelSplit(Exception):
def __init__(self, channel_sets):
self.channel_sets = channel_sets


def to_waveform(program: Loop) -> Waveform:
if program.is_leaf():
if program.repetition_count == 1:
return program.waveform
else:
return RepetitionWaveform.from_repetition_count(program.waveform, program.repetition_count)
else:
if len(program) == 1:
sequenced_waveform = to_waveform(cast(Loop, program[0]))
else:
sequenced_waveform = SequenceWaveform.from_sequence(
[to_waveform(cast(Loop, sub_program))
for sub_program in program])
if program.repetition_count > 1:
return RepetitionWaveform.from_repetition_count(sequenced_waveform, program.repetition_count)
else:
return sequenced_waveform
return program.to_single_waveform()


class _CompatibilityLevel(Enum):
Expand Down Expand Up @@ -568,32 +654,7 @@ def _make_compatible(program: Loop, min_len: int, quantum: int, sample_rate: Tim

def make_compatible(program: Loop, minimal_waveform_length: int, waveform_quantum: int, sample_rate: TimeType):
""" check program for compatibility to AWG requirements, make it compatible if necessary and possible"""
comp_level = _is_compatible(program,
min_len=minimal_waveform_length,
quantum=waveform_quantum,
sample_rate=sample_rate)
if comp_level == _CompatibilityLevel.incompatible_fraction:
raise ValueError('The program duration in samples {} is not an integer'.format(program.duration * sample_rate))
if comp_level == _CompatibilityLevel.incompatible_too_short:
raise ValueError('The program is too short to be a valid waveform. \n'
' program duration in samples: {} \n'
' minimal length: {}'.format(program.duration * sample_rate, minimal_waveform_length))
if comp_level == _CompatibilityLevel.incompatible_quantum:
raise ValueError('The program duration in samples {} '
'is not a multiple of quantum {}'.format(program.duration * sample_rate, waveform_quantum))

elif comp_level == _CompatibilityLevel.action_required:
warnings.warn("qupulse will now concatenate waveforms to make the pulse/program compatible with the chosen AWG."
" This might take some time. If you need this pulse more often it makes sense to write it in a "
"way which is more AWG friendly.", MakeCompatibleWarning)

_make_compatible(program,
min_len=minimal_waveform_length,
quantum=waveform_quantum,
sample_rate=sample_rate)

else:
assert comp_level == _CompatibilityLevel.compatible
program.make_compatible_inplace(minimal_waveform_length, waveform_quantum, sample_rate)


def roll_constant_waveforms(program: Loop, minimal_waveform_quanta: int, waveform_quantum: int, sample_rate: TimeType):
Expand Down
21 changes: 19 additions & 2 deletions qupulse/_program/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@

from qupulse import ChannelID
from qupulse.comparable import Comparable
from qupulse.utils.types import SingletonABCMeta
from qupulse.utils.types import SingletonABCMeta, use_rs_replacements

try:
import qupulse_rs
except ImportError:
qupulse_rs = None
transformation_rs = None
else:
from qupulse_rs.replacements import transformation as transformation_rs


class Transformation(Comparable):
Expand Down Expand Up @@ -325,4 +333,13 @@ def chain_transformations(*transformations: Transformation) -> Transformation:
elif len(parsed_transformations) == 1:
return parsed_transformations[0]
else:
return ChainedTransformation(*parsed_transformations)
return ChainedTransformation(*parsed_transformations)



if transformation_rs:
use_rs_replacements(globals(), transformation_rs, Transformation)

py_chain_transformations = chain_transformations
rs_chain_transformations = transformation_rs.chain_transformations
chain_transformations = rs_chain_transformations
31 changes: 25 additions & 6 deletions qupulse/_program/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,18 @@
from qupulse.expressions import ExpressionScalar
from qupulse.pulses.interpolation import InterpolationStrategy
from qupulse.utils import checked_int_cast, isclose
from qupulse.utils.types import TimeType, time_from_float, FrozenDict
from qupulse.utils.types import TimeType, time_from_float, FrozenDict, use_rs_replacements
from qupulse._program.transformation import Transformation
from qupulse.utils import pairwise

try:
import qupulse_rs
except ImportError:
qupulse_rs = None
waveforms_rs = None
else:
from qupulse_rs.replacements import waveforms as waveforms_rs

class ConstantFunctionPulseTemplateWarning(UserWarning):
""" This warning indicates a constant waveform is constructed from a FunctionPulseTemplate """
pass
Expand Down Expand Up @@ -850,8 +858,8 @@ def unsafe_sample(self,
return output_array

@property
def compare_key(self) -> Tuple[Any, int]:
return self._body.compare_key, self._repetition_count
def compare_key(self) -> Tuple[int, Any]:
return self._repetition_count, self._body

def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform:
return RepetitionWaveform.from_repetition_count(
Expand Down Expand Up @@ -987,9 +995,14 @@ def unsafe_sample(self,
return self.inner_waveform.unsafe_sample(channel, sample_times, output_array)

def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
d = self._inner_waveform.constant_value_dict()
if d is not None:
return {ch: d[ch] for ch in self._channel_subset}
constant_values = {}
for ch in self.defined_channels:
value = self._inner_waveform.constant_value(ch)
if value is None:
return
else:
constant_values[ch] = value
return constant_values

def constant_value(self, channel: ChannelID) -> Optional[float]:
if channel not in self._channel_subset:
Expand Down Expand Up @@ -1231,3 +1244,9 @@ def compare_key(self) -> Hashable:

def reversed(self) -> 'Waveform':
return self._inner




if waveforms_rs:
use_rs_replacements(globals(), waveforms_rs, Waveform)
Loading