Skip to content

Commit

Permalink
Fix macro keyword parsing (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
pulsipher authored Oct 31, 2023
1 parent 064df70 commit 3a512ef
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 16 deletions.
26 changes: 15 additions & 11 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function _extract_kwargs(args)
arg_list = collect(args)
if !isempty(args) && isexpr(args[1], :parameters)
p = popfirst!(arg_list)
append!(arg_list, p.args)
append!(arg_list, (Expr(:(=), a.args[1], a.args[2]) for a in p.args))
end
extra_kwargs = filter(x -> isexpr(x, :(=)) && x.args[1] != :container &&
x.args[1] != :base_name, arg_list)
Expand Down Expand Up @@ -291,9 +291,9 @@ And data, a 2-element Array{GeneralVariableRef,1}:
noname[b]
```
"""
macro infinite_parameter(model, args...)
macro infinite_parameter(args...)
# define error message function
_error(str...) = _macro_error(:infinite_parameter, (model, args...),
_error(str...) = _macro_error(:infinite_parameter, (args...,),
__source__, str...)
# parse the arguments
pos_args, extra_kwargs, container_type, base_name = _extract_kwargs(args)
Expand All @@ -306,6 +306,8 @@ macro infinite_parameter(model, args...)
extra_kwargs)

# process the positional arguments
isempty(pos_args) && _error("No model was given.")
model = popfirst!(pos_args)
if isempty(pos_args)
p = gensym()
is_anon_single = true
Expand Down Expand Up @@ -499,16 +501,18 @@ julia> @finite_parameter(model, par2 == 42)
par2
```
"""
macro finite_parameter(model, args...)
macro finite_parameter(args...)
# process the inputs
esc_model = esc(model)
pos_args, kwargs, container_type, base_name = _extract_kwargs(args)

# make an error function
_error(str...) = _macro_error(:finite_parameter, (model, args...),
_error(str...) = _macro_error(:finite_parameter, (args...,),
__source__, str...)

# process the positional arguments
isempty(pos_args) && _error("No model was given.")
model = popfirst!(pos_args)
esc_model = esc(model)
if length(pos_args) == 1
expr = popfirst!(pos_args)
else
Expand Down Expand Up @@ -702,18 +706,18 @@ julia> @parameter_function(model, pf2[i = 1:2] == t -> g(t, i, b = 2 * i ))
pf2[2](t)
```
"""
macro parameter_function(model, args...)
# prepare the model
esc_model = esc(model)

macro parameter_function(args...)
# define error message function
_error(str...) = _macro_error(:parameter_function, (model, args...),
_error(str...) = _macro_error(:parameter_function, (args...,),
__source__, str...)

# parse the arguments
pos_args, kwargs, container_type, base_name = _extract_kwargs(args)

# process the positional arguements
isempty(pos_args) && _error("No model was given.")
model = popfirst!(pos_args)
esc_model = esc(model)
if length(pos_args) == 1
expr = popfirst!(pos_args)
else
Expand Down
3 changes: 2 additions & 1 deletion test/expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ end
# test @parameter_function
@testset "@parameter_function" begin
# test errors
@test_macro_throws ErrorException @parameter_function()
@test_macro_throws ErrorException @parameter_function(m)
@test_macro_throws ErrorException @parameter_function(m, func = f5)
@test_macro_throws ErrorException @parameter_function(m, y == sin(t), Int)
Expand All @@ -272,7 +273,7 @@ end
# test anonymous singular
idx = 1
ref = GeneralVariableRef(m, idx, ParameterFunctionIndex)
@test isequal(@parameter_function(m, f5(t, x), base_name = "a"), ref)
@test isequal(@parameter_function(m, f5(t, x); base_name = "a"), ref)
@test name(ref) == "a"
@test raw_function(ref) == f5
@test isequal(parameter_refs(ref), (t, x))
Expand Down
10 changes: 6 additions & 4 deletions test/scalar_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ end
@test InfiniteOpt._core_variable_object(pref).domain == IntervalDomain(0, 1)

pref = GeneralVariableRef(m, 4, IndependentParameterIndex)
@test isequal(@infinite_parameter(m, domain = IntervalDomain(0, 1),
@test isequal(@infinite_parameter(m; domain = IntervalDomain(0, 1),
base_name = "d"), pref)
@test name(pref) == "d"

Expand All @@ -380,7 +380,7 @@ end
@test InfiniteOpt._core_variable_object(prefs[2]).domain == IntervalDomain(0, 1)

prefs = [GeneralVariableRef(m, i, IndependentParameterIndex) for i in 9:10]
@test isequal(@infinite_parameter(m, [1:2], domain = IntervalDomain(0, 1),
@test isequal(@infinite_parameter(m, [1:2]; domain = IntervalDomain(0, 1),
independent = true), prefs)
@test InfiniteOpt._core_variable_object(prefs[1]).domain == IntervalDomain(0, 1)
@test InfiniteOpt._core_variable_object(prefs[2]).domain == IntervalDomain(0, 1)
Expand All @@ -406,7 +406,7 @@ end

prefs = [GeneralVariableRef(m, i, IndependentParameterIndex) for i in 17:18]
prefs = convert(JuMP.Containers.SparseAxisArray, prefs)
@test all(isequal.(@infinite_parameter(m, i[1:2] ~ Normal(), independent = true,
@test all(isequal.(@infinite_parameter(m, i[1:2] ~ Normal(); independent = true,
container = SparseAxisArray), prefs))
@test InfiniteOpt._core_variable_object(prefs[1]).domain == UniDistributionDomain(Normal())
@test InfiniteOpt._core_variable_object(prefs[2]).domain == UniDistributionDomain(Normal())
Expand All @@ -423,6 +423,7 @@ end
distribution = Normal(),
domain = IntervalDomain(0, 1))
@test_macro_throws ErrorException @infinite_parameter(m)
@test_macro_throws ErrorException @infinite_parameter()
@test_macro_throws ErrorException @infinite_parameter(m, 0 <= z <= 1)
@test_macro_throws ErrorException @infinite_parameter(m, [1:2] in [0, 1],
independent = a)
Expand Down Expand Up @@ -1000,6 +1001,7 @@ end
@testset "@finite_parameter" begin
# test errors
@test_macro_throws ErrorException @finite_parameter(m)
@test_macro_throws ErrorException @finite_parameter()
@test_macro_throws ErrorException @finite_parameter(m, a, 2)
@test_macro_throws ErrorException @finite_parameter(m, a ~ 42)
@test_macro_throws ErrorException @finite_parameter(Model(), 2)
Expand All @@ -1015,7 +1017,7 @@ end
@test name(pref) == ""
# test vector anonymous definition
prefs = [GeneralVariableRef(m, i, FiniteParameterIndex) for i in 2:3]
@test @finite_parameter(m, [1:2] == 42, base_name = "a") == prefs
@test @finite_parameter(m, [1:2] == 42; base_name = "a") == prefs
@test InfiniteOpt._core_variable_object(prefs[1]).value == 42
@test name.(prefs) == ["a[1]", "a[2]"]
# test named definition
Expand Down

0 comments on commit 3a512ef

Please sign in to comment.