diff --git a/Project.toml b/Project.toml index 87295f3..3fa760b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CompositionalNetworks" uuid = "4b67e4b5-442d-4ef5-b760-3f5df3a57537" authors = ["Jean-François Baffier"] -version = "0.1.1" +version = "0.1.2" [deps] ConstraintDomains = "5800fd60-8556-4464-8d61-84ebf7a0bedb" diff --git a/src/learn.jl b/src/learn.jl index 0f006b9..4fc540a 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -1,17 +1,17 @@ -function _partial_search_space(domains, concept; sol_number=100) +function _partial_search_space(domains, concept, param=nothing; sol_number=100) solutions = Set{Vector{Int}}() non_sltns = Set{Vector{Int}}() while length(solutions) < 100 || length(non_sltns) < 100 config = map(_draw, domains) - c = concept(config) + c = concept(config; param = param) c && length(solutions) < 100 && push!(solutions, config) !c && length(non_sltns) < 100 && push!(non_sltns, config) end return solutions, non_sltns end -function _complete_search_space(domains, concept) +function _complete_search_space(domains, concept, param=nothing) solutions = Set{Vector{Int}}() non_sltns = Set{Vector{Int}}() @@ -24,8 +24,10 @@ function _complete_search_space(domains, concept) @warn message space_size end + f = isnothing(param) ? ((x; param = p) -> concept(x)) : concept + configurations = product(map(d -> _get_domain(d), domains)...) - foreach(c -> (cv = collect(c); push!(concept(cv) ? solutions : non_sltns, cv)), configurations) + foreach(c -> (cv = collect(c); push!(f(cv; param=param) ? solutions : non_sltns, cv)), configurations) return solutions, non_sltns end @@ -54,7 +56,7 @@ function explore_learn_compose(concept; domains, param=nothing, ) dom_size = maximum(_length, domains) if search == :complete - X_sols, X = _complete_search_space(domains, concept) + X_sols, X = _complete_search_space(domains, concept, param) union!(X, X_sols) return learn_compose(X, X_sols, dom_size, param; local_iter=local_iter, global_iter=global_iter, action=action) @@ -72,7 +74,9 @@ function _compose_to_string(symbols, name) co = _reduce_symbols(symbols[4], ", ", false; prefix=CN * "_co_") julia_string = """ - $name = x -> fill(x, $tr_length) .|> $tr |> $ar |> $ag |> $co + function $name(x; param=nothing, dom_size) + fill(x, $tr_length) .|> map(f -> (y -> f(y; param=param)), $tr) |> $ar |> $ag |> (y -> $co(y; param=param, dom_size=dom_size, nvars=length(x))) + end """ return julia_string