diff --git a/Project.toml b/Project.toml index a7d05380..dbc95034 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -66,6 +68,8 @@ SparseArrays = "1.6" StaticArrays = "=1.9.7" Statistics = "1.6" StructTypes = "=1.10.0" +Symbolics = "=5.30.3" +SymbolicUtils = "=2.0.2" Tar = "1.9.3" Test = "1.6" UUIDs = "1.6" @@ -91,6 +95,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/docs/src/circuits.md b/docs/src/circuits.md index 79d7719a..dc6263ac 100644 --- a/docs/src/circuits.md +++ b/docs/src/circuits.md @@ -39,6 +39,7 @@ QubitSet Braket.Operator Braket.QuantumOperator FreeParameter +FreeParameterExpression depth qubits qubit_count diff --git a/src/Braket.jl b/src/Braket.jl index ec4b5206..685ce945 100644 --- a/src/Braket.jl +++ b/src/Braket.jl @@ -1,7 +1,7 @@ module Braket export Circuit, QubitSet, Qubit, Device, AwsDevice, AwsQuantumTask, AwsQuantumTaskBatch -export metadata, status, Observable, Result, FreeParameter, Job, AwsQuantumJob, LocalQuantumJob, LocalSimulator +export metadata, status, Observable, Result, FreeParameter, FreeParameterExpression, Job, AwsQuantumJob, LocalQuantumJob, LocalSimulator, subs export Tracker, simulator_tasks_cost, qpu_tasks_cost export arn, cancel, state, result, results, name, download_result, id, ir, isavailable, search_devices, get_devices export provider_name, properties, type @@ -32,6 +32,8 @@ using DecFP using Graphs using HTTP using StaticArrays +using Symbolics +using SymbolicUtils using JSON3, StructTypes using LinearAlgebra using DataStructures @@ -39,6 +41,9 @@ using NamedTupleTools using OrderedCollections using Tar +# Operator overloading for FreeParameterExpression +import Base: +, -, *, /, ^, == + include("utils.jl") """ IRType @@ -134,6 +139,107 @@ end Base.copy(fp::FreeParameter) = fp Base.show(io::IO, fp::FreeParameter) = print(io, string(fp.name)) +""" + FreeParameterExpression + FreeParameterExpression(expr::Union{FreeParameterExpression, Number, Symbolics.Num, String}) + +Struct representing a [`FreeParameterExpression`](@ref), which can be used in symbolic computations. +Instances of [`FreeParameterExpression`](@ref) can represent symbolic expressions involving [`FreeParameter`](@ref), +such as mathematical expressions with undetermined values. + +This type is often used in combination with [`FreeParameter`](@ref), which represents individual [`FreeParameter`](@ref). + +### Examples +```jldoctest +julia> α = FreeParameter(:alpha) +alpha + +julia> θ = FreeParameter(:theta) +theta + +julia> gate = FreeParameterExpression("α + 2*θ") +α + 2θ + +julia> gsub = subs(gate, Dict(:α => 2.0, :θ => 2.0)) +6.0 + +julia> gate + gate +2α + 4θ + +julia> gate * gate +(α + 2θ)^2 + +julia> gate₁ = FreeParameterExpression("phi + 2*gamma") +2gamma + phi +``` +""" +struct FreeParameterExpression + expression::Symbolics.Num + function FreeParameterExpression(expr::Symbolics.Num) + new(expr) + end +end + +FreeParameterExpression(expr::FreeParameterExpression) = FreeParameterExpression(expr.expression) +FreeParameterExpression(expr::Number) = FreeParameterExpression(Symbolics.Num(expr)) +FreeParameterExpression(expr) = throw(ArgumentError("Unsupported expression type")) + +# Function to validate the input expression string +function validate_expr(expr::String) + allowed_chars = Set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-*/^()αβγδεζηθικλμνξοπρςστυφχψωϐϑϕϖϘϙϚϛϜϝϞϟϠϡϰϱϴϵ϶ΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ ") + for char in expr + if !(char in allowed_chars) + throw(ArgumentError("Unsupported character '$char' in expression. Only ASCII letters (a-z, A-Z), digits (0-9), Greek letters (α-϶, Α-Ω), space( ), and basic mathematical symbols (+ - * / ^ ()) are allowed.")) + end + end + return expr +end + +# Function to create FreeParameterExpression from a validated string +function FreeParameterExpression(expr::String) + validated_expr = validate_expr(expr) + parsed_expr = parse_expr_to_symbolic(Meta.parse(validated_expr), @__MODULE__) + return FreeParameterExpression(parsed_expr) +end + +Base.show(io::IO, fpe::FreeParameterExpression) = print(io, fpe.expression) +Base.copy(fp::FreeParameterExpression) = fp + +function subs(fpe::FreeParameterExpression, parameter_values::Dict{Symbol, <:Number}) + param_values_num = Dict(Symbolics.variable(string(k); T=Real) => v for (k, v) in parameter_values) + subbed_expr = Symbolics.substitute(fpe.expression, param_values_num) + if isempty(Symbolics.get_variables(subbed_expr)) + subbed_expr = Symbolics.value(subbed_expr) + return subbed_expr + else + subbed_expr = Symbolics.value(subbed_expr) + return FreeParameterExpression(subbed_expr) + end +end + +Base.:+(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number, Symbolics.Num}) = + FreeParameterExpression((fpe1 isa FreeParameterExpression ? fpe1.expression : fpe1) + + (fpe2 isa FreeParameterExpression ? fpe2.expression : fpe2)) + +Base.:*(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number, Symbolics.Num}) = + FreeParameterExpression((fpe1 isa FreeParameterExpression ? fpe1.expression : fpe1) * + (fpe2 isa FreeParameterExpression ? fpe2.expression : fpe2)) + +Base.:-(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number, Symbolics.Num}) = + FreeParameterExpression((fpe1 isa FreeParameterExpression ? fpe1.expression : fpe1) - + (fpe2 isa FreeParameterExpression ? fpe2.expression : fpe2)) + +Base.:/(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number, Symbolics.Num}) = + FreeParameterExpression((fpe1 isa FreeParameterExpression ? fpe1.expression : fpe1) / + (fpe2 isa FreeParameterExpression ? fpe2.expression : fpe2)) + +Base.:^(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number, Symbolics.Num}) = + FreeParameterExpression((fpe1 isa FreeParameterExpression ? fpe1.expression : fpe1) ^ + (fpe2 isa FreeParameterExpression ? fpe2.expression : fpe2)) + +Base.:(==)(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression) = isequal(fpe1.expression, fpe2.expression) +Base.:!=(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression) = !(isequal(fpe1, fpe2)) + include("compiler_directive.jl") include("gates.jl") include("noises.jl") diff --git a/src/circuit.jl b/src/circuit.jl index 2b363d09..a8286950 100644 --- a/src/circuit.jl +++ b/src/circuit.jl @@ -20,7 +20,7 @@ mutable struct Circuit qubit_observable_mapping::Dict{Int, Observables.Observable} qubit_observable_target_mapping::Dict{Int, Tuple} qubit_observable_set::Set{Int} - parameters::Set{FreeParameter} + parameters::Set{Union{FreeParameter, FreeParameterExpression}} observables_simultaneously_measureable::Bool has_compiler_directives::Bool measure_targets::Vector{Int} diff --git a/src/gates.jl b/src/gates.jl index 0da98726..ef80b520 100644 --- a/src/gates.jl +++ b/src/gates.jl @@ -48,9 +48,12 @@ for gate_def in ( $($G) gate. """ struct $G <: AngledGate{$n_angle} - angle::NTuple{$n_angle, Union{Real, FreeParameter}} - $G(angle::T) where {T<:NTuple{$n_angle, Union{Real, FreeParameter}}} = new(angle) + angle::NTuple{$n_angle, Union{Real, FreeParameter, FreeParameterExpression}} + $G(angle::T) where {T<:NTuple{$n_angle, Union{Real, FreeParameter, FreeParameterExpression}}} = new(angle) end + $G(angles::Vararg{Union{Float64, FreeParameter, FreeParameterExpression}}) = $G(tuple(angles...)) + $G(angles::Vararg{Number}) = $G((Float64(a) for a in angles)...) + chars(::Type{$G}) = $c ir_typ(::Type{$G}) = $IR_G qubit_count(::Type{$G}) = $qc @@ -106,9 +109,9 @@ end (::Type{G})(x::Tuple{}) where {G<:Gate} = G() (::Type{G})(x::Tuple{}) where {G<:AngledGate} = throw(ArgumentError("angled gate must be constructed with at least one angle.")) (::Type{G})(x::AbstractVector) where {G<:AngledGate} = G(x...) -(::Type{G})(angle::T) where {G<:AngledGate{1}, T<:Union{Real, FreeParameter}} = G((angle,)) -(::Type{G})(angle1::T1, angle2::T2) where {T1<:Union{Real, FreeParameter}, T2<:Union{Real, FreeParameter}, G<:AngledGate{2}} = G((angle1, angle2,)) -(::Type{G})(angle1::T1, angle2::T2, angle3::T3) where {T1<:Union{Real, FreeParameter}, T2<:Union{Real, FreeParameter}, T3<:Union{Real, FreeParameter}, G<:AngledGate{3}} = G((angle1, angle2, angle3,)) +(::Type{G})(angle::T) where {G<:AngledGate{1}, T<:Union{Real, FreeParameter, FreeParameterExpression}} = G((angle,)) +(::Type{G})(angle1::T1, angle2::T2) where {T1<:Union{Real, FreeParameter, FreeParameterExpression}, T2<:Union{Real, FreeParameter, FreeParameterExpression}, G<:AngledGate{2}} = G((angle1, angle2,)) +(::Type{G})(angle1::T1, angle2::T2, angle3::T3) where {T1<:Union{Real, FreeParameter, FreeParameterExpression}, T2<:Union{Real, FreeParameter, FreeParameterExpression}, T3<:Union{Real, FreeParameter, FreeParameterExpression}, G<:AngledGate{3}} = G((angle1, angle2, angle3,)) qubit_count(g::G) where {G<:Gate} = qubit_count(G) angles(g::G) where {G<:Gate} = () angles(g::AngledGate{N}) where {N} = g.angle diff --git a/test/freeparameterexpression.jl b/test/freeparameterexpression.jl new file mode 100644 index 00000000..e5c3be81 --- /dev/null +++ b/test/freeparameterexpression.jl @@ -0,0 +1,70 @@ +using Braket, Test + +@testset "Free parameter expressions" begin + α = FreeParameter(:alpha) + θ = FreeParameter(:theta) + gate = FreeParameterExpression("α + 2*θ") + @test copy(gate) === gate + gsub = subs(gate, Dict(:α => 2.0, :θ => 2.0)) + circ = Circuit() + circ = H(circ, 0) + circ = Rx(circ, 1, gsub) + circ = Ry(circ, 0, θ) + circ = Probability(circ) + new_circ = circ(6.0) + non_para_circ = Circuit() |> (ci->H(ci, 0)) |> (ci->Rx(ci, 1, gsub)) |> (ci->Ry(ci, 0, 6.0)) |> Probability + @test new_circ == non_para_circ + ϕ = FreeParameter(:phi) + circ = apply_gate_noise!(circ, BitFlip(ϕ)) + circ = apply_gate_noise!(circ, PhaseFlip(0.1)) + new_circ = circ(theta=2.0, alpha=1.0, phi=0.2) + non_para_circ = Circuit() |> (ci->H(ci, 0)) |> (ci->Rx(ci, 1, gsub)) |> (ci->Ry(ci, 0, 2.0)) |> Probability |> (ci->apply_gate_noise!(ci, BitFlip(0.2))) |> (ci->apply_gate_noise!(ci, PhaseFlip(0.1))) + @test new_circ == non_para_circ + + # creating gates directly + gate = FreeParameterExpression("phi + 2*gamma") + gsub₁ = subs(gate, Dict(:phi => 4.0, :gamma => 4.0)) + @test gsub₁ == 12.0 + circ = Circuit() + circ = H(circ, 0) + circ = Rx(circ, 1, gsub₁) + circ = Ry(circ, 0, θ) + circ = Probability(circ) + new_circ = circ(6.0) + non_para_circ = Circuit() |> (ci->H(ci, 0)) |> (ci->Rx(ci, 1, gsub₁)) |> (ci->Ry(ci, 0, 6.0)) |> Probability + @test new_circ == non_para_circ + + # + operator + fpe1 = FreeParameterExpression("α + θ") + fpe2 = FreeParameterExpression("2 * θ") + result = fpe1 + fpe2 + @test result == FreeParameterExpression("2 * θ + α + θ") + gsub = subs(result, Dict(:α => 1.0, :θ => 1.0)) + @test gsub == 4.0 # α + 3θ == 4.0 + fpe3 = FreeParameterExpression("2 * θ") + # == operator + @test fpe3 == fpe2 + # != operator + @test fpe1 != fpe2 + show(fpe3) + # - operator + result = fpe1 - fpe2 + @test result == FreeParameterExpression("α - θ") + gsub = subs(result, Dict(:α => 1.0, :θ => 1.0)) + @test gsub == 0.0 # α - θ == 0.0 + # * operator + result = fpe1 * fpe2 + @test result == FreeParameterExpression("2(α + θ)*θ") + gsub = subs(result, Dict(:α => 1.0, :θ => 1.0)) + @test gsub == 4.0 # 2(α + θ)*θ == 4.0 + # / operator + result = fpe1 / fpe2 + @test result == FreeParameterExpression("(α + θ) / (2θ)") + gsub = subs(result, Dict(:α => 1.0, :θ => 1.0)) + @test gsub == 1.0 # (α + θ) / (2θ) == 1.0 + # ^ operator + result = fpe1 ^ fpe2 + @test result == FreeParameterExpression("(α + θ)^(2θ)") + gsub = subs(result, Dict(:α => 1.0, :θ => 1.0)) + @test gsub == 4.0 # (α + θ)^(2θ) == 4.0 +end diff --git a/test/runtests.jl b/test/runtests.jl index 91807f74..5e8773aa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,7 @@ for group in groups include("circuits.jl") include("measure.jl") include("free_parameter.jl") + include("freeparameterexpression.jl") include("gates.jl") include("observables.jl") include("noises.jl")