From 5faf99ea5f63677242c2d97d3cf755008369e443 Mon Sep 17 00:00:00 2001 From: Fe-r-oz Date: Wed, 5 Jun 2024 11:01:44 +0500 Subject: [PATCH] Feature: Adding support for FreeParameterExpression, resolving Issue #82 --- Project.toml | 8 ++++- src/Braket.jl | 73 +++++++++++++++++++++++++++++++++++++++++- src/circuit.jl | 2 +- src/gates.jl | 8 ++--- test/free_parameter.jl | 17 ++++++++++ 5 files changed, 101 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index b6ba47d4..49bcb278 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Braket" uuid = "19504a0f-b47d-4348-9127-acc6cc69ef67" authors = ["Katharine Hyatt "] -version = "0.9.0" +version = "0.9.1" [deps] AWS = "fbe9abb3-538b-5e4e-ba9e-bc94f4f92ebc" @@ -29,6 +29,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -64,6 +66,8 @@ SparseArrays = "1.6" StaticArrays = "=1.9.3" Statistics = "1.6" StructTypes = "=1.10.0" +Symbolics = "5.28.0" +SymbolicUtils = "1.6.0" Tar = "1.9.3" Test = "1.6" UUIDs = "1.6" @@ -89,6 +93,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/src/Braket.jl b/src/Braket.jl index ec4b5206..b96e81b0 100644 --- a/src/Braket.jl +++ b/src/Braket.jl @@ -1,7 +1,7 @@ module Braket export Circuit, QubitSet, Qubit, Device, AwsDevice, AwsQuantumTask, AwsQuantumTaskBatch -export metadata, status, Observable, Result, FreeParameter, Job, AwsQuantumJob, LocalQuantumJob, LocalSimulator +export metadata, status, Observable, Result, FreeParameter, FreeParameterExpression, Job, AwsQuantumJob, LocalQuantumJob, LocalSimulator, subs export Tracker, simulator_tasks_cost, qpu_tasks_cost export arn, cancel, state, result, results, name, download_result, id, ir, isavailable, search_devices, get_devices export provider_name, properties, type @@ -32,6 +32,8 @@ using DecFP using Graphs using HTTP using StaticArrays +using Symbolics +using SymbolicUtils using JSON3, StructTypes using LinearAlgebra using DataStructures @@ -39,6 +41,9 @@ using NamedTupleTools using OrderedCollections using Tar +# Operator overloading for FreeParameterExpression +import Base: +, -, *, /, ^, == + include("utils.jl") """ IRType @@ -134,6 +139,72 @@ end Base.copy(fp::FreeParameter) = fp Base.show(io::IO, fp::FreeParameter) = print(io, string(fp.name)) +""" + FreeParameterExpression + FreeParameterExpression(expr::Union{FreeParameterExpression, Number, Symbolics.Num, String}) + +Struct representing a Free Parameter expression, which can be used in symbolic computations. + +### Examples +```jldoctest +julia> fp_alpha = FreeParameter(:alpha) +alpha + +julia> fp_beta = FreeParameter(:beta) +beta + +julia> expr1 = FreeParameterExpression("2 * alpha / 3") +(2//3)*alpha + +julia> expr2 = FreeParameterExpression("alpha + 2 * beta") +alpha + 2beta +``` +""" + +struct FreeParameterExpression + expression::Symbolics.Num + + function FreeParameterExpression(expr::Union{FreeParameterExpression, Number, Symbolics.Num, String}) + if isa(expr, FreeParameterExpression) + return new(expr.expression) + elseif isa(expr, Number) + return new(Symbolics.Num(expr)) + elseif isa(expr, Symbolics.Num) + return new(expr) + elseif isa(expr, String) + parsed_expr = parse_expr_to_symbolic(Meta.parse(expr), @__MODULE__) + return new(parsed_expr) + else + throw(ArgumentError("Unsupported expression type")) + end + end +end + +Base.show(io::IO, fpe::FreeParameterExpression) = print(io, fpe.expression) +Base.copy(fp::FreeParameterExpression) = fp + +function subs(fpe::FreeParameterExpression, parameter_values::Dict{Symbol, <:Number}) + param_values_num = Dict(Symbolics.Variable(k) => v for (k, v) in parameter_values) + subbed_expr = Symbolics.substitute(fpe.expression, param_values_num) + if isempty(Symbolics.get_variables(subbed_expr)) + return subbed_expr + else + return FreeParameterExpression(subbed_expr) + end +end + +import Base: +, *, -, /, ^, == + ++(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression + fpe2) +*(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression * fpe2) +*(a::Number, fp::FreeParameter) = FreeParameterExpression("$(a) * $(fp.name)") +-(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression - fpe2) +/(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression / fpe2) +^(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression ^ fpe2) +-(fpe::FreeParameterExpression) = FreeParameterExpression(-fpe.expression) +==(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression) = Symbolics.simplify(fpe1.expression) == Symbolics.simplify(fpe2.expression) +==(fpe::FreeParameterExpression, expr::Symbolics.Num) = fpe.expression == expr + include("compiler_directive.jl") include("gates.jl") include("noises.jl") diff --git a/src/circuit.jl b/src/circuit.jl index a5a1fa67..eb69a200 100644 --- a/src/circuit.jl +++ b/src/circuit.jl @@ -20,7 +20,7 @@ mutable struct Circuit qubit_observable_mapping::Dict{Int, Observables.Observable} qubit_observable_target_mapping::Dict{Int, Tuple} qubit_observable_set::Set{Int} - parameters::Set{FreeParameter} + parameters::Set{Union{FreeParameter, FreeParameterExpression}} observables_simultaneously_measureable::Bool has_compiler_directives::Bool measure_targets::Vector{Int} diff --git a/src/gates.jl b/src/gates.jl index bac6d472..f9f3c453 100644 --- a/src/gates.jl +++ b/src/gates.jl @@ -47,10 +47,10 @@ for gate_def in ( $($G) gate. """ struct $G <: AngledGate{$n_angle} - angle::NTuple{$n_angle, Union{Float64, FreeParameter}} - $G(angle::T) where {T<:NTuple{$n_angle, Union{Float64, FreeParameter}}} = new(angle) + angle::NTuple{$n_angle, Union{Float64, FreeParameter, FreeParameterExpression}} + $G(angle::T) where {T<:NTuple{$n_angle, Union{Float64, FreeParameter, FreeParameterExpression}}} = new(angle) end - $G(angles::Vararg{Union{Float64, FreeParameter}}) = $G(tuple(angles...)) + $G(angles::Vararg{Union{Float64, FreeParameter, FreeParameterExpression}}) = $G(tuple(angles...)) $G(angles::Vararg{Number}) = $G((Float64(a) for a in angles)...) chars(::Type{$G}) = $c ir_typ(::Type{$G}) = $IR_G @@ -99,7 +99,7 @@ end (::Type{G})(x::Tuple{}) where {G<:Gate} = G() (::Type{G})(x::Tuple{}) where {G<:AngledGate} = throw(ArgumentError("angled gate must be constructed with at least one angle.")) (::Type{G})(x::AbstractVector) where {G<:AngledGate} = G(x...) -(::Type{G})(x::T) where {G<:AngledGate{1}, T<:Union{Float64, FreeParameter}} = G((x,)) +(::Type{G})(x::T) where {G<:AngledGate{1}, T<:Union{Float64, FreeParameter, FreeParameterExpression}} = G((x,)) qubit_count(g::G) where {G<:Gate} = qubit_count(G) angles(g::G) where {G<:Gate} = () angles(g::AngledGate{N}) where {N} = g.angle diff --git a/test/free_parameter.jl b/test/free_parameter.jl index 426f0e6e..0ef5d7e0 100644 --- a/test/free_parameter.jl +++ b/test/free_parameter.jl @@ -25,3 +25,20 @@ using Braket, Test @test b.name == :b @test copy(b) === b end + + +@testset "Free parameter Expression" begin + α = FreeParameter(:alpha) + θ = FreeParameter(:theta) + gate = FreeParameterExpression("α + 2*θ") + circ = Circuit() + circ = H(circ, 0) + circ = Rx(circ, 1, gate) + circ = Ry(circ, 0, θ) + circ = Ry(circ, 0, θ) + circ = Probability(circ) + #substitution + d = FreeParameterExpression("2*α + 3*θ") + e = subs(d, Dict(:α => 1.0, :θ => 2.0)) + @test e == 8.0 +end \ No newline at end of file