diff --git a/src/macros.jl b/src/macros.jl index a3acb606f..1cc1d72dd 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -51,7 +51,7 @@ function _extract_kwargs(args) arg_list = collect(args) if !isempty(args) && isexpr(args[1], :parameters) p = popfirst!(arg_list) - append!(arg_list, p.args) + append!(arg_list, (Expr(:(=), a.args[1], a.args[2]) for a in p.args)) end extra_kwargs = filter(x -> isexpr(x, :(=)) && x.args[1] != :container && x.args[1] != :base_name, arg_list) @@ -291,9 +291,9 @@ And data, a 2-element Array{GeneralVariableRef,1}: noname[b] ``` """ -macro infinite_parameter(model, args...) +macro infinite_parameter(args...) # define error message function - _error(str...) = _macro_error(:infinite_parameter, (model, args...), + _error(str...) = _macro_error(:infinite_parameter, (args...,), __source__, str...) # parse the arguments pos_args, extra_kwargs, container_type, base_name = _extract_kwargs(args) @@ -306,6 +306,8 @@ macro infinite_parameter(model, args...) extra_kwargs) # process the positional arguments + isempty(pos_args) && _error("No model was given.") + model = popfirst!(pos_args) if isempty(pos_args) p = gensym() is_anon_single = true @@ -499,16 +501,18 @@ julia> @finite_parameter(model, par2 == 42) par2 ``` """ -macro finite_parameter(model, args...) +macro finite_parameter(args...) # process the inputs - esc_model = esc(model) pos_args, kwargs, container_type, base_name = _extract_kwargs(args) # make an error function - _error(str...) = _macro_error(:finite_parameter, (model, args...), + _error(str...) = _macro_error(:finite_parameter, (args...,), __source__, str...) # process the positional arguments + isempty(pos_args) && _error("No model was given.") + model = popfirst!(pos_args) + esc_model = esc(model) if length(pos_args) == 1 expr = popfirst!(pos_args) else @@ -702,18 +706,18 @@ julia> @parameter_function(model, pf2[i = 1:2] == t -> g(t, i, b = 2 * i )) pf2[2](t) ``` """ -macro parameter_function(model, args...) - # prepare the model - esc_model = esc(model) - +macro parameter_function(args...) # define error message function - _error(str...) = _macro_error(:parameter_function, (model, args...), + _error(str...) = _macro_error(:parameter_function, (args...,), __source__, str...) # parse the arguments pos_args, kwargs, container_type, base_name = _extract_kwargs(args) # process the positional arguements + isempty(pos_args) && _error("No model was given.") + model = popfirst!(pos_args) + esc_model = esc(model) if length(pos_args) == 1 expr = popfirst!(pos_args) else diff --git a/test/expressions.jl b/test/expressions.jl index 09f57d40e..624396669 100644 --- a/test/expressions.jl +++ b/test/expressions.jl @@ -261,6 +261,7 @@ end # test @parameter_function @testset "@parameter_function" begin # test errors + @test_macro_throws ErrorException @parameter_function() @test_macro_throws ErrorException @parameter_function(m) @test_macro_throws ErrorException @parameter_function(m, func = f5) @test_macro_throws ErrorException @parameter_function(m, y == sin(t), Int) @@ -272,7 +273,7 @@ end # test anonymous singular idx = 1 ref = GeneralVariableRef(m, idx, ParameterFunctionIndex) - @test isequal(@parameter_function(m, f5(t, x), base_name = "a"), ref) + @test isequal(@parameter_function(m, f5(t, x); base_name = "a"), ref) @test name(ref) == "a" @test raw_function(ref) == f5 @test isequal(parameter_refs(ref), (t, x)) diff --git a/test/scalar_parameters.jl b/test/scalar_parameters.jl index 4c6e21c9c..bd25752e4 100644 --- a/test/scalar_parameters.jl +++ b/test/scalar_parameters.jl @@ -356,7 +356,7 @@ end @test InfiniteOpt._core_variable_object(pref).domain == IntervalDomain(0, 1) pref = GeneralVariableRef(m, 4, IndependentParameterIndex) - @test isequal(@infinite_parameter(m, domain = IntervalDomain(0, 1), + @test isequal(@infinite_parameter(m; domain = IntervalDomain(0, 1), base_name = "d"), pref) @test name(pref) == "d" @@ -380,7 +380,7 @@ end @test InfiniteOpt._core_variable_object(prefs[2]).domain == IntervalDomain(0, 1) prefs = [GeneralVariableRef(m, i, IndependentParameterIndex) for i in 9:10] - @test isequal(@infinite_parameter(m, [1:2], domain = IntervalDomain(0, 1), + @test isequal(@infinite_parameter(m, [1:2]; domain = IntervalDomain(0, 1), independent = true), prefs) @test InfiniteOpt._core_variable_object(prefs[1]).domain == IntervalDomain(0, 1) @test InfiniteOpt._core_variable_object(prefs[2]).domain == IntervalDomain(0, 1) @@ -406,7 +406,7 @@ end prefs = [GeneralVariableRef(m, i, IndependentParameterIndex) for i in 17:18] prefs = convert(JuMP.Containers.SparseAxisArray, prefs) - @test all(isequal.(@infinite_parameter(m, i[1:2] ~ Normal(), independent = true, + @test all(isequal.(@infinite_parameter(m, i[1:2] ~ Normal(); independent = true, container = SparseAxisArray), prefs)) @test InfiniteOpt._core_variable_object(prefs[1]).domain == UniDistributionDomain(Normal()) @test InfiniteOpt._core_variable_object(prefs[2]).domain == UniDistributionDomain(Normal()) @@ -423,6 +423,7 @@ end distribution = Normal(), domain = IntervalDomain(0, 1)) @test_macro_throws ErrorException @infinite_parameter(m) + @test_macro_throws ErrorException @infinite_parameter() @test_macro_throws ErrorException @infinite_parameter(m, 0 <= z <= 1) @test_macro_throws ErrorException @infinite_parameter(m, [1:2] in [0, 1], independent = a) @@ -1000,6 +1001,7 @@ end @testset "@finite_parameter" begin # test errors @test_macro_throws ErrorException @finite_parameter(m) + @test_macro_throws ErrorException @finite_parameter() @test_macro_throws ErrorException @finite_parameter(m, a, 2) @test_macro_throws ErrorException @finite_parameter(m, a ~ 42) @test_macro_throws ErrorException @finite_parameter(Model(), 2) @@ -1015,7 +1017,7 @@ end @test name(pref) == "" # test vector anonymous definition prefs = [GeneralVariableRef(m, i, FiniteParameterIndex) for i in 2:3] - @test @finite_parameter(m, [1:2] == 42, base_name = "a") == prefs + @test @finite_parameter(m, [1:2] == 42; base_name = "a") == prefs @test InfiniteOpt._core_variable_object(prefs[1]).value == 42 @test name.(prefs) == ["a[1]", "a[2]"] # test named definition