Skip to content

Commit

Permalink
Merge pull request #51 from JuliaDiffEq/minor_basis_fix
Browse files Browse the repository at this point in the history
Minor basis fix
  • Loading branch information
AlCap23 authored Feb 20, 2020
2 parents 0d5c3bc + a8a2d44 commit a13e1c8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
24 changes: 13 additions & 11 deletions src/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ function Basis(basis::AbstractVector{Operation}, variables::AbstractVector{Opera
bs = unique(basis)
fix_single_vars_in_basis!(bs, variables)

vs = sort!([b for b in [ModelingToolkit.vars(bs)...] if !b.known], by = x -> x.name)
ps = sort!([b for b in [ModelingToolkit.vars(bs)...] if b.known], by = x -> x.name )
vs = [ModelingToolkit.Variable(Symbol(i)) for i in variables]
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters]

f_ = ModelingToolkit.build_function(bs, vs, ps, (), simplified_expr, Val{false})[1]
return Basis(bs, variables, parameters, f_)
end

function update!(b::Basis)
vs = sort!([bi for bi in [ModelingToolkit.vars(b.basis)...] if !bi.known], by = x->x.name)
ps = sort!([bi for bi in [ModelingToolkit.vars(b.basis)...] if bi.known], by = x->x.name)
function update!(basis::Basis)
vs = [ModelingToolkit.Variable(Symbol(i)) for i in variables(basis)]
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters(basis)]

b.f_ = ModelingToolkit.build_function(b.basis, vs, ps, (), simplified_expr, Val{false})[1]
basis.f_ = ModelingToolkit.build_function(basis.basis, vs, ps, (), simplified_expr, Val{false})[1]
return
end

function Base.push!(b::Basis, ops::AbstractArray{Operation})
@inbounds for o in ops
push!(b.basis, o)
Expand Down Expand Up @@ -139,10 +139,12 @@ Base.length(b::Basis) = length(b.basis)
ModelingToolkit.parameters(b::Basis) = b.parameter
variables(b::Basis) = b.variables

function jacobian(b::Basis)
vs = sort!([bi for bi in [ModelingToolkit.vars(b.basis)...] if !bi.known], by = x-> x.name)
ps = sort!([bi for bi in [ModelingToolkit.vars(b.basis)...] if bi.known], by = x-> x.name)
j = calculate_jacobian(b.basis, variables(b))
function jacobian(basis::Basis)

vs = [ModelingToolkit.Variable(Symbol(i)) for i in variables(basis)]
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters(basis)]

j = calculate_jacobian(basis.basis, variables(basis))
return ModelingToolkit.build_function(expand_derivatives.(j), vs, ps, (), simplified_expr, Val{false})[1]
end

Expand Down
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ using Test
unique!(basis)
@test size(basis) == size(h)

@variables a
g = [u[1]; u[3]; a]
basis = Basis(g, [u; a])
@test basis([1; 2; 3; 4]) == [1; 3; 4]
g = [1.0*u[1]; 1.0*u[3]; 1.0*u[2]]
basis = Basis(g, u, parameters = [])
X = ones(Float64, 3, 10)
Expand Down Expand Up @@ -230,7 +234,7 @@ end
X = sol[:, :] + 1e-3*randn(size(sol[:,:])...)
set_threshold!(opt, 3.5e-1)
Ψ = SInDy(X, DX, basis, maxiter = 10000, opt = opt, denoise = true, normalize = true)

estimator = ODEProblem(dynamics(Ψ), u0, tspan, [])
sol_4 = solve(estimator,Tsit5(), saveat = dt)
@test norm(sol[:,:] - sol_4[:,:], 2) < 5e-1
Expand Down

0 comments on commit a13e1c8

Please sign in to comment.