Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix and more cleanup #21

Merged
merged 4 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/qasm_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Base.copy(qasm_expr::QasmExpression) = QasmExpression(qasm_expr.head, deepcopy(q

head(qasm_expr::QasmExpression) = qasm_expr.head

Base.convert(::Type{Vector{QasmExpression}}, expr::QasmExpression) = head(expr) == :array_literal ? convert(Vector{QasmExpression}, expr.args) : [expr]

AbstractTrees.children(qasm_expr::QasmExpression) = qasm_expr.args
AbstractTrees.printnode(io::IO, qasm_expr::QasmExpression) = print(io, "QasmExpression :$(qasm_expr.head)")

Expand Down
157 changes: 77 additions & 80 deletions src/visitor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,9 @@ mutable struct QasmFunctionVisitor <: AbstractVisitor
return v
end
end
function QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::Vector{QasmExpression}, provided_arguments::QasmExpression)
head(provided_arguments) == :array_literal && return QasmFunctionVisitor(parent, declared_arguments, convert(Vector{QasmExpression}, provided_arguments.args))
QasmFunctionVisitor(parent, declared_arguments, [provided_arguments])
end
function QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::QasmExpression, provided_arguments)
head(declared_arguments) == :array_literal && return QasmFunctionVisitor(parent, convert(Vector{QasmExpression}, declared_arguments.args), provided_arguments)
QasmFunctionVisitor(parent, [declared_arguments], provided_arguments)
end
QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::Vector{QasmExpression}, provided_arguments::QasmExpression) = QasmFunctionVisitor(parent, declared_arguments, convert(Vector{QasmExpression}, provided_arguments))
QasmFunctionVisitor(parent::AbstractVisitor, declared_arguments::QasmExpression, provided_arguments) = QasmFunctionVisitor(parent, convert(Vector{QasmExpression}, declared_arguments), provided_arguments)

Base.parent(v::AbstractVisitor) = v.parent

hasgate(v::AbstractVisitor, gate_name::String) = hasgate(parent(v), gate_name)
Expand Down Expand Up @@ -281,9 +276,9 @@ function evaluate_modifiers(v::V, expr::QasmExpression) where {V<:AbstractVisito
arg_val::Int = v(first(expr.args)::QasmExpression)::Int
isinteger(arg_val) || throw(QasmVisitorError("cannot apply non-integer ($arg_val) number of controls or negcontrols."))
true_inner = expr.args[2]::QasmExpression
inner = QasmExpression(head(expr), true_inner)
inner = QasmExpression(head(expr), true_inner)
while arg_val > 2
inner = QasmExpression(head(expr), inner)
inner = QasmExpression(head(expr), inner)
arg_val -= 1
end
else
Expand Down Expand Up @@ -346,6 +341,7 @@ end
evaluate_qubits(v::AbstractVisitor, qubit_targets::QasmExpression) = evaluate_qubits(v::AbstractVisitor, [qubit_targets])

function remap(ix, target_mapper::Dict{Int, Int})
isempty(target_mapper) && return ix
mapped_targets = map(t->getindex(target_mapper, t), ix.targets)
mapped_controls = map(c->getindex(target_mapper, c[1])=>c[2], ix.controls)
return (type=ix.type, arguments=ix.arguments, targets=mapped_targets, controls=mapped_controls, exponent=ix.exponent)
Expand All @@ -359,57 +355,64 @@ function process_gate_arguments(v::AbstractVisitor, gate_name::String, defined_a
def_has_arguments = !isempty(defined_arguments)
call_has_arguments = !isempty(v(called_arguments))
if def_has_arguments ⊻ call_has_arguments
def_has_arguments && throw(QasmVisitorError("gate $gate_name requires arguments but none were provided."))
def_has_arguments && throw(QasmVisitorError("gate $gate_name requires arguments but none were provided."))
call_has_arguments && throw(QasmVisitorError("gate $gate_name does not accept arguments but arguments were provided."))
end
if def_has_arguments
evaled_args = v(called_arguments)
argument_values = Dict{Symbol, Real}(Symbol(arg_name)=>argument for (arg_name, argument) in zip(defined_arguments, evaled_args))
return map(ix->bind_arguments!(ix, argument_values), gate_body)
else
return deepcopy(gate_body)
end
!def_has_arguments && return deepcopy(gate_body) # deep copy to avoid overwriting canonical definition

evaled_args = v(called_arguments)
argument_values = Dict{Symbol, Real}(Symbol(arg_name)=>argument for (arg_name, argument) in zip(defined_arguments, evaled_args))
return map(ix->bind_arguments!(ix, argument_values), gate_body)
end

function handle_gate_modifiers(ixs, mods::Vector{QasmExpression}, control_qubits::Vector{Int}, is_gphase::Bool)
for mod in Iterators.reverse(mods)
control_qubit = head(mod) ∈ (:negctrl, :ctrl) ? pop!(control_qubits) : -1
for (ii, ix) in enumerate(ixs)
if head(mod) == :pow
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=ix.exponent*mod.args[1])
elseif head(mod) == :inv
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=-ix.exponent)
# need to handle "extra" target
elseif head(mod) ∈ (:negctrl, :ctrl)
if head(mod) ∈ (:negctrl, :ctrl)
control_qubit = pop!(control_qubits)
for (ii, ix) in enumerate(ixs)
exp = ix.exponent
targets = ix.targets
controls = ix.controls
bit = head(mod) == :ctrl ? 1 : 0
if is_gphase
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=pushfirst!(ix.controls, control_qubit=>bit), exponent=ix.exponent)
else
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=pushfirst!(ix.targets, control_qubit), controls=pushfirst!(ix.controls, control_qubit=>bit), exponent=ix.exponent)
controls = pushfirst!(controls, control_qubit=>bit)
if !is_gphase
targets = pushfirst!(targets, control_qubit)
end
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=targets, controls=controls, exponent=exp)
end
elseif head(mod) == :inv
reverse!(ixs)
for (ii, ix) in enumerate(ixs)
ixs[ii] = (type=ix.type, arguments=ix.arguments, targets=ix.targets, controls=ix.controls, exponent=-ix.exponent)
end
elseif head(mod) == :pow
pow_exp = mod.args[1]
(isinteger(pow_exp) || length(ixs) == 1) || throw(QasmVisitorError("can't apply a non-integer exponent to a gate of multiple instructions")) # can't do 2.5 for a list... yet
if length(ixs) > 1
pow_exp < 0 && reverse!(ixs)
ixs = repeat(ixs, abs(pow_exp))
else
ixs[1] = (type=ixs[1].type, arguments=ixs[1].arguments, targets=ixs[1].targets, controls=ixs[1].controls, exponent=ixs[1].exponent*pow_exp)
end
end
head(mod) == :inv && reverse!(ixs)
end
return ixs
end

