Skip to content

Commit

Permalink
Feature: Adding support for FreeParameterExpression, resolving Issue a…
Browse files Browse the repository at this point in the history
  • Loading branch information
Fe-r-oz committed Jun 5, 2024
1 parent 1d05da0 commit 5faf99e
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 7 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Braket"
uuid = "19504a0f-b47d-4348-9127-acc6cc69ef67"
authors = ["Katharine Hyatt <[email protected]>"]
version = "0.9.0"
version = "0.9.1"

[deps]
AWS = "fbe9abb3-538b-5e4e-ba9e-bc94f4f92ebc"
Expand Down Expand Up @@ -29,6 +29,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

Expand Down Expand Up @@ -64,6 +66,8 @@ SparseArrays = "1.6"
StaticArrays = "=1.9.3"
Statistics = "1.6"
StructTypes = "=1.10.0"
Symbolics = "5.28.0"
SymbolicUtils = "1.6.0"
Tar = "1.9.3"
Test = "1.6"
UUIDs = "1.6"
Expand All @@ -89,6 +93,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
73 changes: 72 additions & 1 deletion src/Braket.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Braket

export Circuit, QubitSet, Qubit, Device, AwsDevice, AwsQuantumTask, AwsQuantumTaskBatch
export metadata, status, Observable, Result, FreeParameter, Job, AwsQuantumJob, LocalQuantumJob, LocalSimulator
export metadata, status, Observable, Result, FreeParameter, FreeParameterExpression, Job, AwsQuantumJob, LocalQuantumJob, LocalSimulator, subs
export Tracker, simulator_tasks_cost, qpu_tasks_cost
export arn, cancel, state, result, results, name, download_result, id, ir, isavailable, search_devices, get_devices
export provider_name, properties, type
Expand Down Expand Up @@ -32,13 +32,18 @@ using DecFP
using Graphs
using HTTP
using StaticArrays
using Symbolics
using SymbolicUtils
using JSON3, StructTypes
using LinearAlgebra
using DataStructures
using NamedTupleTools
using OrderedCollections
using Tar

# Operator overloading for FreeParameterExpression
import Base: +, -, *, /, ^, ==

