Skip to content

Commit

Permalink
some mild refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mhinsch committed Dec 5, 2024
1 parent 3f5ef0b commit 3920fbb
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 157 deletions.
162 changes: 5 additions & 157 deletions src/Observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,98 +12,9 @@ using ..StatsAccumulatorBase: add!
using MacroTools
using MacroTools: prewalk

"obtain a named tuple type with the same field types and names as `struct_T`"
tuple_type(struct_T) = NamedTuple{fieldnames(struct_T), Tuple{fieldtypes(struct_T)...}}

"construct a named tuple from `x`"
@generated function to_named_tuple(x)
if x <: NamedTuple
return :x
end

# constructor call
tuptyp = Expr(:quote, tuple_type(x))

# constructor arguments
tup = Expr(:tuple)
for i in 1:fieldcount(x)
push!(tup.args, :(getfield(x, $i)) )
end

# both put together
:($tuptyp($tup))
end


"translate accumulator types into prefixes for the header (e.g. min, max, etc.)"
stat_names(::Type{T}) where {T} = fieldnames(result_type(T))

# We could make this a generated function as well, but since the header
# is only printed once at the beginning, the additional time needed for
# runtime introspection is worth the reduction in complexity.
"Print a header for an observation type `stats_t` to `output` using field separator `FS`, name separator `NS` and line separator `LS`."
function print_header(output, stats_t; FS="\t", NS="_", LS="\n")
fn = fieldnames(stats_t)
ft = fieldtypes(stats_t)

for (i, (name, typ)) in enumerate(zip(fn, ft))
if typ <: NamedTuple
# aggregate stat
header(output, string(name), string.(fieldnames(typ)), FS, NS)
else
# single stat
print(output, string(name))
end

if i < length(fn)
print(output, FS)
end
end

print(output, LS)
end

# print header for aggregate stat
function header(out, stat_name, stat_names, FS, NS)
@assert length(stat_names) > 0

print(out, join((stat_name * NS) .* stat_names, FS))
end

# It's quite possibly overkill to make this a generated function, but we
# don't want anybody accusing us of wasting CPU cycles.
"Print results stored in `stats` to `output` using field separator `FS` and line separator `LS`."
@generated function log_results(out, stats; FS="\t", LS="\n")
fn = fieldnames(stats)
ft = fieldtypes(stats)

fn_body = Expr(:block)

# all fields of stats
for (i, (name, typ)) in enumerate(zip(fn, ft))
# aggregate stats
if typ <: NamedTuple
# go through all elements of stats.name
for (j, tname) in enumerate(fieldnames(typ))
push!(fn_body.args, :(print(out, stats.$name.$tname)))
if j < length(fieldnames(typ))
push!(fn_body.args, :(print(out, FS)))
end
end
# single values
else
push!(fn_body.args, :(print(out, stats.$name)))
end

if i < length(fn)
push!(fn_body.args, :(print(out, FS)))
end
end

push!(fn_body.args, :(print(out, LS)))

fn_body
end
include("utils.jl")
include("file_io.jl")


@inline function add_result_to_accs!(result, acc)
Expand Down Expand Up @@ -132,29 +43,6 @@ function process_single(name, typ, expr)
:($tmp_name) # result constructor
end

# it would be much nicer to use a generated function for this, but
# unfortunately we are already operating on types
"concatenate named tuple and/or struct types into one single named tuple"
function joined_named_tuple_T(types...)
ns = Expr(:tuple)
ts = Expr(:curly)
push!(ts.args, :Tuple)

for t in types
fnames = fieldnames(t)
ftypes = fieldtypes(t)

append!(ns.args, QuoteNode.(fnames))
append!(ts.args, ftypes)
end

ret = :(NamedTuple{})
push!(ret.args, ns)
push!(ret.args, ts)

eval(ret)
end


# code to declare stat property in stats struct
# creates a single named tuple type from all result types of all stats
Expand All @@ -173,11 +61,8 @@ function data_struct_elements(statname, stattypes)
prop_code
end


error_stat_syntax() = error("expected: [@if cond] @stat(<NAME>, <STAT> {, <STAT>}) <| <EXPR>")



# process a single expression in the AST handed over to the observe macro
# returns the expression unaltered unless @record or @stat is encountered
function process_expression!(ex, stats_type, stats_results, acc_temp_vars)
# local declaration of accumulator objects, goes to beginning of function
temp_vars_code = []
Expand Down Expand Up @@ -245,7 +130,7 @@ Given a declaration
@record "time" model.time
@record "N" Int length(model.population)
@for ind in model.population begin
for ind in model.population
@stat("capital", MaxMinAcc{Float64}, MeanVarAcc{FloatT}) <| ind.capital
@stat("n_alone", CountAcc) <| has_neighbours(ind)
end
Expand Down Expand Up @@ -315,42 +200,5 @@ macro observe(tname, args_and_decl...)

