Skip to content

Commit

Permalink
Merge pull request #469 from SciML/symbolicsv5
Browse files Browse the repository at this point in the history
Complete Symbolics v5 update
  • Loading branch information
ChrisRackauckas authored Apr 28, 2023
2 parents 9615177 + 3831358 commit 12f9a44
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 20 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
group:
- Core
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[compat]
CommonSolve = "0.2"
Expand All @@ -35,6 +37,8 @@ RecipesBase = "1"
Reexport = "1.0"
Setfield = "1"
StatsBase = "0.32.0, 0.33"
Symbolics = "5"
SymbolicUtils = "1"
julia = "1.6"

[extras]
Expand Down
14 changes: 7 additions & 7 deletions lib/DataDrivenLux/src/custom_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,38 +53,38 @@ function Base.summary(io::IO, d::ObservedDistribution{fixed, D, E}) where {fixed
end

get_init(d::ObservedDistribution) = d.latent_scale
get_scale(d::ObservedDistribution) = transform(d.scale_transformation, d.latent_scale)
get_scale(d::ObservedDistribution) = TransformVariables.transform(d.scale_transformation, d.latent_scale)
get_dist(d::ObservedDistribution{<:Any, D}) where {D} = D

Base.show(io::IO, d::ObservedDistribution) = summary(io, d)

function Distributions.logpdf(d::ObservedDistribution{false}, x::X, x̂::Y,
scale::S) where {X, Y, S <: Number}
sum(map(xs -> d.errormodel(get_dist(d), xs..., transform(d.scale_transformation, scale)),
sum(map(xs -> d.errormodel(get_dist(d), xs..., TransformVariables.transform(d.scale_transformation, scale)),
zip(x, x̂)))
end

function Distributions.logpdf(d::ObservedDistribution{true}, x::X, x̂::Y,
scale::S) where {X, Y, S <: Number}
sum(map(xs -> d.errormodel(get_dist(d), xs...,
transform(d.scale_transformation, d.latent_scale)),
TransformVariables.transform(d.scale_transformation, d.latent_scale)),
zip(x, x̂)))
end

function Distributions.logpdf(d::ObservedDistribution{false}, x::X, x̂::Number,
scale::S) where {X, S <: Number}
sum(map(xs -> d.errormodel(get_dist(d), xs, x̂,
transform(d.scale_transformation, scale)), x))
TransformVariables.transform(d.scale_transformation, scale)), x))
end

function Distributions.logpdf(d::ObservedDistribution{true}, x::X, x̂::Number,
scale::S) where {X, S <: Number}
sum(map(xs -> d.errormodel(get_dist(d), xs, x̂,
transform(d.scale_transformation, d.latent_scale)), x))
TransformVariables.transform(d.scale_transformation, d.latent_scale)), x))
end

function transform_scales(d::ObservedDistribution, scale::T) where {T <: Number}
transform(d.scale_transformation, scale)
TransformVariables.transform(d.scale_transformation, scale)
end

"""
Expand Down Expand Up @@ -159,7 +159,7 @@ Base.show(io::IO, p::ParameterDistribution) = summary(io, p)

get_init(p::ParameterDistribution) = p.init
function transform_parameter(p::ParameterDistribution, pval::T) where {T <: Number}
transform(p.transformation, pval)
TransformVariables.transform(p.transformation, pval)
end
get_interval(p::ParameterDistribution) = p.interval

Expand Down
2 changes: 1 addition & 1 deletion lib/DataDrivenSR/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
[compat]
Reexport = "1.2"
DataDrivenDiffEq = "1"
SymbolicRegression = "0.14"
SymbolicRegression = "0.17"
julia = "1.6"

[extras]
Expand Down
9 changes: 4 additions & 5 deletions src/DataDrivenDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ using Setfield

@reexport using ModelingToolkit
using ModelingToolkit: AbstractSystem
using ModelingToolkit: value, operation, arguments, istree, get_observed
using ModelingToolkit.Symbolics
using ModelingToolkit.SymbolicUtils
using ModelingToolkit.Symbolics: scalarize, variable
using SymbolicUtils: operation, arguments, istree, issym
using Symbolics
using Symbolics: scalarize, variable, value
@reexport using ModelingToolkit: states, parameters, independent_variable, observed,
controls, get_iv
controls, get_iv, get_observed

using Random
using QuadGK
Expand Down
13 changes: 6 additions & 7 deletions src/basis/utils.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
## Create linear independent basis
count_operation(x::Number, op::Function, nested::Bool = true) = 0
count_operation(x::Sym, op::Function, nested::Bool = true) = 0
count_operation(x::SymbolicUtils.BasicSymbolic, op::Function, nested::Bool = true) = 0
function count_operation(x::Num, op::Function, nested::Bool = true)
count_operation(value(x), op, nested)
end

function count_operation(x, op::Function, nested::Bool = true)
function count_operation(x::SymbolicUtils.BasicSymbolic, op::Function, nested::Bool = true)
issym(x) && return 0
if operation(x) == op
if is_unary(op)
# Handles sin, cos and stuff
Expand All @@ -23,6 +18,10 @@ function count_operation(x, op::Function, nested::Bool = true)
return 0
end

function count_operation(x::Num, op::Function, nested::Bool = true)
count_operation(value(x), op, nested)
end

function count_operation(x, ops::AbstractArray, nested::Bool = true)
return sum([count_operation(x, op, nested) for op in ops])
end
Expand Down

0 comments on commit 12f9a44

Please sign in to comment.