diff --git a/src/Observation.jl b/src/Observation.jl index 9d1f42a..84a408a 100644 --- a/src/Observation.jl +++ b/src/Observation.jl @@ -197,11 +197,14 @@ function process_aggregate(var, collection, decls) # data struct prop_code = data_struct_elements(statname, stattypes) push!(stat_type_code, prop_code) - + + # code for this @stat line + this_stat_code = [] + # code to store result of user code (to be fed into stats objects) # (inside loop) tmp_name = gensym("tmp_" * statname) - push!(loop_code, :($tmp_name = $(esc(expr)))) + push!(this_stat_code, :($tmp_name = $(esc(expr)))) # expression that merges all results for this stat into single named tuple res_expr = length(stattypes) > 1 ? :(merge()) : :(identity()) @@ -221,14 +224,15 @@ function process_aggregate(var, collection, decls) add = :($(esc(:add!))($(esc(vname)), $tmp_name)) # add value to accumulator - if condition == nothing - push!(loop_code, add) - else - push!(loop_code, Expr(:if, esc(condition), add)) - end + push!(this_stat_code, add) # add to named tuple argument of constructor call push!(res_expr.args, :(to_named_tuple($(esc(:results))($(esc(vname)))))) end + if condition == nothing + append!(loop_code, this_stat_code) + else + push!(loop_code, Expr(:if, esc(condition), Expr(:block, this_stat_code...))) + end # another argument for the main constructor call push!(res_code, res_expr) diff --git a/test/runtests.jl b/test/runtests.jl index 67ba85d..ffecf2f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -95,3 +95,20 @@ result = observe(Data2, m, 42) end +const run_count = [0] + +@observe Data3 model begin + @for ind in model.population begin + @if ind.capital > 0 @stat("count", CountAcc) <| (run_count[1]+=1; ind.n==10) + end +end + + +@testset "if" begin + m = Model() + result = observe(Data3, m) + + @test run_count[1] == 2 + @test result.count.n == 1 +end +