include("utils.jl")
"""
IRType
Expand Down Expand Up @@ -134,6 +139,72 @@ end
Base.copy(fp::FreeParameter) = fp
Base.show(io::IO, fp::FreeParameter) = print(io, string(fp.name))

"""
FreeParameterExpression
FreeParameterExpression(expr::Union{FreeParameterExpression, Number, Symbolics.Num, String})
Struct representing a Free Parameter expression, which can be used in symbolic computations.
### Examples
```jldoctest
julia> fp_alpha = FreeParameter(:alpha)
alpha
julia> fp_beta = FreeParameter(:beta)
beta
julia> expr1 = FreeParameterExpression("2 * alpha / 3")
(2//3)*alpha
julia> expr2 = FreeParameterExpression("alpha + 2 * beta")
alpha + 2beta
```
"""

struct FreeParameterExpression
expression::Symbolics.Num

function FreeParameterExpression(expr::Union{FreeParameterExpression, Number, Symbolics.Num, String})
if isa(expr, FreeParameterExpression)
return new(expr.expression)
elseif isa(expr, Number)
return new(Symbolics.Num(expr))
elseif isa(expr, Symbolics.Num)
return new(expr)
elseif isa(expr, String)
parsed_expr = parse_expr_to_symbolic(Meta.parse(expr), @__MODULE__)
return new(parsed_expr)
else
throw(ArgumentError("Unsupported expression type"))
end
end
end

Base.show(io::IO, fpe::FreeParameterExpression) = print(io, fpe.expression)
Base.copy(fp::FreeParameterExpression) = fp

function subs(fpe::FreeParameterExpression, parameter_values::Dict{Symbol, <:Number})
param_values_num = Dict(Symbolics.Variable(k) => v for (k, v) in parameter_values)
subbed_expr = Symbolics.substitute(fpe.expression, param_values_num)
if isempty(Symbolics.get_variables(subbed_expr))
return subbed_expr
else
return FreeParameterExpression(subbed_expr)
end
end

import Base: +, *, -, /, ^, ==

+(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression + fpe2)
*(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression * fpe2)
*(a::Number, fp::FreeParameter) = FreeParameterExpression("$(a) * $(fp.name)")
-(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression - fpe2)
/(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression / fpe2)
^(fpe1::FreeParameterExpression, fpe2::Union{FreeParameterExpression, Number}) = FreeParameterExpression(fpe1.expression ^ fpe2)
-(fpe::FreeParameterExpression) = FreeParameterExpression(-fpe.expression)
==(fpe1::FreeParameterExpression, fpe2::FreeParameterExpression) = Symbolics.simplify(fpe1.expression) == Symbolics.simplify(fpe2.expression)
==(fpe::FreeParameterExpression, expr::Symbolics.Num) = fpe.expression == expr

include("compiler_directive.jl")
include("gates.jl")
include("noises.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/circuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mutable struct Circuit
qubit_observable_mapping::Dict{Int, Observables.Observable}
qubit_observable_target_mapping::Dict{Int, Tuple}
qubit_observable_set::Set{Int}
parameters::Set{FreeParameter}
parameters::Set{Union{FreeParameter, FreeParameterExpression}}
observables_simultaneously_measureable::Bool
has_compiler_directives::Bool
measure_targets::Vector{Int}
Expand Down
8 changes: 4 additions & 4 deletions src/gates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ for gate_def in (
$($G) gate.
"""
struct $G <: AngledGate{$n_angle}
angle::NTuple{$n_angle, Union{Float64, FreeParameter}}
$G(angle::T) where {T<:NTuple{$n_angle, Union{Float64, FreeParameter}}} = new(angle)
angle::NTuple{$n_angle, Union{Float64, FreeParameter, FreeParameterExpression}}
$G(angle::T) where {T<:NTuple{$n_angle, Union{Float64, FreeParameter, FreeParameterExpression}}} = new(angle)
end
$G(angles::Vararg{Union{Float64, FreeParameter}}) = $G(tuple(angles...))
$G(angles::Vararg{Union{Float64, FreeParameter, FreeParameterExpression}}) = $G(tuple(angles...))
$G(angles::Vararg{Number}) = $G((Float64(a) for a in angles)...)
chars(::Type{$G}) = $c
ir_typ(::Type{$G}) = $IR_G
Expand Down Expand Up @@ -99,7 +99,7 @@ end
(::Type{G})(x::Tuple{}) where {G<:Gate} = G()
(::Type{G})(x::Tuple{}) where {G<:AngledGate} = throw(ArgumentError("angled gate must be constructed with at least one angle."))
(::Type{G})(x::AbstractVector) where {G<:AngledGate} = G(x...)
(::Type{G})(x::T) where {G<:AngledGate{1}, T<:Union{Float64, FreeParameter}} = G((x,))
(::Type{G})(x::T) where {G<:AngledGate{1}, T<:Union{Float64, FreeParameter, FreeParameterExpression}} = G((x,))
qubit_count(g::G) where {G<:Gate} = qubit_count(G)
angles(g::G) where {G<:Gate} = ()
angles(g::AngledGate{N}) where {N} = g.angle
Expand Down
17 changes: 17 additions & 0 deletions test/free_parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,20 @@ using Braket, Test
@test b.name == :b
@test copy(b) === b
end


@testset "Free parameter Expression" begin
α = FreeParameter(:alpha)
θ = FreeParameter(:theta)
gate = FreeParameterExpression("α + 2*θ")
circ = Circuit()
circ = H(circ, 0)
circ = Rx(circ, 1, gate)
circ = Ry(circ, 0, θ)
circ = Ry(circ, 0, θ)
circ = Probability(circ)
#substitution
d = FreeParameterExpression("2*α + 3*θ")
e = subs(d, Dict( => 1.0, => 2.0))
@test e == 8.0
end

0 comments on commit 5faf99e

Please sign in to comment.