diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ddd4b0659..7766485c0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -10,6 +10,7 @@ jobs: test: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: group: - Core diff --git a/Project.toml b/Project.toml index 18d13462c..ae0764421 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,8 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [compat] CommonSolve = "0.2" @@ -35,6 +37,8 @@ RecipesBase = "1" Reexport = "1.0" Setfield = "1" StatsBase = "0.32.0, 0.33" +Symbolics = "5" +SymbolicUtils = "1" julia = "1.6" [extras] diff --git a/lib/DataDrivenLux/src/custom_priors.jl b/lib/DataDrivenLux/src/custom_priors.jl index 3304a62fc..6317c76e5 100644 --- a/lib/DataDrivenLux/src/custom_priors.jl +++ b/lib/DataDrivenLux/src/custom_priors.jl @@ -53,38 +53,38 @@ function Base.summary(io::IO, d::ObservedDistribution{fixed, D, E}) where {fixed end get_init(d::ObservedDistribution) = d.latent_scale -get_scale(d::ObservedDistribution) = transform(d.scale_transformation, d.latent_scale) +get_scale(d::ObservedDistribution) = TransformVariables.transform(d.scale_transformation, d.latent_scale) get_dist(d::ObservedDistribution{<:Any, D}) where {D} = D Base.show(io::IO, d::ObservedDistribution) = summary(io, d) function Distributions.logpdf(d::ObservedDistribution{false}, x::X, x̂::Y, scale::S) where {X, Y, S <: Number} - sum(map(xs -> d.errormodel(get_dist(d), xs..., transform(d.scale_transformation, scale)), + sum(map(xs -> d.errormodel(get_dist(d), xs..., TransformVariables.transform(d.scale_transformation, scale)), zip(x, x̂))) end function Distributions.logpdf(d::ObservedDistribution{true}, x::X, x̂::Y, scale::S) where {X, Y, S <: Number} sum(map(xs -> d.errormodel(get_dist(d), xs..., - transform(d.scale_transformation, d.latent_scale)), + TransformVariables.transform(d.scale_transformation, d.latent_scale)), zip(x, x̂))) end function Distributions.logpdf(d::ObservedDistribution{false}, x::X, x̂::Number, scale::S) where {X, S <: Number} sum(map(xs -> d.errormodel(get_dist(d), xs, x̂, - transform(d.scale_transformation, scale)), x)) + TransformVariables.transform(d.scale_transformation, scale)), x)) end function Distributions.logpdf(d::ObservedDistribution{true}, x::X, x̂::Number, scale::S) where {X, S <: Number} sum(map(xs -> d.errormodel(get_dist(d), xs, x̂, - transform(d.scale_transformation, d.latent_scale)), x)) + TransformVariables.transform(d.scale_transformation, d.latent_scale)), x)) end function transform_scales(d::ObservedDistribution, scale::T) where {T <: Number} - transform(d.scale_transformation, scale) + TransformVariables.transform(d.scale_transformation, scale) end """ @@ -159,7 +159,7 @@ Base.show(io::IO, p::ParameterDistribution) = summary(io, p) get_init(p::ParameterDistribution) = p.init function transform_parameter(p::ParameterDistribution, pval::T) where {T <: Number} - transform(p.transformation, pval) + TransformVariables.transform(p.transformation, pval) end get_interval(p::ParameterDistribution) = p.interval diff --git a/lib/DataDrivenSR/Project.toml b/lib/DataDrivenSR/Project.toml index 1d903495d..ecc755c8c 100644 --- a/lib/DataDrivenSR/Project.toml +++ b/lib/DataDrivenSR/Project.toml @@ -11,7 +11,7 @@ SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb" [compat] Reexport = "1.2" DataDrivenDiffEq = "1" -SymbolicRegression = "0.14" +SymbolicRegression = "0.17" julia = "1.6" [extras] diff --git a/src/DataDrivenDiffEq.jl b/src/DataDrivenDiffEq.jl index 9ce3235b9..05704a283 100644 --- a/src/DataDrivenDiffEq.jl +++ b/src/DataDrivenDiffEq.jl @@ -14,12 +14,11 @@ using Setfield @reexport using ModelingToolkit using ModelingToolkit: AbstractSystem -using ModelingToolkit: value, operation, arguments, istree, get_observed -using ModelingToolkit.Symbolics -using ModelingToolkit.SymbolicUtils -using ModelingToolkit.Symbolics: scalarize, variable +using SymbolicUtils: operation, arguments, istree, issym +using Symbolics +using Symbolics: scalarize, variable, value @reexport using ModelingToolkit: states, parameters, independent_variable, observed, - controls, get_iv + controls, get_iv, get_observed using Random using QuadGK diff --git a/src/basis/utils.jl b/src/basis/utils.jl index 7cbb25736..dd05c4ab0 100644 --- a/src/basis/utils.jl +++ b/src/basis/utils.jl @@ -1,12 +1,7 @@ ## Create linear independent basis count_operation(x::Number, op::Function, nested::Bool = true) = 0 -count_operation(x::Sym, op::Function, nested::Bool = true) = 0 -count_operation(x::SymbolicUtils.BasicSymbolic, op::Function, nested::Bool = true) = 0 -function count_operation(x::Num, op::Function, nested::Bool = true) - count_operation(value(x), op, nested) -end - -function count_operation(x, op::Function, nested::Bool = true) +function count_operation(x::SymbolicUtils.BasicSymbolic, op::Function, nested::Bool = true) + issym(x) && return 0 if operation(x) == op if is_unary(op) # Handles sin, cos and stuff @@ -23,6 +18,10 @@ function count_operation(x, op::Function, nested::Bool = true) return 0 end +function count_operation(x::Num, op::Function, nested::Bool = true) + count_operation(value(x), op, nested) +end + function count_operation(x, ops::AbstractArray, nested::Bool = true) return sum([count_operation(x, op, nested) for op in ops]) end