Skip to content

Commit

Permalink
fix @if execution; fixes #14
Browse files Browse the repository at this point in the history
  • Loading branch information
mhinsch committed Oct 12, 2023
1 parent a8e4019 commit 01c9105
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/Observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,14 @@ function process_aggregate(var, collection, decls)
# data struct
prop_code = data_struct_elements(statname, stattypes)
push!(stat_type_code, prop_code)


# code for this @stat line
this_stat_code = []

# code to store result of user code (to be fed into stats objects)
# (inside loop)
tmp_name = gensym("tmp_" * statname)
push!(loop_code, :($tmp_name = $(esc(expr))))
push!(this_stat_code, :($tmp_name = $(esc(expr))))

# expression that merges all results for this stat into single named tuple
res_expr = length(stattypes) > 1 ? :(merge()) : :(identity())
Expand All @@ -221,14 +224,15 @@ function process_aggregate(var, collection, decls)

add = :($(esc(:add!))($(esc(vname)), $tmp_name))
# add value to accumulator
if condition == nothing
push!(loop_code, add)
else
push!(loop_code, Expr(:if, esc(condition), add))
end
push!(this_stat_code, add)
# add to named tuple argument of constructor call
push!(res_expr.args, :(to_named_tuple($(esc(:results))($(esc(vname))))))
end
if condition == nothing
append!(loop_code, this_stat_code)
else
push!(loop_code, Expr(:if, esc(condition), Expr(:block, this_stat_code...)))
end

# another argument for the main constructor call
push!(res_code, res_expr)
Expand Down
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,20 @@ result = observe(Data2, m, 42)

end

const run_count = [0]

@observe Data3 model begin
@for ind in model.population begin
@if ind.capital > 0 @stat("count", CountAcc) <| (run_count[1]+=1; ind.n==10)
end
end


@testset "if" begin
m = Model()
result = observe(Data3, m)

@test run_count[1] == 2
@test result.count.n == 1
end

0 comments on commit 01c9105

Please sign in to comment.