Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Adding support for FreeParameterExpression #85

Merged
merged 18 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

Expand Down Expand Up @@ -64,6 +66,8 @@ SparseArrays = "1.6"
StaticArrays = "=1.9.3"
Statistics = "1.6"
StructTypes = "=1.10.0"
Symbolics = "5.28.0"
SymbolicUtils = "1.6.0"
Tar = "1.9.3"
Test = "1.6"
UUIDs = "1.6"
Expand All @@ -89,6 +93,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
140 changes: 139 additions & 1 deletion src/Braket.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Braket

export Circuit, QubitSet, Qubit, Device, AwsDevice, AwsQuantumTask, AwsQuantumTaskBatch
export metadata, status, Observable, Result, FreeParameter, Job, AwsQuantumJob, LocalQuantumJob, LocalSimulator
export metadata, status, Observable, Result, FreeParameter, FreeParameterExpression, Job, AwsQuantumJob, LocalQuantumJob, LocalSimulator, subs
export Tracker, simulator_tasks_cost, qpu_tasks_cost
export arn, cancel, state, result, results, name, download_result, id, ir, isavailable, search_devices, get_devices
export provider_name, properties, type
Expand Down Expand Up @@ -32,13 +32,18 @@ using DecFP
using Graphs
using HTTP
using StaticArrays
using Symbolics
using SymbolicUtils
using JSON3, StructTypes
using LinearAlgebra
using DataStructures
using NamedTupleTools
using OrderedCollections
using Tar

# Operator overloading for FreeParameterExpression
import Base: +, -, *, /, ^, ==

include("utils.jl")
"""
IRType
Expand Down Expand Up @@ -134,6 +139,139 @@ end
Base.copy(fp::FreeParameter) = fp
Base.show(io::IO, fp::FreeParameter) = print(io, string(fp.name))

"""
FreeParameterExpression
FreeParameterExpression(expr::Union{FreeParameterExpression, Number, Symbolics.Num, String})
Struct representing a Free Parameter expression, which can be used in symbolic computations.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FreeParameter here should get a doc link

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to add this docstring to the relevant docs/src file explicitly