function splat_gate_targets(gate_targets::Vector{Vector{Int}})
target_lengths::Vector{Int} = Int[length(t) for t in gate_targets]
longest = maximum(target_lengths)
longest = maximum(target_lengths)
must_splat::Bool = any(len->len!=1 || len != longest, target_lengths)
!must_splat && return longest, gate_targets
for target_ix in 1:length(gate_targets)
if target_lengths[target_ix] == 1
append!(gate_targets[target_ix], fill(only(gate_targets[target_ix]), longest-1))
end
for target_ix in filter(ix->target_lengths[ix] == 1, 1:length(gate_targets))
append!(gate_targets[target_ix], fill(only(gate_targets[target_ix]), longest-1))
end
return longest, gate_targets
end

function visit_gphase_call(v::AbstractVisitor, program_expr::QasmExpression)
has_modifiers = length(program_expr.args) == 4
n_called_with::Int = qubit_count(v)
has_modifiers = length(program_expr.args) == 4
n_called_with::Int = qubit_count(v)
gate_targets::Vector{Int} = collect(0:n_called_with-1)
provided_arg::QasmExpression = only(program_expr.args[2].args)
evaled_arg = v(provided_arg)
Expand All @@ -421,17 +424,9 @@ function visit_gphase_call(v::AbstractVisitor, program_expr::QasmExpression)
return
end

function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression)
gate_name = name(program_expr)::String
raw_call_targets = program_expr.args[3]::QasmExpression
call_targets::Vector{QasmExpression} = convert(Vector{QasmExpression}, head(raw_call_targets.args[1]) == :array_literal ? raw_call_targets.args[1].args : raw_call_targets.args)::Vector{QasmExpression}
provided_args = isempty(program_expr.args[2].args) ? QasmExpression(:empty) : only(program_expr.args[2].args)::QasmExpression
has_modifiers = length(program_expr.args) == 4
hasgate(v, gate_name) || throw(QasmVisitorError("gate $gate_name not defined!"))
gate_def = gate_defs(v)[gate_name]
gate_def_v = QasmGateDefVisitor(v, gate_def.arguments, provided_args, gate_def.qubit_targets)
gate_def_v(deepcopy(gate_def.body))
gate_ixs = instructions(gate_def_v)
function process_gate_targets(v, expr, gate_def)
raw_call_targets = expr.args[3]::QasmExpression
call_targets::Vector{QasmExpression} = convert(Vector{QasmExpression}, raw_call_targets.args[1])::Vector{QasmExpression}
gate_targets = Vector{Int}[evaluate_qubits(v, call_target)::Vector{Int} for call_target in call_targets]
n_called_with = length(gate_targets)
n_defined_with = length(gate_def.qubit_targets)
Expand All @@ -440,18 +435,31 @@ function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression)
n_called_with = length(gate_targets[1])
gate_targets = Vector{Int}[[gt] for gt in gate_targets[1]]
end
applied_arguments = process_gate_arguments(v, gate_name, gate_def.arguments, provided_args, gate_ixs)
control_qubits::Vector{Int} = collect(0:(n_called_with-n_defined_with)-1)
modifier_remap = Dict{Int, Int}(old_qubit=>(old_qubit + length(control_qubits)) for old_qubit in 0:length(gate_def.qubit_targets))
return gate_targets, control_qubits, n_called_with, modifier_remap
end

