Skip to content

Commit

Permalink
Some cleanup and also aliasing support (#19)
Browse files Browse the repository at this point in the history
* More small cleanup of parser

* Small cleanup in visitor

* Restore rotation parameter tests

* Basic aliasing support
  • Loading branch information
kshyatt-aws authored Nov 26, 2024
1 parent 71ea493 commit 722c4cc
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 114 deletions.
39 changes: 17 additions & 22 deletions src/parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ parse_identifier(token, qasm) = QasmExpression(:identifier, String(read_raw(toke
function parse_block_body(expr, tokens, stack, start, qasm)
is_scope = tokens[1][end] == lbrace
if is_scope
body = parse_scope(tokens, stack, start, qasm)
body_exprs = convert(Vector{QasmExpression}, collect(Iterators.reverse(body)))::Vector{QasmExpression}
scope_tokens = extract_expression(tokens, lbrace, rbrace, stack, start, qasm)
body = parse_qasm(scope_tokens, qasm, QasmExpression(:scope))
body_exprs = convert(Vector{QasmExpression}, collect(Iterators.reverse(body)))::Vector{QasmExpression}
foreach(body_expr->push!(body_exprs[1], body_expr), body_exprs[2:end])
push!(expr, body_exprs[1])
else # one line
Expand Down Expand Up @@ -104,9 +105,9 @@ function parse_function_def(tokens, stack, start, qasm)
return expr
end
function parse_gate_or_cal_def(head::Symbol, tokens, stack, start, qasm)
def_name = popfirst!(tokens)
def_name = popfirst!(tokens)
def_name[end] == identifier || throw(QasmParseError("$head must have a valid identifier as a name", stack, start, qasm))
def_name_id = parse_identifier(def_name, qasm)
def_name_id = parse_identifier(def_name, qasm)

def_args = parse_arguments_list(tokens, stack, start, qasm)
qubit_tokens = splice!(tokens, 1:findfirst(triplet->triplet[end]==lbrace, tokens)-1)
Expand Down Expand Up @@ -226,11 +227,6 @@ function extract_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, opener,
return extracted_tokens
end

function parse_scope(tokens, stack, start, qasm)
scope_tokens = extract_expression(tokens, lbrace, rbrace, stack, start, qasm)
return parse_qasm(scope_tokens, qasm, QasmExpression(:scope))
end

function parse_list_expression(tokens::Vector{Tuple{Int64, Int32, Token}}, stack, start, qasm)
expr_list = QasmExpression[]
while !isempty(tokens) && first(tokens)[end] != semicolon
Expand Down Expand Up @@ -383,14 +379,13 @@ function expression_start(tokens, stack, start, qasm)
expr_head = parse_list_expression(interior_tokens, stack, start, qasm)
elseif start_token[end] == classical_type
type_tokens = pushfirst!(tokens, start_token)
raw_expr = parse_classical_type(type_tokens, stack, start, qasm)
raw_expr = parse_classical_type(type_tokens, stack, start, qasm)
expr_head = raw_expr
if !isempty(tokens) && first(tokens)[end] == lparen
interior = extract_expression(tokens, lparen, rparen, stack, start, qasm)
expr_head = QasmExpression(:cast, raw_expr, parse_expression(interior, stack, start, qasm))
elseif !isempty(tokens) && first(tokens)[end] == identifier
expr_head = QasmExpression(:classical_declaration, raw_expr, parse_expression(tokens, stack, start, qasm))
else
expr_head = raw_expr
end
elseif start_token[end] == waveform_token && next_token[end] != identifier
expr_head = QasmExpression(:waveform)
Expand Down Expand Up @@ -418,16 +413,14 @@ end
function parse_range(expr_head, tokens, stack, start, qasm)
popfirst!(tokens)
second_colon = findfirst(triplet->triplet[end] == colon, tokens)
step = QasmExpression(:integer_literal, 1)
if !isnothing(second_colon)
step_tokens = push!(splice!(tokens, 1:second_colon-1), (-1, Int32(-1), semicolon))
popfirst!(tokens) # colon
step = parse_expression(step_tokens, stack, start, qasm)::QasmExpression
else
step = QasmExpression(:integer_literal, 1)
end
if isempty(tokens) || first(tokens)[end] == semicolon # missing stop
stop = QasmExpression(:integer_literal, -1)
else
stop = QasmExpression(:integer_literal, -1)
if !isempty(tokens) && first(tokens)[end] != semicolon # missing stop
stop = parse_expression(tokens, stack, start, qasm)::QasmExpression
end
return QasmExpression(:range, QasmExpression[expr_head, step, stop])
Expand Down Expand Up @@ -485,11 +478,11 @@ function parse_unary_op(tokens, stack, start, qasm)
expr = QasmExpression(:complex_literal, -real(next_expr.args[1]) + im*imag(next_expr.args[1]))
end
elseif head(next_expr) == :binary_op && !next_token_is_paren
# replace first argument if next token isn't a paren
left_hand_side = next_expr.args[2]::QasmExpression
new_left_hand_side = QasmExpression(:unary_op, unary_op_symbol, left_hand_side)
next_expr.args[2] = new_left_hand_side
expr = next_expr
# replace first argument if next token isn't a paren
left_hand_side = next_expr.args[2]::QasmExpression
new_left_hand_side = QasmExpression(:unary_op, unary_op_symbol, left_hand_side)
next_expr.args[2] = new_left_hand_side
expr = next_expr
else
expr = QasmExpression(:unary_op, unary_op_symbol, next_expr)
end
Expand Down Expand Up @@ -718,6 +711,8 @@ function parse_qasm(clean_tokens::Vector{Tuple{Int64, Int32, Token}}, qasm::Stri
push!(stack, delay_expr)
elseif token == end_token
push!(stack, QasmExpression(:end))
elseif token == alias
push!(stack, QasmExpression(:alias, parse_expression(clean_tokens, stack, start, qasm)))
elseif token == identifier || token == builtin_gate
clean_tokens = pushfirst!(clean_tokens, (start, len, token))
expr = parse_expression(clean_tokens, stack, start, qasm)
Expand Down
190 changes: 123 additions & 67 deletions src/visitor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ function name(expr::QasmExpression)::String
head(expr) == :gate_call && return name(expr.args[1]::QasmExpression)
head(expr) == :gate_definition && return name(expr.args[1]::QasmExpression)
head(expr) == :classical_assignment && return name(expr.args[1].args[2]::QasmExpression)
head(expr) == :alias && return name(expr.args[1]::QasmExpression)
head(expr) == :hw_qubit && return replace(expr.args[1], "\$"=>"")
throw(QasmVisitorError("name not defined for expressions of type $(head(expr))"))
end
Expand All @@ -323,8 +324,14 @@ function _evaluate_qubits(::Val{:indexed_identifier}, v, qubit_expr::QasmExpress
haskey(mapping, qubit_name) || throw(QasmVisitorError("Missing input variable '$qubit_name'.", "NameError"))
qubit_ix = v(qubit_expr.args[2]::QasmExpression)
qubits = Iterators.flatmap(qubit_ix) do rq
haskey(mapping, qubit_name * "[$rq]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError"))
return mapping[qubit_name * "[$rq]"]
if rq >= 0
haskey(mapping, qubit_name * "[$rq]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError"))
return mapping[qubit_name * "[$rq]"]
else
qubit_size = length(mapping[qubit_name])
haskey(mapping, qubit_name * "[$(qubit_size + rq)]") || throw(QasmVisitorError("Invalid qubit index '$rq' in '$qubit_name'.", "IndexError"))
return mapping[qubit_name * "[$(qubit_size + rq)]"]
end
end
return collect(qubits)
end
Expand Down Expand Up @@ -451,6 +458,64 @@ function visit_gate_call(v::AbstractVisitor, program_expr::QasmExpression)
return
end

function visit_function_call(v, expr, function_name)
function_def = function_defs(v)[function_name]
function_body = function_def.body::Vector{QasmExpression}
declared_args = only(function_def.arguments.args)::QasmExpression
provided_args = only(expr.args[2].args)::QasmExpression
function_v = QasmFunctionVisitor(v, declared_args, provided_args)
return_val = nothing
body_exprs::Vector{QasmExpression} = head(function_body[1]) == :scope ? function_body[1].args : function_body
for f_expr in body_exprs
if head(f_expr) == :return
return_val = function_v(f_expr.args[1])
else
function_v(f_expr)
end
end
# remap qubits and classical variables
function_args = if head(declared_args) == :array_literal
convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression}
else
declared_args
end
called_args = if head(provided_args) == :array_literal
convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression}
else
provided_args
end
reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args))
reverse_qubits_map = Dict{Int, Int}()
for variable in filter(v->head(v) (:identifier, :indexed_identifier), keys(reverse_arguments_map))
variable_name = name(variable)
if haskey(classical_defs(v), variable_name) && classical_defs(v)[variable_name].type isa SizedArray && head(reverse_arguments_map[variable]) != :const_declaration
inner_variable_name = name(reverse_arguments_map[variable])
new_val = classical_defs(function_v)[inner_variable_name].val
back_assignment = QasmExpression(:classical_assignment, QasmExpression(:binary_op, Symbol("="), variable, new_val))
v(back_assignment)
elseif haskey(qubit_defs(v), variable_name)
outer_context_map = only(evaluate_qubits(v, variable))
inner_context_map = only(evaluate_qubits(function_v, reverse_arguments_map[variable].args[1]))
reverse_qubits_map[inner_context_map] = outer_context_map
end
end
mapper = isempty(reverse_qubits_map) ? identity : ix->remap(ix, reverse_qubits_map)
push!(v, map(mapper, function_v.instructions))
return return_val
end