Instances of `FreeParameterExpression` can represent symbolic expressions involving FreeParameters,
such as mathematical expressions with undetermined values.
This type is often used in combination with [`FreeParameter`](@ref), which represents individual FreeParameter.
### Examples
```jldoctest
julia> α = FreeParameter(:alpha)
alpha
julia> θ = FreeParameter(:theta)
theta
julia> gate = FreeParameterExpression("α + 2*θ")
α + 2θ
julia> gsub = subs(gate, Dict(:α => 2.0, :θ => 2.0))
6.0
julia> gate + gate
2α + 4θ
julia> gate * gate
(α + 2θ)^2
julia> gate₁ = FreeParameterExpression("phi + 2*gamma")
2gamma + phi
```
"""
struct FreeParameterExpression
expression::Symbolics.Num
function FreeParameterExpression(expr::Symbolics.Num)
new(expr)
end
end

FreeParameterExpression(expr::FreeParameterExpression) = FreeParameterExpression(expr.expression)
FreeParameterExpression(expr::Number) = FreeParameterExpression(Symbolics.Num(expr))
FreeParameterExpression(expr) = throw(ArgumentError("Unsupported expression type"))

function FreeParameterExpression(expr::String)
parsed_expr = parse_expr_to_symbolic(Meta.parse(expr), @__MODULE__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very suspect to me... is there a less sketchy way to accomplish this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally directly calling Meta.parse can be pretty risky!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your concern. In order to alleviate it, Added a function validate_expr that checks for each char in string and make sure that error is thrown for suspect and unknown characters. This approach ensures that input expressions are strictly validated before parsing with Meta.parse.

Tried a bunch of different ways without Meta.parse. , but other approaches were errorsome. Checked out Link: https://github.com/JuliaSymbolics/Symbolics.jl/blob/master/src/parsing.jl and ExprTools as well.

# Function to validate the input expression string
function validate_expr(expr::String)
    allowed_chars = Set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-*/^()αβγδεζηθικλμνξοπρςστυφχψω ")
    for char in expr
        if !(char in allowed_chars)
            throw(ArgumentError("Unsupported character '$char' in expression. Only ASCII letters (a-z, A-Z), Greek letters (α-ω), digits (0-9), space( ), and basic mathematical symbols (+ - * / ^ ()) are allowed."))
        end
    end
    return expr
end

# Function to create FreeParameterExpression from a validated string
function FreeParameterExpression(expr::String)
    validated_expr = validate_expr(expr)
    parsed_expr = parse_expr_to_symbolic(Meta.parse(validated_expr), @__MODULE__)
    return FreeParameterExpression(parsed_expr)
end

return FreeParameterExpression(parsed_expr)
end

Base.show(io::IO, fpe::FreeParameterExpression) = print(io, fpe.expression)
Base.copy(fp::FreeParameterExpression) = fp

function subs(fpe::FreeParameterExpression, parameter_values::Dict{Symbol, <:Number})
param_values_num = Dict(Symbolics.variable(string(k); T=Real) => v for (k, v) in parameter_values)
subbed_expr = Symbolics.substitute(fpe.expression, param_values_num)
if isempty(Symbolics.get_variables(subbed_expr))
subbed_expr = Symbolics.value(subbed_expr)
return subbed_expr
else
subbed_expr = Symbolics.value(subbed_expr)
return FreeParameterExpression(subbed_expr)
end
end

function Base.:+(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1.expression + fpe2.expression)
end

function Base.:+(fpe1::FreeParameterExpression, fpe2::Union{Number, Symbolics.Num})
return FreeParameterExpression(fpe1.expression + fpe2)
end

function Base.:+(fpe1::Union{Number, Symbolics.Num}, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1 + fpe2.expression)
end

function Base.:*(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1.expression * fpe2.expression)
end

function Base.:*(fpe1::FreeParameterExpression, fpe2::Union{Number, Symbolics.Num})
return FreeParameterExpression(fpe1.expression * fpe2)
end

function Base.:*(fpe1::Union{Number, Symbolics.Num}, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1 * fpe2.expression)
end

function Base.:-(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1.expression - fpe2.expression)
end

function Base.:-(fpe1::FreeParameterExpression, fpe2::Union{Number, Symbolics.Num})
return FreeParameterExpression(fpe1.expression - fpe2)
end

function Base.:-(fpe1::Union{Number, Symbolics.Num}, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1 - fpe2.expression)
end

function Base.:/(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1.expression / fpe2.expression)
end

function Base.:/(fpe1::FreeParameterExpression, fpe2::Union{Number, Symbolics.Num})
return FreeParameterExpression(fpe1.expression / fpe2)
end

function Base.:/(fpe1::Union{Number, Symbolics.Num}, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1 / fpe2.expression)
end

function Base.:^(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1.expression ^ fpe2.expression)
end

function Base.:^(fpe1::FreeParameterExpression, fpe2::Union{Number, Symbolics.Num})
return FreeParameterExpression(fpe1.expression ^ fpe2)
end

function Base.:^(fpe1::Union{Number, Symbolics.Num}, fpe2::FreeParameterExpression)
return FreeParameterExpression(fpe1 ^ fpe2.expression)
end

function Base.:(==)(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression)
return isequal(fpe1.expression, fpe2.expression)
end

function Base.:!=(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression)
return !(isequal(fpe1.expression, fpe2.expression))
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could slim all these down by making them one liners like:

Base.:^(fpe1::Union{Number, Symbolics.Num}, fpe2::FreeParameterExpression) = FreeParameterExpression(fpe1 ^ fpe2.expression)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was giving error

Base.:^(fpe1::Union{Number, Symbolics.Num}, fpe2::FreeParameterExpression) = FreeParameterExpression(fpe1 ^ fpe2.expression)

julia> result = fpe1 ^ fpe2
ERROR: MethodError: no method matching ^(::FreeParameterExpression, ::FreeParameterExpression)

Closest candidates are:
  ^(::Number, ::FreeParameterExpression)
   @ Braket ~/Desktop/New/braket/2/Braket.jl/src/Braket.jl:211
  ^(::SymbolicUtils.Symbolic{<:Number}, ::Any)
   @ SymbolicUtils ~/.julia/packages/SymbolicUtils/qyMYa/src/types.jl:1179

We can slim it down like this:


Base.:+(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number, Symbolics.Num}) = 
    FreeParameterExpression((fpe1 isa FreeParameterExpression ? fpe1.expression : fpe1) + 
                            (fpe2 isa FreeParameterExpression ? fpe2.expression : fpe2))

include("compiler_directive.jl")
include("gates.jl")
include("noises.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/circuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mutable struct Circuit
qubit_observable_mapping::Dict{Int, Observables.Observable}
qubit_observable_target_mapping::Dict{Int, Tuple}
qubit_observable_set::Set{Int}
parameters::Set{FreeParameter}
parameters::Set{Union{FreeParameter, FreeParameterExpression}}
observables_simultaneously_measureable::Bool
has_compiler_directives::Bool
measure_targets::Vector{Int}
Expand Down
8 changes: 4 additions & 4 deletions src/gates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ for gate_def in (
$($G) gate.
"""
struct $G <: AngledGate{$n_angle}
angle::NTuple{$n_angle, Union{Float64, FreeParameter}}
$G(angle::T) where {T<:NTuple{$n_angle, Union{Float64, FreeParameter}}} = new(angle)
angle::NTuple{$n_angle, Union{Float64, FreeParameter, FreeParameterExpression}}
$G(angle::T) where {T<:NTuple{$n_angle, Union{Float64, FreeParameter, FreeParameterExpression}}} = new(angle)
end
$G(angles::Vararg{Union{Float64, FreeParameter}}) = $G(tuple(angles...))
$G(angles::Vararg{Union{Float64, FreeParameter, FreeParameterExpression}}) = $G(tuple(angles...))
$G(angles::Vararg{Number}) = $G((Float64(a) for a in angles)...)
chars(::Type{$G}) = $c
ir_typ(::Type{$G}) = $IR_G
Expand Down Expand Up @@ -99,7 +99,7 @@ end
(::Type{G})(x::Tuple{}) where {G<:Gate} = G()
(::Type{G})(x::Tuple{}) where {G<:AngledGate} = throw(ArgumentError("angled gate must be constructed with at least one angle."))
(::Type{G})(x::AbstractVector) where {G<:AngledGate} = G(x...)
(::Type{G})(x::T) where {G<:AngledGate{1}, T<:Union{Float64, FreeParameter}} = G((x,))
(::Type{G})(x::T) where {G<:AngledGate{1}, T<:Union{Float64, FreeParameter, FreeParameterExpression}} = G((x,))
qubit_count(g::G) where {G<:Gate} = qubit_count(G)
angles(g::G) where {G<:Gate} = ()
angles(g::AngledGate{N}) where {N} = g.angle
Expand Down
70 changes: 70 additions & 0 deletions test/freeparameterexpression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using Braket, Test

