Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Full support for Local Detuning #86

Merged
merged 11 commits into from
Aug 28, 2024
180 changes: 180 additions & 0 deletions src/ahs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import StructTypes

export AtomArrangementItem, AtomArrangement, TimeSeriesItem, TimeSeries, Field, DrivingField
export ShiftingField, Hamiltonian, AnalogHamiltonianSimulation, Pattern, vacant, filled, SiteType, discretize
export LocalDetuning, stitch

"""
Hamiltonian
Expand Down Expand Up @@ -271,3 +272,182 @@ function discretize(ahs::AnalogHamiltonianSimulation, device::Device)
discretize(ahs.register, properties), map(h->discretize(h, properties), ahs.hamiltonian)
)
end

"""
LocalDetuning <: Hamiltonian
Struct representing a Hamiltonian term `H_{shift}` representing the [local detuning](https://aws.amazon.com/blogs/quantum-computing/local-detuning-now-available-on-queras-aquila-device-with-braket-direct/) that changes the energy of the Rydberg level in an AnalogHamiltonianSimulation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can AnalogHamiltonianSimulation have a doc link?

```math
H_{shift} (t) := -\\Delta(t) \\sum_k h_k | r_k \\rangle \\langle r_k |
```
where:
- ``\\Delta(t)`` is the magnitude of the frequency shift in rad/s,
- ``h_k`` is the site coefficient of atom ``k``, a dimensionless real number between 0 and 1,
- ``|r_k \\rangle`` is the Rydberg state of atom ``k``.
The sum ``\\sum_k`` is taken over all target atoms.
Fields:
- `magnitude::Field`: Field containing the global magnitude time series Delta(t),
where time is measured in seconds (s) and values are measured in rad/s, and the
local pattern h_k of dimensionless real numbers between 0 and 1.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should follow the proper doctest format like:

# Examples

[doctest here]

```julia
julia> magnitude = Field(TimeSeries(OrderedDict([0 => TimeSeriesItem(0, 1), 1 => TimeSeriesItem(1, 2)]), true, -1))
julia> local_detuning = LocalDetuning(magnitude)
```
"""
struct LocalDetuning <: Hamiltonian
magnitude::Field
end

"""
LocalDetuning(times::Vector{<:Number}, values::Vector{<:Number}, pattern::Vector{<:Number})
kshyatt-aws marked this conversation as resolved.
Show resolved Hide resolved
```julia
julia> times₁ = [0, 0.1, 0.2, 0.3];
julia> glob_amplitude₁ = [0.5, 0.8, 0.9, 1.0];
julia> pattern₁ = [0.3, 0.7, 0.6, -0.5, 0, 1.6];
julia> s₁ = LocalDetuning(times₁, glob_amplitude₁, pattern₁)
LocalDetuning(Field(TimeSeries(OrderedCollections.OrderedDict{Number, TimeSeriesItem}(0.0 => TimeSeriesItem(0.0, 0.5), 0.1 => TimeSeriesItem(0.1, 0.8), 0.2 => TimeSeriesItem(0.2, 0.9), 0.3 => TimeSeriesItem(0.3, 1.0)), true, -1), Pattern(Number[0.3, 0.7, 0.6, -0.5, 0.0, 1.6])))
```
"""
function LocalDetuning(times::Vector{<:Number}, values::Vector{<:Number}, pattern::Vector{<:Number})
if length(times) != length(values)
throw(ArgumentError("The length of the times and values lists must be equal."))
end

time_series = TimeSeries()
for (t, v) in zip(times, values)
time_series[t] = v
end

field = Field(time_series, Pattern(pattern))
LocalDetuning(field)
end


ir(ld::LocalDetuning) = IR.LocalDetuning(ir(ld.magnitude))
"""
stitch(ld1::LocalDetuning, ld2::LocalDetuning; boundary::Symbol="mean") -> LocalDetuning
[`stitch`](@ref) two shifting fields based on the `TimeSeries.stitch` method.
The time points of the second LocalDetuning are shifted such that the first time point of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalDetuning can get a doc link here too

the second LocalDetuning coincides with the last time point of the first LocalDetuning.
The boundary point value is handled according to the boundary argument value.
# Arguments:
ld1::LocalDetuning: The first LocalDetuning to be stitched.
ld2::LocalDetuning: The second LocalDetuning to be stitched.
boundary::Symbol="mean": The boundary point handler. Possible options are "mean", "left", "right".
"""
function stitch(ld1::LocalDetuning, ld2::LocalDetuning, boundary::Symbol=:mean)
if ld1.magnitude.pattern.series != ld2.magnitude.pattern.series
throw(ArgumentError("The LocalDetuning pattern for both fields must be equal."))
end

new_ts = stitch(ld1.magnitude.time_series, ld2.magnitude.time_series, boundary)
LocalDetuning(Field(new_ts, ld1.magnitude.pattern))
end

"""
stitch(ts1::TimeSeries, ts2::TimeSeries; boundary::Symbol="mean")
[`stitch`](@ref) two shifting fields based on the `TimeSeries.stitch` method.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TimeSeries.stitch isn't a valid Julia function reference? It should be just stitch

The time points of the second `TimeSeries` are shifted such that the first time point of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can TimeSeries have a doc link?

the second `TimeSeries` coincides with the last time point of the first `TimeSeries`.
The boundary point value is handled according to the `boundary` argument value.
# Arguments:
- `ts1::TimeSeries`: The first `TimeSeries` to be stitched.
- `ts2::TimeSeries`: The second `TimeSeries` to be stitched.
- `boundary::Symbol="mean"`: The boundary point handler. Possible options are "mean", "left", "right".
# Examples:
```julia
julia> times₁ = [0, 0.1, 0.2, 0.3];
julia> glob_amplitude₁ = [0.5, 0.8, 0.9, 1.0];
julia> pattern₁ = [0.3, 0.7, 0.6, -0.5, 0, 1.6];
julia> times₂ = [0, 0.1, 0.2, 0.3];
julia> glob_amplitude₂ = [0.5, 0.8, 0.9, 1.0];
julia> pattern₂ = pattern₁;
julia> s₂ = LocalDetuning(times₂, glob_amplitude₂, pattern₂)
LocalDetuning(Field(TimeSeries(OrderedCollections.OrderedDict{Number, TimeSeriesItem}(0.0 => TimeSeriesItem(0.0, 0.5), 0.1 => TimeSeriesItem(0.1, 0.8), 0.2 => TimeSeriesItem(0.2, 0.9), 0.3 => TimeSeriesItem(0.3, 1.0)), true, -1), Pattern(Number[0.3, 0.7, 0.6, -0.5, 0.0, 1.6])))
julia> s₁ = LocalDetuning(times₁, glob_amplitude₁, pattern₁)
LocalDetuning(Field(TimeSeries(OrderedCollections.OrderedDict{Number, TimeSeriesItem}(0.0 => TimeSeriesItem(0.0, 0.5), 0.1 => TimeSeriesItem(0.1, 0.8), 0.2 => TimeSeriesItem(0.2, 0.9), 0.3 => TimeSeriesItem(0.3, 1.0)), true, -1), Pattern(Number[0.3, 0.7, 0.6, -0.5, 0.0, 1.6])))
julia> stitchedₗ = stitch(s₁, s₂, :mean)
LocalDetuning(Field(TimeSeries(OrderedCollections.OrderedDict{Number, TimeSeriesItem}(0.0 => TimeSeriesItem(0.0, 0.75), 0.1 => TimeSeriesItem(0.1, 0.8), 0.2 => TimeSeriesItem(0.2, 0.9), 0.3 => TimeSeriesItem(0.3, 1.0)), true, 1), Pattern(Number[0.3, 0.7, 0.6, -0.5, 0.0, 1.6])))
julia> stitchedₗ = stitch(s₁, s₂, :left)
LocalDetuning(Field(TimeSeries(OrderedCollections.OrderedDict{Number, TimeSeriesItem}(0.0 => TimeSeriesItem(0.3, 1.0), 0.1 => TimeSeriesItem(0.1, 0.8), 0.2 => TimeSeriesItem(0.2, 0.9), 0.3 => TimeSeriesItem(0.3, 1.0)), true, 1), Pattern(Number[0.3, 0.7, 0.6, -0.5, 0.0, 1.6])))
julia> stitchedₗ = stitch(s₁, s₂, :right)
LocalDetuning(Field(TimeSeries(OrderedCollections.OrderedDict{Number, TimeSeriesItem}(0.0 => TimeSeriesItem(0.0, 0.5), 0.1 => TimeSeriesItem(0.1, 0.8), 0.2 => TimeSeriesItem(0.2, 0.9), 0.3 => TimeSeriesItem(0.3, 1.0)), true, 1), Pattern(Number[0.3, 0.7, 0.6, -0.5, 0.0, 1.6])))
```
"""
function stitch(ts1::TimeSeries, ts2::TimeSeries, boundary::Symbol)
merged_series = deepcopy(ts1.series)
first_time_ts2 = first(collect(keys(ts2.series)))
last_time_ts1 = last(collect(keys(ts1.series)))

if boundary == :mean
merged_value = (ts1.series[last_time_ts1].value + ts2.series[first_time_ts2].value) / 2
merged_series[first_time_ts2] = TimeSeriesItem(first_time_ts2, merged_value)
elseif boundary == :left
merged_series[first_time_ts2] = ts1.series[last_time_ts1]
elseif boundary == :right
merged_series[first_time_ts2] = ts2.series[first_time_ts2]
else
throw(ArgumentError("Invalid boundary condition: $boundary"))
end

for (t, v) in ts2.series
if t != first_time_ts2
merged_series[t] = v
end
end

largest_time = maximum(keys(merged_series))
if largest_time isa Float64
largest_time = Int(ceil(largest_time))
end

TimeSeries(merged_series, true, largest_time)
end
"""
discretize(ld::LocalDetuning, properties::DiscretizationProperties) -> LocalDetuning
Creates a discretized version of the `LocalDetuning`.
# Arguments:
- `ld::LocalDetuning`: The `LocalDetuning` to discretize.
- `properties::DiscretizationProperties`: Capabilities of a device that represent the
resolution with which the device can implement the parameters.
"""
function discretize(ld::LocalDetuning, properties::DiscretizationProperties)
local_detuning_parameters = properties.rydberg.rydbergLocal
time_resolution = local_detuning_parameters.timeResolution
value_resolution = local_detuning_parameters.commonDetuningResolution
pattern_resolution = local_detuning_parameters.localDetuningResolution

discretized_magnitude = discretize(ld.magnitude, time_resolution, value_resolution, pattern_resolution)
LocalDetuning(discretized_magnitude)
end

function discretize(ld::LocalDetuning, properties::DiscretizationProperties)
local_detuning_parameters = properties.rydberg.rydbergLocal
time_resolution = Dec128(local_detuning_parameters.timeResolution)
value_resolution = Dec128(local_detuning_parameters.commonDetuningResolution)
pattern_resolution = Dec128(local_detuning_parameters.localDetuningResolution)
discretized_magnitude = discretize(ld.magnitude, time_resolution, value_resolution, pattern_resolution)
LocalDetuning(discretized_magnitude)
end
61 changes: 61 additions & 0 deletions test/local_detuning.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using Braket, Test, JSON3, StructTypes, Mocking, UUIDs, DecFP

Mocking.activate()

struct MockRydbergLocal
kshyatt-aws marked this conversation as resolved.
Show resolved Hide resolved
timeResolution::Dec128
commonDetuningResolution::Dec128
localDetuningResolution::Dec128
end

struct MockRydberg
c6coefficient::Dec128
rydbergGlobal::Braket.RydbergGlobal
rydbergLocal::MockRydbergLocal
end

struct MockDiscretizationProperties
rydberg::MockRydberg
end

@testset "AHS.LocalDetuning" begin

@testset "LocalDetuning" begin
times₁ = [0, 0.1, 0.2, 0.3]
times₂ = [0.3, 0.1, 0.2, 0]
glob_amplitude₁ = [0.5, 0.8, 0.9, 1.0]
pattern₁ = [0.3, 0.7, 0.6, -0.5, 0, 1.6]
s₁ = LocalDetuning(times₁, glob_amplitude₁, pattern₁)
s₂ = LocalDetuning(times₂, glob_amplitude₁, pattern₁)
@test s₁.magnitude.pattern.series == [0.3, 0.7, 0.6, -0.5, 0, 1.6]
@test s₁.magnitude.time_series.sorted == true
@test s₂.magnitude.time_series.sorted == false
end

@testset "LocalDetuning Stitch: Mean, Left, Right" begin
times₁ = [0, 0.1, 0.2, 0.3]
glob_amplitude₁ = [0.5, 0.8, 0.9, 1.0]
pattern₁ = [0.3, 0.7, 0.6, -0.5, 0, 1.6]
times₂ = [0, 0.1, 0.2, 0.3]
glob_amplitude₂ = [0.5, 0.8, 0.9, 1.0]
pattern₂ = pattern₁
s₂ = LocalDetuning(times₂, glob_amplitude₂, pattern₂)
s₁ = LocalDetuning(times₁, glob_amplitude₁, pattern₁)
stitchedₗ = stitch(s₁, s₂, :left)
@test stitchedₗ.magnitude.pattern == s₁.magnitude.pattern
end

@testset "LocalDetuning: Discretize" begin
times₅ = [0, 0.1, 0.2]
values₅ = [0.2, 0.5, 0.7]
pattern₅ = [0.1, 0.3, 0.5]
ld = LocalDetuning(times₅, values₅, pattern₅)
properties = Braket.DiscretizationProperties(
Braket.Lattice(Braket.Area(Dec128("1e-3"), Dec128("1e-3")), Braket.Geometry(Dec128("1e-7"), Dec128("1e-7"), Dec128("1e-7"), 200)),
MockRydberg(Dec128("1e-6"), Braket.RydbergGlobal((Dec128("1.0"), Dec128("1e6")), Dec128("400.0"), Dec128("0.2"), (Dec128("1.0"), Dec128("1e6")), Dec128("0.2"), Dec128("0.2"), (Dec128("1.0"), Dec128("1e6")), Dec128("5e-7"), Dec128("1e-9"), Dec128("1e-5"), Dec128("0.0"), Dec128("100.0")),
MockRydbergLocal(Dec128("1e-9"), Dec128("2000.0"), Dec128("0.01")))
)
discretizedₗ = discretize(ld, properties)
@test discretized_ld isa LocalDetuning
end
end