Skip to content

Commit

Permalink
Add comments and tidy _rewrite function. (#59)
Browse files Browse the repository at this point in the history
* Add comments and tidy _rewrite function.

I started fixing a bug for Julia 1.6, and got thoroughly confused
with what this function was doing. Hopefully this is more readable.

* Add back special case

* Fix comments
  • Loading branch information
odow authored Nov 12, 2020
1 parent 7117b58 commit a627628
Showing 1 changed file with 239 additions and 68 deletions.
307 changes: 239 additions & 68 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,114 +291,285 @@ function _write_add_mul(vectorized, minus, current_sum, left_factors, inner_fact
end

"""
_rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym())
_rewrite(
vectorized::Bool,
minus::Bool,
inner_factor,
current_sum::Union{Symbol, Nothing},
left_factors::Vector,
right_factors::Vector,
new_var::Symbol = gensym(),
)
Return `new_var, code` such that `code` is equivalent to
```julia
new_var = prod(left_factors) * inner_factor * prod(reverse(right_factors))
```
if `current_sum` is `nothing`, and is
If `current_sum` is `nothing`, and is
```julia
new_var = current_sum op prod(left_factors) * inner_factor * prod(reverse(right_factors))
```
otherwise where `op` is `+` if `!vectorized` and `!minus`, `.+` if `vectorized` and `!minus`,
`-` if `!vectorized` and `minus` and `.-` if `vectorized` and `minus`.
otherwise where `op` is `+` if `!vectorized & !minus`, `.+` if
`vectorized & !minus`, `-` if `!vectorized & minus` and `.-` if
`vectorized & minus`.
"""
function _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Union{Symbol, Nothing}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym())
function _rewrite(
vectorized::Bool,
minus::Bool,
inner_factor,
current_sum::Union{Symbol, Nothing},
left_factors::Vector,
right_factors::Vector,
new_var::Symbol = gensym(),
)
if isexpr(inner_factor, :call)
# We need to verfify that `left_factors` and `right_factors` are empty for broadcast, see `_is_decomposable_with_factors`.
# We also need to verify that `current_sum` is `nothing` otherwise we are unsure that the elements in the containers have been copied, e.g., in
# `I + (x .+ 1)`, the offdiagonal entries of `I + x` are the same as `x` so we cannot do `broadcast!(add_mul, I + x, 1)`.
if inner_factor.args[1] == :+ || inner_factor.args[1] == :- ||
(current_sum === nothing && isempty(left_factors) && isempty(right_factors) && (inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-)))
block = Expr(:block)
if length(inner_factor.args) > 2 # not unary addition or subtraction
next_sum, code = _rewrite(vectorized, minus, inner_factor.args[2], current_sum, left_factors, right_factors)
push!(block.args, code)
start = 3
else
if (
inner_factor.args[1] == :+ ||
inner_factor.args[1] == :- ||
(
current_sum === nothing &&
isempty(left_factors) &&
isempty(right_factors) &&
(inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-))
)
)
# There are three cases here:
# 1. scalar addition : +(args...)
# 2. scalar subtraction : -(args...)
# 3. broadcast addition or subtraction.
# For case (3), we need to verify that current_sum, left_factors,
# and right_factors are empty, otherwise we are unsure that the
# elements in the containers have been copied, e.g., in
# `I + (x .+ 1)`, the offdiagonal entries of `I + x` are the same as
# `x` so we cannot do `broadcast!(add_mul, I + x, 1)`.
code = Expr(:block)
if length(inner_factor.args) == 2
# Unary addition or subtraction.
next_sum = current_sum
start = 2
else
next_sum, new_code = _rewrite(
vectorized,
minus,
inner_factor.args[2],
current_sum,
left_factors,
right_factors,
)
push!(code.args, new_code)
start = 3
end
vectorized = vectorized || inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-)
if inner_factor.args[1] == :- || inner_factor.args[1] == :(.-)
minus = !minus
end
return rewrite_sum(vectorized, minus, inner_factor.args[start:end], next_sum, left_factors, right_factors, new_var, block)
vectorized = (
vectorized ||
inner_factor.args[1] == :(.+) ||
inner_factor.args[1] == :(.-)
)
return rewrite_sum(
vectorized,
minus,
inner_factor.args[start:end],
next_sum,
left_factors,
right_factors,
new_var,
code,
)
elseif inner_factor.args[1] == :* && !vectorized
# We need `&& !vectorized` otherwise `x .+ A * b` would be rewritten `broadcast!(add_mul, x, A, b)`.

