Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add return_is_mutable kwarg to MA.rewrite #312

Merged
merged 1 commit into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,11 @@ function _is_decomposable_with_factors(ex)
end

"""
rewrite(expr; move_factors_into_sums::Bool = true) -> Tuple{Symbol,Expr}
rewrite(
expr;
move_factors_into_sums::Bool = true,
return_is_mutable::Bool = false,
) -> Tuple{Symbol,Expr[,Bool]}

Rewrites the expression `expr` to use mutable arithmetics.

Expand Down Expand Up @@ -318,34 +322,64 @@ variable = MA.operate!!(*, y, term)

The latter can produce an additional allocation if there is an efficient
fallback for `add_mul` and not for `*(y, term)`.

## `return_is_mutable`

If `return_is_mutable = true`, this function returns three arguments. The third
is a `Bool` indicating if the returned expression can be safely mutated without
changing the user's original expression.

`return_is_mutable` cannot be `true` if `move_factors_into_sums = true`.
"""
function rewrite(x; kwargs...)
function rewrite(x; return_is_mutable::Bool = false, kwargs...)
variable = gensym()
if return_is_mutable
code, is_mutable = rewrite_and_return(x; return_is_mutable, kwargs...)
return variable, :($variable = $code), is_mutable
end
code = rewrite_and_return(x; kwargs...)
return variable, :($variable = $code)
end

"""
rewrite_and_return(expr; move_factors_into_sums::Bool = true) -> Expr
rewrite_and_return(
expr;
move_factors_into_sums::Bool = true,
return_is_mutable::Bool = false,
) -> Expr

Rewrite the expression `expr` using mutable arithmetics and return an expression
in which the last statement is equivalent to `expr`.

See [`rewrite`](@ref) for an explanation of the keyword argument.
See [`rewrite`](@ref) for an explanation of the keyword arguments.
"""
function rewrite_and_return(expr; move_factors_into_sums::Bool = true)
function rewrite_and_return(
expr;
move_factors_into_sums::Bool = true,
return_is_mutable::Bool = false,
)
if move_factors_into_sums
@assert !return_is_mutable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this going to be an issue if JuMP starts using return_is_mutable = true ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JuMP uses move_factors_into_sums = false

root, stack = _rewrite(false, false, expr, nothing, [], [])
else
stack = quote end
root, _ = _rewrite_generic(stack, expr)
return quote
let
$stack
$root
end
end
end
return quote
stack = quote end
root, is_mutable = _rewrite_generic(stack, expr)
code = quote
let
$stack
$root
end
end
if return_is_mutable
return code, is_mutable
end
return code
end

function _is_comparison(ex::Expr)
Expand Down
32 changes: 32 additions & 0 deletions test/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,38 @@ function test_rewrite_ifelse()
return
end

function test_return_is_mutable()
function _rewrite(expr)
return MA.rewrite(
expr;
move_factors_into_sums = false,
return_is_mutable = true,
)
end
x, expr, is_mutable = _rewrite(1)
@test x isa Symbol
@test Meta.isexpr(expr, :(=), 2)
@test is_mutable
y = 1
x, expr, is_mutable = _rewrite(:(y))
@test x isa Symbol
@test Meta.isexpr(expr, :(=), 2)
@test !is_mutable
x, expr, is_mutable = _rewrite(:(1 + 1))
@test x isa Symbol
@test Meta.isexpr(expr, :(=), 2)
@test is_mutable
@test_throws(
AssertionError,
MA.rewrite(
:(1 + 1);
move_factors_into_sums = true,
return_is_mutable = true,
),
)
return
end

end # module

TestRewriteGeneric.runtests()
Loading