function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression)
gate_name = name(program_expr)::String
provided_args = isempty(program_expr.args[2].args) ? QasmExpression(:empty) : only(program_expr.args[2].args)::QasmExpression
has_modifiers = length(program_expr.args) == 4
mods::Vector{QasmExpression} = has_modifiers ? convert(Vector{QasmExpression}, program_expr.args[4].args) : QasmExpression[]
if !isempty(control_qubits)
modifier_remap = Dict{Int, Int}(old_qubit=>(old_qubit + length(control_qubits)) for old_qubit in 0:length(gate_def.qubit_targets))
for ii in 1:length(applied_arguments)
applied_arguments[ii] = remap(applied_arguments[ii], modifier_remap)
end
hasgate(v, gate_name) || throw(QasmVisitorError("gate $gate_name not defined!"))
gate_def = gate_defs(v)[gate_name]
gate_def_v = QasmGateDefVisitor(v, gate_def.arguments, provided_args, gate_def.qubit_targets)
gate_def_v(deepcopy(gate_def.body))
gate_ixs = instructions(gate_def_v)
# generate instruction list based on provided arguments to gate
applied_arguments = process_gate_arguments(v, gate_name, gate_def.arguments, provided_args, gate_ixs)
gate_targets, control_qubits, n_called_with, modifier_remap = process_gate_targets(v, program_expr, gate_def)
for ii in 1:length(applied_arguments) # first apply any needed control qubits to the entire gate, shuffling targets
applied_arguments[ii] = remap(applied_arguments[ii], modifier_remap)
end
# go through individual instructions, applying the modifiers to each argument
applied_arguments = handle_gate_modifiers(applied_arguments, mods, control_qubits, false)
longest, gate_targets = splat_gate_targets(gate_targets)
for splatted_ix in 1:longest
for splatted_ix in 1:longest # then splat if necessary
target_mapper = Dict{Int, Int}(g_ix=>gate_targets[g_ix+1][splatted_ix] for g_ix in 0:n_called_with-1)
push!(v, map(ix->remap(ix, target_mapper), applied_arguments))
end
Expand All @@ -473,17 +481,8 @@ function visit_function_call(v, expr, function_name)
function_v(f_expr)
end
end
# remap qubits and classical variables
function_args = if head(declared_args) == :array_literal
convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression}
else
declared_args
end
called_args = if head(provided_args) == :array_literal
convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression}
else
provided_args
end
function_args = convert(Vector{QasmExpression}, declared_args)::Vector{QasmExpression}
called_args = convert(Vector{QasmExpression}, provided_args)::Vector{QasmExpression}
reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args))
reverse_qubits_map = Dict{Int, Int}()
for variable in filter(v->head(v) ∈ (:identifier, :indexed_identifier), keys(reverse_arguments_map))
Expand Down Expand Up @@ -598,9 +597,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
end
delete!(classical_defs(v), loop_variable_name)
elseif head(program_expr) == :switch
case_val = v(program_expr.args[1])
all_cases = convert(Vector{QasmExpression}, program_expr.args[2:end])
default = findfirst(expr->head(expr) == :default, all_cases)
case_val = v(program_expr.args[1])
all_cases = convert(Vector{QasmExpression}, program_expr.args[2:end])
default = findfirst(expr->head(expr) == :default, all_cases)
case_found = false
for case in all_cases
if head(case) == :case && case_val ∈ v(case.args[1])
Expand All @@ -614,7 +613,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
foreach(v, convert(Vector{QasmExpression}, all_cases[default].args))
end
elseif head(program_expr) == :alias
alias_name = name(program_expr)
alias_name = name(program_expr)
right_hand_side = program_expr.args[1].args[1].args[end]
if head(right_hand_side) == :binary_op
right_hand_side.args[1] == Symbol("++") || throw(QasmVisitorError("right hand side of alias must be either an identifier or concatenation"))
Expand All @@ -624,8 +623,8 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
is_right_qubit = haskey(qubit_mapping(v), name(concat_right))
(is_left_qubit ⊻ is_right_qubit) && throw(QasmVisitorError("cannot concatenate qubit and classical arrays"))
if is_left_qubit
left_qs = v(concat_left)
right_qs = v(concat_right)
left_qs = v(concat_left)
right_qs = v(concat_right)
alias_qubits = collect(vcat(left_qs, right_qs))
qubit_size = length(alias_qubits)
qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size)
Expand All @@ -636,7 +635,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
else # both classical
left_array = classical_defs(v)[name(concat_left)]
right_array = classical_defs(v)[name(concat_right)]
new_size = QasmExpression(:binary_op, :+, only(size(left_array.type)), only(size(right_array.type)))
new_size = QasmExpression(:binary_op, :+, only(size(left_array.type)), only(size(right_array.type)))
if left_array.type isa SizedBitVector
classical_defs(v)[alias_name] = ClassicalVariable(alias_name, new_size, vcat(left_array.val, right_array.val), false)
else
Expand Down Expand Up @@ -775,9 +774,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
new_val = evaluate_binary_op(op, left_val, right_val)
end
if length(inds) > 1
var.val[inds] .= new_val
var.val[inds] .= new_val
else
var.val[inds] = new_val
var.val[inds] = new_val
end
end
elseif head(program_expr) == :classical_declaration
Expand Down Expand Up @@ -835,11 +834,9 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
gate_arguments = gate_def[2]::QasmExpression
gate_def_targets = gate_def[3]::QasmExpression
gate_body = gate_def[4]::QasmExpression
single_argument = !isempty(gate_arguments.args) && head(gate_arguments.args[1]) == :array_literal
argument_exprs = single_argument ? gate_arguments.args[1].args::Vector{Any} : gate_arguments.args::Vector{Any}
argument_exprs = !isempty(gate_arguments.args) ? convert(Vector{QasmExpression}, gate_arguments.args[1]) : QasmExpression[]
argument_names = String[arg.args[1] for arg::QasmExpression in argument_exprs]
single_target = head(gate_def_targets.args[1]) == :array_literal
qubit_targets = single_target ? map(name, gate_def_targets.args[1].args)::Vector{String} : map(name, gate_def_targets.args)::Vector{String}
qubit_targets = map(name, convert(Vector{QasmExpression}, gate_def_targets.args[1]))::Vector{String}
v.gate_defs[gate_name] = GateDefinition(gate_name, argument_names, qubit_targets, gate_body)
elseif head(program_expr) == :function_call
function_name = name(program_expr)
Expand Down
12 changes: 12 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,11 @@ Quasar.builtin_gates[] = complex_builtin_gates
gate cxx_2 c, a {
pow(1/2) @ pow(4) @ cx c, a;
}
gate cxx_3 c, a {
pow(1/2) @ pow(4) @ cx c, a;
i c;
i a;
}
gate cxxx c, a {
pow(1) @ pow(two) @ cx c, a;
}
Expand All @@ -1220,6 +1225,7 @@ Quasar.builtin_gates[] = complex_builtin_gates
cx q1, q2; // flip
cxx_1 q1, q3; // don't flip
cxx_2 q1, q4; // don't flip
pow(2) @ cxx_3 q1, q4; // don't flip
cx q1, q5; // flip
x q3; // flip
x q4; // flip
Expand All @@ -1236,6 +1242,12 @@ Quasar.builtin_gates[] = complex_builtin_gates
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 1], controls=[0=>1], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 2], controls=[0=>1], exponent=2.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0),
(type="i", arguments=InstructionArgument[], targets=[0], controls=Pair{Int,Int}[], exponent=1.0),
(type="i", arguments=InstructionArgument[], targets=[3], controls=Pair{Int,Int}[], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 3], controls=[0=>1], exponent=2.0),
(type="i", arguments=InstructionArgument[], targets=[0], controls=Pair{Int,Int}[], exponent=1.0),
(type="i", arguments=InstructionArgument[], targets=[3], controls=Pair{Int,Int}[], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[0, 4], controls=[0=>1], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[2], controls=Pair{Int,Int}[], exponent=1.0),
(type="u", arguments=InstructionArgument[π, 0, π], targets=[3], controls=Pair{Int,Int}[], exponent=1.0),
Expand Down
Loading