# we might need to recurse on multiple arguments, e.g.,
# (x+y)*(x+y)
# special case, only recurse on one argument and don't create temporary objects
if isone(mapreduce(_is_complex_expr, +, inner_factor.args)) &&
isone(mapreduce(_is_decomposable_with_factors, +, inner_factor.args))
# A multiplication expression *(args...). We need `!vectorized`
# otherwise `x .+ A * b` would be rewritten
# `broadcast!(add_mul, x, A, b)`.
# We might need to recurse on multiple arguments, e.g., (x+y)*(x+y).
# As a special case, only recurse on one argument and don't create
# temporary objects
if (
isone(mapreduce(_is_complex_expr, +, inner_factor.args)) &&
isone(mapreduce(_is_decomposable_with_factors, +, inner_factor.args))
)
# `findfirst` return the index in `2:...` so we need to add `1`.
which_idx = 1 + findfirst(2:length(inner_factor.args)) do i
_is_decomposable_with_factors(inner_factor.args[i])
end
return _rewrite(
vectorized, minus, inner_factor.args[which_idx], current_sum,
vcat(left_factors, [esc(inner_factor.args[i]) for i in 2:(which_idx - 1)]),
vcat(right_factors, [esc(inner_factor.args[i]) for i in length(inner_factor.args):-1:(which_idx + 1)]),
new_var)
vectorized,
minus,
inner_factor.args[which_idx],
current_sum,
vcat(
left_factors,
[esc(inner_factor.args[i]) for i in 2:(which_idx - 1)]
),
vcat(
right_factors,
[
esc(inner_factor.args[i])
for i in length(inner_factor.args):-1:(which_idx + 1)
],
),
new_var,
)
else
blk = Expr(:block)
code = Expr(:block)
for i in 2:length(inner_factor.args)
if _is_complex_expr(inner_factor.args[i])
new_var_, parsed = rewrite(inner_factor.args[i])
push!(blk.args, parsed)
inner_factor.args[i] = new_var_
arg = inner_factor.args[i]
if _is_complex_expr(arg) # `arg` needs rewriting.
new_arg, new_arg_code = rewrite(arg)
push!(code.args, new_arg_code)
inner_factor.args[i] = new_arg
else
inner_factor.args[i] = esc(inner_factor.args[i])
inner_factor.args[i] = esc(arg)
end
end
push!(blk.args, _write_add_mul(
vectorized, minus, current_sum, left_factors,
inner_factor.args[2:end], right_factors, new_var
))
return new_var, blk
push!(
code.args,
_write_add_mul(
vectorized,
minus,
current_sum,
left_factors,
inner_factor.args[2:end],
right_factors,
new_var,
),
)
return new_var, code
end
elseif inner_factor.args[1] == :^ && _is_complex_expr(inner_factor.args[2]) && !vectorized
# We need `&& !vectorized` otherwise `A .+ (A + A)^2` would be rewritten `broadcast!(add_mul, x, AA, AA)` where `AA` is `A + A`.
MulType = :(MA.promote_operation(*, typeof($(inner_factor.args[2])), typeof($(inner_factor.args[2]))))
if inner_factor.args[3] == 2
elseif (
inner_factor.args[1] == :^ &&
_is_complex_expr(inner_factor.args[2]) &&
!vectorized
)
# An expression like `base ^ exponent`, where the `base` is a
# non-trivial expression that also needs to be re-written. We need
# `!vectorized` otherwise `A .+ (A + A)^2` would be rewritten as
# `broadcast!(add_mul, x, AA, AA)` where `AA` is `A + A`.
MulType = :(
MA.promote_operation(
*,
typeof($(inner_factor.args[2])),
typeof($(inner_factor.args[2]))
)
)
if inner_factor.args[3] == 0
# If the exponent is 0, rewrite
# new_var = base^0
# as
# new_var = 1
return _rewrite(
vectorized,
minus,
:(one($MulType)),
current_sum,
left_factors,
right_factors,
new_var,
)
elseif inner_factor.args[3] == 1
# If the exponent is 1, rewrite
# new_var = base^1
# as
# new_var = base
return _rewrite(
vectorized,
minus,
:(convert($MulType, $(inner_factor.args[2]))),
current_sum,
left_factors,
right_factors,
new_var,
)
elseif inner_factor.args[3] == 2
# If the exponent is 2, rewrite
# new_var = base^2
# as
# new_base = base_rewrite
# new_var = base_rewrite * base_rewrite
new_var_, parsed = rewrite(inner_factor.args[2])
square_expr = _write_add_mul(
vectorized, minus, current_sum, left_factors,
(new_var_, new_var_), right_factors, new_var
vectorized,
minus,
current_sum,
left_factors,
(new_var_, new_var_),
right_factors,
new_var,
)
return new_var, Expr(:block, parsed, square_expr)
elseif inner_factor.args[3] == 1
return _rewrite(vectorized, minus, :(convert($MulType, $(inner_factor.args[2]))), current_sum, left_factors, right_factors, new_var)
elseif inner_factor.args[3] == 0
return _rewrite(vectorized, minus, :(one($MulType)), current_sum, left_factors, right_factors, new_var)
else
new_var_, parsed = rewrite(inner_factor.args[2])
power_expr = _write_add_mul(
vectorized, minus, current_sum, left_factors,
(Expr(:call, :^, new_var_, esc(inner_factor.args[3])),),
right_factors, new_var
# In the general case, rewrite
# new_var = base^exponent
# as
# new_base = base_rewrite
# new_var = base_rewrite^(exponent)
new_base, base_rewrite = rewrite(inner_factor.args[2])
new_expr = _write_add_mul(
vectorized,
minus,
current_sum,
left_factors,
(Expr(:call, :^, new_base, esc(inner_factor.args[3])),),
right_factors,
new_var,
)
return new_var, Expr(:block, parsed, power_expr)
return new_var, Expr(:block, base_rewrite, new_expr)
end
elseif inner_factor.args[1] == :/ && !vectorized
# Rewrite
# new_var = numerator / denominator
# as
# new_var = numerator * (1 / denominator)
@assert length(inner_factor.args) == 3
numerator = inner_factor.args[2]
denom = inner_factor.args[3]
return _rewrite(vectorized, minus, numerator, current_sum, left_factors, vcat(esc(:(1 / $denom)), right_factors), new_var)
elseif length(inner_factor.args) >= 2 && (isexpr(inner_factor.args[2], :generator) || isexpr(inner_factor.args[2], :flatten))
return new_var, _parse_generator(vectorized, minus, inner_factor, current_sum, left_factors, right_factors, new_var)
return _rewrite(
vectorized,
minus,
inner_factor.args[2],
current_sum,
left_factors,
vcat(esc(:(1 / $(inner_factor.args[3]))), right_factors),
new_var,
)
elseif (
length(inner_factor.args) >= 2 &&
(
isexpr(inner_factor.args[2], :generator) ||
isexpr(inner_factor.args[2], :flatten)
)
)
# A generator statement.
code = _parse_generator(
vectorized,
minus,
inner_factor,
current_sum,
left_factors,
right_factors,
new_var,
)
return new_var, code
end
elseif isexpr(inner_factor, :curly)
Base.error("The curly syntax (sum{},prod{},norm2{}) is no longer supported. Expression: `$inner_factor`.")
end
if isa(inner_factor, Expr) && _is_comparison(inner_factor)
if isexpr(inner_factor, :curly)
error(
"The curly syntax (sum{},prod{},norm2{}) is no longer supported. " *
"Expression: `$inner_factor`."
)
elseif isa(inner_factor, Expr) && _is_comparison(inner_factor)
error("Unexpected comparison in expression `$inner_factor`.")
end
if isa(inner_factor, Expr) && _has_assignment_in_ref(inner_factor)
elseif isa(inner_factor, Expr) && _has_assignment_in_ref(inner_factor)
error("Unexpected assignment in expression `$inner_factor`.")
end
# at the lowest level
return new_var, _write_add_mul(vectorized, minus, current_sum, left_factors, (esc(inner_factor),), right_factors, new_var)
# None of the special cases were hit! This probably means we are vectorized.
code = _write_add_mul(
vectorized,
minus,
current_sum,
left_factors,
(esc(inner_factor),),
right_factors,
new_var,
)
return new_var, code
end

0 comments on commit a627628

Please sign in to comment.