@testset "Free parameter expressions" begin
Fe-r-oz marked this conversation as resolved.
Show resolved Hide resolved
α = FreeParameter(:alpha)
θ = FreeParameter(:theta)
gate = FreeParameterExpression("α + 2*θ")
@test copy(gate) === gate
gsub = subs(gate, Dict( => 2.0, => 2.0))
circ = Circuit()
circ = H(circ, 0)
circ = Rx(circ, 1, gsub)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But here gsub is still the FPE with the values subbed in. What happens if I do:

circ = Rx(circ, 1, gate)

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will show the FPE before evaluating it via substitute at the backend wrapped as subs:

julia> circ = Circuit()

T : |Result Types|
                
T : |Result Types|

julia> circ = H(circ, 0)
T  : |0|Result Types|
                     
q0 : -H--------------
                     
T  : |0|Result Types|


julia> circ = Rx(circ, 1, gate)
T  : |    0     |Result Types|
                              
q0 : -H-----------------------
                              
q1 : -Rx(α + 2θ)--------------
                              
T  : |    0     |Result Types|


circ = Ry(circ, 0, θ)
circ = Probability(circ)
new_circ = circ(6.0)
non_para_circ = Circuit() |> (ci->H(ci, 0)) |> (ci->Rx(ci, 1, gsub)) |> (ci->Ry(ci, 0, 6.0)) |> Probability
@test new_circ == non_para_circ
ϕ = FreeParameter(:phi)
circ = apply_gate_noise!(circ, BitFlip(ϕ))
circ = apply_gate_noise!(circ, PhaseFlip(0.1))
new_circ = circ(theta=2.0, alpha=1.0, phi=0.2)
non_para_circ = Circuit() |> (ci->H(ci, 0)) |> (ci->Rx(ci, 1, gsub)) |> (ci->Ry(ci, 0, 2.0)) |> Probability |> (ci->apply_gate_noise!(ci, BitFlip(0.2))) |> (ci->apply_gate_noise!(ci, PhaseFlip(0.1)))
@test new_circ == non_para_circ