function declaration_init(v, expr::QasmExpression)
var_type = expr.args[1].args[1]
init = if var_type isa SizedNumber
undef
elseif var_type isa SizedArray
fill(undef, v(var_type.size))
elseif var_type isa SizedBitVector
falses(max(0, v(var_type.size)))
end
return init, var_type
end

(v::AbstractVisitor)(i::Number) = i
(v::AbstractVisitor)(i::String) = i
(v::AbstractVisitor)(i::BitVector) = i
Expand Down Expand Up @@ -548,6 +613,58 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
isnothing(default) && throw(QasmVisitorError("no case matched and no default defined."))
foreach(v, convert(Vector{QasmExpression}, all_cases[default].args))
end
elseif head(program_expr) == :alias
alias_name = name(program_expr)
right_hand_side = program_expr.args[1].args[1].args[end]
if head(right_hand_side) == :binary_op
right_hand_side.args[1] == Symbol("++") || throw(QasmVisitorError("right hand side of alias must be either an identifier or concatenation"))
concat_left = right_hand_side.args[2]
concat_right = right_hand_side.args[3]
is_left_qubit = haskey(qubit_mapping(v), name(concat_left))
is_right_qubit = haskey(qubit_mapping(v), name(concat_right))
(is_left_qubit is_right_qubit) && throw(QasmVisitorError("cannot concatenate qubit and classical arrays"))
if is_left_qubit
left_qs = v(concat_left)
right_qs = v(concat_right)
alias_qubits = collect(vcat(left_qs, right_qs))
qubit_size = length(alias_qubits)
qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size)
qubit_mapping(v)[alias_name] = alias_qubits
for qubit_i in 0:qubit_size-1
qubit_mapping(v)["$alias_name[$qubit_i]"] = [alias_qubits[qubit_i+1]]
end
else # both classical
throw(QasmVisitorError("classical array concatenation not yet supported!"))
end
elseif head(right_hand_side) == :identifier
referent_name = name(right_hand_side)
is_qubit = haskey(qubit_mapping(v), referent_name)
if is_qubit
qubit_defs(v)[alias_name] = qubit_defs(v)[referent_name]
qubit_mapping(v)[alias_name] = qubit_mapping(v)[referent_name]
qubit_size = length(qubit_mapping(v)[alias_name])
for qubit_i in 0:qubit_size-1
qubit_mapping(v)["$alias_name[$qubit_i]"] = qubit_mapping(v)["$referent_name[$qubit_i]"]
end
else
classical_defs(v)[alias_name] = classical_defs(v)[referent_name]
end
elseif head(right_hand_side) == :indexed_identifier
referent_name = name(right_hand_side)
is_qubit = haskey(qubit_mapping(v), referent_name)
if is_qubit
alias_qubits = v(right_hand_side)
qubit_size = length(alias_qubits)
qubit_defs(v)[alias_name] = Qubit(alias_name, qubit_size)
qubit_mapping(v)[alias_name] = collect(alias_qubits)
for qubit_i in 0:qubit_size-1
qubit_mapping(v)["$alias_name[$qubit_i]"] = [alias_qubits[qubit_i+1]]
end
else
referent = classical_defs(v)[referent_name]
classical_defs(v)[alias_name] = ClassicalVariable(alias_name, referent.type, view(referent.val, v(right_hand_side.args[end]) .+ 1), referent.is_const)
end
end
elseif head(program_expr) == :identifier
id_name = name(program_expr)
haskey(classical_defs(v), id_name) && return classical_defs(v)[id_name].val
Expand Down Expand Up @@ -610,7 +727,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
condition_value = while_v(program_expr.args[1])
end
elseif head(program_expr) == :classical_assignment
op = program_expr.args[1].args[1]::Symbol
op = program_expr.args[1].args[1]::Symbol
left_hand_side = program_expr.args[1].args[2]::QasmExpression
right_hand_side = program_expr.args[1].args[3]
var_name = name(left_hand_side)::String
Expand Down Expand Up @@ -651,14 +768,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
end
end
elseif head(program_expr) == :classical_declaration
var_type = program_expr.args[1].args[1]
init = if var_type isa SizedNumber
undef
elseif var_type isa SizedArray
fill(undef, v(var_type.size))
elseif var_type isa SizedBitVector
falses(max(0, v(var_type.size)))
end
init, var_type = declaration_init(v, program_expr)
# no initial value
if head(program_expr.args[2]) == :identifier
var_name = name(program_expr.args[2])
Expand All @@ -671,14 +781,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
end
elseif head(program_expr) == :const_declaration
head(program_expr.args[2]) == :classical_assignment || throw(QasmVisitorError("const declaration must assign an initial value."))
var_type = program_expr.args[1].args[1]
init = if var_type isa SizedNumber
undef
elseif var_type isa SizedArray
fill(undef, v(var_type.size))
elseif var_type isa SizedBitVector
falses(max(0, v(var_type.size)))
end
init, var_type = declaration_init(v, program_expr)
op, left_hand_side, right_hand_side = program_expr.args[2].args[1].args
var_name = name(left_hand_side)
v.classical_defs[var_name] = ClassicalVariable(var_name, var_type, init, false)
Expand Down Expand Up @@ -737,54 +840,7 @@ function (v::AbstractVisitor)(program_expr::QasmExpression)
return return_val[1]
else
hasfunction(v, function_name) || throw(QasmVisitorError("function $function_name not defined!"))
function_def = function_defs(v)[function_name]
function_body = function_def.body::Vector{QasmExpression}
declared_args = only(function_def.arguments.args)::QasmExpression
provided_args = only(program_expr.args[2].args)::QasmExpression
function_v = QasmFunctionVisitor(v, declared_args, provided_args)
return_val = nothing
body_exprs::Vector{QasmExpression} = head(function_body[1]) == :scope ? function_body[1].args : function_body
for f_expr in body_exprs
if head(f_expr) == :return
return_val = function_v(f_expr.args[1])
else
function_v(f_expr)
end
end
# remap qubits and classical variables
function_args = if head(declared_args) == :array_literal
convert(Vector{QasmExpression}, declared_args.args)::Vector{QasmExpression}
else
declared_args
end
called_args = if head(provided_args) == :array_literal
convert(Vector{QasmExpression}, provided_args.args)::Vector{QasmExpression}
else
provided_args
end
arguments_map = Dict{QasmExpression, QasmExpression}(zip(function_args, called_args))
reverse_arguments_map = Dict{QasmExpression, QasmExpression}(zip(called_args, function_args))
reverse_qubits_map = Dict{Int, Int}()
for variable in keys(reverse_arguments_map)
if head(variable) (:identifier, :indexed_identifier)
variable_name = name(variable)
if haskey(classical_defs(v), variable_name) && classical_defs(v)[variable_name].type isa SizedArray
if head(reverse_arguments_map[variable]) != :const_declaration
inner_variable_name = name(reverse_arguments_map[variable])
new_val = classical_defs(function_v)[inner_variable_name].val
back_assignment = QasmExpression(:classical_assignment, QasmExpression(:binary_op, Symbol("="), variable, new_val))
v(back_assignment)
end
elseif haskey(qubit_defs(v), variable_name)
outer_context_map = only(evaluate_qubits(v, variable))
inner_context_map = only(evaluate_qubits(function_v, reverse_arguments_map[variable].args[1]))
reverse_qubits_map[inner_context_map] = outer_context_map
end
end
end
mapper = isempty(reverse_qubits_map) ? identity : ix->remap(ix, reverse_qubits_map)
push!(v, map(mapper, function_v.instructions))
return return_val
return visit_function_call(v, program_expr, function_name)
end
elseif head(program_expr) == :function_definition
function_def = program_expr.args
Expand Down
Loading

0 comments on commit 722c4cc

Please sign in to comment.