ret
end
# go through declaration expression by expression
# each expression is translated into three bits of code:
# * additional fields for the stats type
# * additional code to run during the analysis
# * additional arguments for the stats object constructor call
#=lines = rmlines(decl).args
for (i, line) in enumerate(lines)
# single stat
if line.args[1] == Symbol("@record")
typ = nothing
@capture(line, @record(name_String, expr_)) ||
@capture(line, @record(name_String, typ_, expr_)) ||
error("expecting: @record <NAME> [<TYPE>] <EXPR>")
stats_type_c, ana_body_c, stats_constr_c =
process_single(Symbol(name), typ, expr)
# aggregate stat
elseif line.args[1] == Symbol("@for")
@capture(line, @for var_Symbol in expr_ begin block_ end) ||
error("expecting: @for <NAME> in <EXPR> <BLOCK>")
stats_type_c, ana_body_c, stats_constr_c =
process_aggregate(var, expr, block)
# everything else is copied verbatim
else
stats_type_c = []
stats_constr_c = []
ana_body_c = [esc(line)]
end
# add code to respective bits
append!(ana_func.args[2].args, ana_body_c)
append!(stats_type.args[3].args, stats_type_c)
append!(stats_constr.args, stats_constr_c)
end
=#
# add constructor call as last line of analysis function

end # module
68 changes: 68 additions & 0 deletions src/file_io.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@

# We could make this a generated function as well, but since the header
# is only printed once at the beginning, the additional time needed for
# runtime introspection is worth the reduction in complexity.
"Print a header for an observation type `stats_t` to `output` using field separator `FS`, name separator `NS` and line separator `LS`."
function print_header(output, stats_t; FS="\t", NS="_", LS="\n")
fn = fieldnames(stats_t)
ft = fieldtypes(stats_t)

for (i, (name, typ)) in enumerate(zip(fn, ft))
if typ <: NamedTuple
# aggregate stat
header(output, string(name), string.(fieldnames(typ)), FS, NS)
else
# single stat
print(output, string(name))
end

if i < length(fn)
print(output, FS)
end
end

print(output, LS)
end

# print header for aggregate stat
function header(out, stat_name, stat_names, FS, NS)
@assert length(stat_names) > 0

print(out, join((stat_name * NS) .* stat_names, FS))
end

# It's quite possibly overkill to make this a generated function, but we
# don't want anybody accusing us of wasting CPU cycles.
"Print results stored in `stats` to `output` using field separator `FS` and line separator `LS`."
@generated function log_results(out, stats; FS="\t", LS="\n")
fn = fieldnames(stats)
ft = fieldtypes(stats)

fn_body = Expr(:block)

# all fields of stats
for (i, (name, typ)) in enumerate(zip(fn, ft))
# aggregate stats
if typ <: NamedTuple
# go through all elements of stats.name
for (j, tname) in enumerate(fieldnames(typ))
push!(fn_body.args, :(print(out, stats.$name.$tname)))
if j < length(fieldnames(typ))
push!(fn_body.args, :(print(out, FS)))
end
end
# single values
else
push!(fn_body.args, :(print(out, stats.$name)))
end

if i < length(fn)
push!(fn_body.args, :(print(out, FS)))
end
end

push!(fn_body.args, :(print(out, LS)))

fn_body
end

50 changes: 50 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"obtain a named tuple type with the same field types and names as `struct_T`"
tuple_type(struct_T) = NamedTuple{fieldnames(struct_T), Tuple{fieldtypes(struct_T)...}}

"construct a named tuple from `x`"
@generated function to_named_tuple(x)
if x <: NamedTuple
return :x
end

# constructor call
tuptyp = Expr(:quote, tuple_type(x))

# constructor arguments
tup = Expr(:tuple)
for i in 1:fieldcount(x)
push!(tup.args, :(getfield(x, $i)) )
end

# both put together
:($tuptyp($tup))
end


"translate accumulator types into prefixes for the header (e.g. min, max, etc.)"
stat_names(::Type{T}) where {T} = fieldnames(result_type(T))


# it would be much nicer to use a generated function for this, but
# unfortunately we are already operating on types
"concatenate named tuple and/or struct types into one single named tuple"
function joined_named_tuple_T(types...)
ns = Expr(:tuple)
ts = Expr(:curly)
push!(ts.args, :Tuple)

for t in types
fnames = fieldnames(t)
ftypes = fieldtypes(t)

append!(ns.args, QuoteNode.(fnames))
append!(ts.args, ftypes)
end

ret = :(NamedTuple{})
push!(ret.args, ns)
push!(ret.args, ts)

eval(ret)
end

0 comments on commit 3920fbb

Please sign in to comment.