# creating gates directly
gate = FreeParameterExpression("phi + 2*gamma")
gsub₁ = subs(gate, Dict(:phi => 4.0, :gamma => 4.0))
@test gsub₁ == 12.0
circ = Circuit()
circ = H(circ, 0)
circ = Rx(circ, 1, gsub₁)
circ = Ry(circ, 0, θ)
circ = Probability(circ)
new_circ = circ(6.0)
non_para_circ = Circuit() |> (ci->H(ci, 0)) |> (ci->Rx(ci, 1, gsub₁)) |> (ci->Ry(ci, 0, 6.0)) |> Probability
@test new_circ == non_para_circ

# + operator
fpe1 = FreeParameterExpression("α + θ")
fpe2 = FreeParameterExpression("2 * θ")
result = fpe1 + fpe2
@test result == FreeParameterExpression("2 * θ + α + θ")
gsub = subs(result, Dict( => 1.0, => 1.0))
@test gsub == 4.0 # α + 3θ == 4.0
fpe3 = FreeParameterExpression("2 * θ")
# == operator
@test fpe3 == fpe2
# != operator
@test fpe1 != fpe2
show(fpe3)
# - operator
result = fpe1 - fpe2
@test result == FreeParameterExpression("α - θ")
gsub = subs(result, Dict( => 1.0, => 1.0))
@test gsub == 0.0 # α - θ == 0.0
# * operator
result = fpe1 * fpe2
@test result == FreeParameterExpression("2(α + θ)*θ")
gsub = subs(result, Dict( => 1.0, => 1.0))
@test gsub == 4.0 # 2(α + θ)*θ == 4.0
# / operator
result = fpe1 / fpe2
@test result == FreeParameterExpression("(α + θ) / (2θ)")
gsub = subs(result, Dict( => 1.0, => 1.0))
@test gsub == 1.0 # (α + θ) / (2θ) == 1.0
# ^ operator
result = fpe1 ^ fpe2
@test result == FreeParameterExpression("(α + θ)^(2θ)")
gsub = subs(result, Dict( => 1.0, => 1.0))
@test gsub == 4.0 # (α + θ)^(2θ) == 4.0
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ for group in groups
include("circuits.jl")
include("measure.jl")
include("free_parameter.jl")
include("freeparameterexpression.jl")
include("gates.jl")
include("observables.jl")
include("noises.jl")
Expand Down