Skip to content

Commit

Permalink
copy(::CellDatum) to copy solutions
Browse files Browse the repository at this point in the history
This copy should copy the necessary data to make an independent copy of
CellFields solution data (e.g. dof and dirichlet values), useful to save
FEM solution modified in place (e.g. by TransientFESolution iterator).
It does not copy geometry related data (e.g. CellPoint is not copied,
just returned), those could be modified from the original instance.
  • Loading branch information
Antoine Marteau committed Oct 20, 2022
1 parent 1e49a45 commit fe6b2e1
Show file tree
Hide file tree
Showing 25 changed files with 175 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/CellData/CellDataInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ change_domain(a::CellDatum,target_domain::DomainStyle) = change_domain(a,DomainS
change_domain(a::CellDatum,input_domain::T,target_domain::T) where T<: DomainStyle = a
change_domain(a::CellDatum,input_domain::DomainStyle,target_domain::DomainStyle) = @abstractmethod

"""
Copies all fields data but not model and geometric data
"""
Base.copy(a::CellDatum) = @abstractmethod

# Tester
"""
"""
Expand Down
1 change: 1 addition & 0 deletions src/CellData/CellDofs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ end
get_data(f::CellDof) = f.cell_dof
get_triangulation(f::CellDof) = f.trian
DomainStyle(::Type{CellDof{DS}}) where DS = DS()
Base.copy(f::CellDof{DS}) where DS = CellDof{DS}(copy(f.cell_dof), f.trian, f.domain_style)

function change_domain(a::CellDof,::ReferenceDomain,::PhysicalDomain)
@notimplemented
Expand Down
14 changes: 14 additions & 0 deletions src/CellData/CellFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ end
get_triangulation(f::CellPoint) = f.trian
DomainStyle(::Type{CellPoint{DS,A,B,C}}) where {DS,A,B,C} = DS()

# Do not copy geometric data
Base.copy(f::CellPoint) = f

function change_domain(a::CellPoint,::ReferenceDomain,::PhysicalDomain)
CellPoint(a.cell_ref_point,a.cell_phys_point,a.trian,PhysicalDomain())
end
Expand Down Expand Up @@ -224,6 +227,7 @@ DomainStyle(::Type{GenericCellField{DS}}) where DS = DS()
function similar_cell_field(f::GenericCellField,cell_data,trian,ds)
GenericCellField(cell_data,trian,ds)
end
Base.copy(f::GenericCellField) = GenericCellField(copy(f.cell_field), f.trian, f.domain_style)

"""
dist = distance(polytope::ExtrusionPolytope,
Expand Down Expand Up @@ -485,6 +489,14 @@ struct OperationCellField{DS} <: CellField

new{typeof(domain_style)}(op,args,trian,domain_style,Dict())
end

"""
Copy constructor
"""
function OperationCellField(f::OperationCellField{DS}) where DS
argscopy = (copy(c) for c in f.args)
new{DS}(f.op, argscopy, f.trian, f.domain_style, f.memo)
end
end

function _get_cell_points(args::CellField...)
Expand Down Expand Up @@ -526,6 +538,7 @@ function get_data(f::OperationCellField)
end
get_triangulation(f::OperationCellField) = f.trian
DomainStyle(::Type{OperationCellField{DS}}) where DS = DS()
Base.copy(f::OperationCellField) = OperationCellField(f);

function evaluate!(cache,f::OperationCellField,x::CellPoint)
#key = (:evaluate,objectid(x))
Expand Down Expand Up @@ -693,6 +706,7 @@ end
get_data(f::CellFieldAt) = get_data(f.parent)
get_triangulation(f::CellFieldAt) = get_triangulation(f.parent)
DomainStyle(::Type{CellFieldAt{T,F}}) where {T,F} = DomainStyle(F)
Base.copy(f::CellFieldAt{T}) where T = CellFieldAt{T}(copy(f.parent))
gradient(a::CellFieldAt{P}) where P = CellFieldAt{P}(gradient(a.parent))
∇∇(a::CellFieldAt{P}) where P = CellFieldAt{P}(∇∇(a.parent))
function similar_cell_field(f::CellFieldAt{T},cell_data,trian,ds) where T
Expand Down
3 changes: 3 additions & 0 deletions src/CellData/CellQuadratures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ get_data(f::CellQuadrature) = f.cell_quad
get_triangulation(f::CellQuadrature) = f.trian
DomainStyle(::Type{CellQuadrature{DDS,IDS}}) where {DDS,IDS} = DDS()

# Do not copy geometric data
Base.copy(f::CellQuadrature) = f

function change_domain(a::CellQuadrature,::ReferenceDomain,::PhysicalDomain)
@notimplemented
end
Expand Down
1 change: 1 addition & 0 deletions src/CellData/CellStates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ end

get_triangulation(f::CellState) = get_triangulation(f.points)
DomainStyle(::Type{CellState{T,P}}) where {T,P} = DomainStyle(P)
Base.copy(f::CellState) = CellState(f.points, copy(f.values))

function evaluate!(cache,f::CellState,x::CellPoint)
if f.points === x
Expand Down
9 changes: 9 additions & 0 deletions src/CellData/SkeletonCellFieldPair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ struct SkeletonCellFieldPair{
T = typeof(cf_plus_trian)
new{P,M,T}(cf_plus_plus, cf_minus_minus, cf_plus_trian)
end

"""
Copy constructor
"""
function SkeletonCellFieldPair(a::SkeletonCellFieldPair{P,M,T}) where {P,M,T}
new{P,M,T}(copy(a.cf_plus), copy(a.cf_minus), a.trian)
end
end

function SkeletonCellFieldPair(cf_plus::CellField, cf_minus::CellField)
Expand All @@ -63,6 +70,8 @@ function DomainStyle(a::SkeletonCellFieldPair)
DomainStyle(getfield(a,:cf_plus))
end

Base.copy(a::SkeletonCellFieldPair) = SkeletonCellFieldPair(a)

function get_triangulation(a::SkeletonCellFieldPair)
getfield(a,:trian)
end
Expand Down
2 changes: 2 additions & 0 deletions src/FESpaces/FESpaceInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ get_data(f::SingleFieldFEBasis) = f.cell_basis
get_triangulation(f::SingleFieldFEBasis) = f.trian
BasisStyle(::Type{SingleFieldFEBasis{BS,DS}}) where {BS,DS} = BS()
DomainStyle(::Type{SingleFieldFEBasis{BS,DS}}) where {BS,DS} = DS()
# Do not copy discretisation data
Base.copy(f::SingleFieldFEBasis) = f
function CellData.similar_cell_field(f::SingleFieldFEBasis,cell_data,trian,ds::DomainStyle)
SingleFieldFEBasis(cell_data,trian,BasisStyle(f),ds)
end
Expand Down
2 changes: 1 addition & 1 deletion src/FESpaces/FESpacesWithLinearConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Background:
#
# We build a novel fe space form a given fe space plus a set of linear constraints.
# We accept any single field fe space as input. In particular, the given fe space can also be defied
# We accept any single field fe space as input. In particular, the given fe space can also be defined
# via constraints.
#
# Assumptions:
Expand Down
10 changes: 10 additions & 0 deletions src/FESpaces/SingleFieldFESpaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ get_data(f::SingleFieldFEFunction) = get_data(f.cell_field)
get_triangulation(f::SingleFieldFEFunction) = get_triangulation(f.cell_field)
DomainStyle(::Type{SingleFieldFEFunction{T}}) where T = DomainStyle(T)

function Base.copy(f::SingleFieldFEFunction)
SingleFieldFEFunction(
copy(f.cell_field),
copy(f.cell_dof_values),
copy(f.free_values),
copy(f.dirichlet_values),
f.fe_space
)
end

get_free_dof_values(f::SingleFieldFEFunction) = f.free_values
get_cell_dof_values(f::SingleFieldFEFunction) = f.cell_dof_values
get_fe_space(f::SingleFieldFEFunction) = f.fe_space
Expand Down
9 changes: 9 additions & 0 deletions src/MultiField/MultiFieldCellFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ struct MultiFieldCellField{DS<:DomainStyle} <: CellField

new{typeof(domain_style)}(single_fields,domain_style)
end

"""
Copy constructor
"""
function MultiFieldCellField(f::MultiFieldCellField{DS}) where DS
single_fields_copy = [copy(cf) for cf in f.single_fields]
new{DS}(single_fields_copy, f.domain_style)
end
end

function CellData.get_data(f::MultiFieldCellField)
Expand All @@ -35,6 +43,7 @@ function CellData.get_triangulation(f::MultiFieldCellField)
trian
end
CellData.DomainStyle(::Type{MultiFieldCellField{DS}}) where DS = DS()
Base.copy(f::MultiFieldCellField) = MultiFieldCellField(f)
num_fields(a::MultiFieldCellField) = length(a.single_fields)
Base.getindex(a::MultiFieldCellField,i::Integer) = a.single_fields[i]
Base.iterate(a::MultiFieldCellField) = iterate(a.single_fields)
Expand Down
6 changes: 6 additions & 0 deletions src/MultiField/MultiFieldFEFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ end
CellData.get_data(f::MultiFieldFEFunction) = get_data(f.multi_cell_field)
CellData.get_triangulation(f::MultiFieldFEFunction) = get_triangulation(f.multi_cell_field)
CellData.DomainStyle(::Type{MultiFieldFEFunction{T}}) where T = DomainStyle(T)
function Base.copy(f::MultiFieldFEFunction{T}) where T
sfef_copy = [ copy(ff) for ff in f.single_fe_functions ]
fv_copy = copy(f.free_values)
f_copy = MultiFieldFEFunction(fv_copy, f.fe_space, sfef_copy)
f_copy::MultiFieldFEFunction{T}
end
FESpaces.get_free_dof_values(f::MultiFieldFEFunction) = f.free_values
FESpaces.get_fe_space(f::MultiFieldFEFunction) = f.fe_space

Expand Down
2 changes: 2 additions & 0 deletions src/MultiField/MultiFieldFESpaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ CellData.get_data(f::MultiFieldFEBasisComponent) = f.cell_basis
CellData.get_triangulation(f::MultiFieldFEBasisComponent) = get_triangulation(f.single_field)
FESpaces.BasisStyle(::Type{<:MultiFieldFEBasisComponent{B}}) where B = BasisStyle(B)
CellData.DomainStyle(::Type{<:MultiFieldFEBasisComponent{B}}) where B = DomainStyle(B)
# Do not copy discretisation data
Base.copy(f::MultiFieldFEBasisComponent) = f
function FESpaces.CellData.similar_cell_field(
f::MultiFieldFEBasisComponent,cell_data,trian,ds::DomainStyle)
@notimplemented
Expand Down
3 changes: 3 additions & 0 deletions src/ODEs/TransientFETools/TransientCellField.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ end
get_data(f::TransientSingleFieldCellField) = get_data(f.cellfield)
get_triangulation(f::TransientSingleFieldCellField) = get_triangulation(f.cellfield)
DomainStyle(::Type{<:TransientSingleFieldCellField{A}}) where A = DomainStyle(A)
Base.copy(f::TransientSingleFieldCellField) = TransientSingleFieldCellField(copy(f.cellfield), f.derivatives)
gradient(f::TransientSingleFieldCellField) = gradient(f.cellfield)
∇∇(f::TransientSingleFieldCellField) = ∇∇(f.cellfield)
change_domain(f::TransientSingleFieldCellField,trian::Triangulation,target_domain::DomainStyle) = change_domain(f.cellfield,trian,target_domain)
Expand Down Expand Up @@ -61,6 +62,8 @@ get_data(f::TransientFEBasis) = get_data(f.febasis)
get_triangulation(f::TransientFEBasis) = get_triangulation(f.febasis)
DomainStyle(::Type{<:TransientFEBasis{A}}) where A = DomainStyle(A)
BasisStyle(::Type{<:TransientFEBasis{A}}) where A = BasisStyle(A)
# Do not copy discretisation data
Base.copy(f::TransientFEBasis) = f
gradient(f::TransientFEBasis) = gradient(f.febasis)
∇∇(f::TransientFEBasis) = ∇∇(f.febasis)
change_domain(f::TransientFEBasis,trian::Triangulation,target_domain::DomainStyle) = change_domain(f.febasis,trian,target_domain)
Expand Down
6 changes: 6 additions & 0 deletions src/ODEs/TransientFETools/TransientMultiFieldCellField.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ end

get_triangulation(f::TransientMultiFieldCellField) = get_triangulation(f.cellfield)
DomainStyle(::Type{TransientMultiFieldCellField{A}}) where A = DomainStyle(A)

function Base.copy(f::TransientMultiFieldCellField)
tsf_copy = [ copy(tsf) for tsf in f.transient_single_fields ]
TransientMultiFieldCellField(copy(f.cellfield), f.derivatives, tsf_copy)
end

num_fields(f::TransientMultiFieldCellField) = length(f.cellfield)
gradient(f::TransientMultiFieldCellField) = gradient(f.cellfield)
∇∇(f::TransientMultiFieldCellField) = ∇∇(f.cellfield)
Expand Down
11 changes: 11 additions & 0 deletions test/MultiFieldTests/MultiFieldCellFieldsTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ X = MultiFieldFESpace([U,P])
dv, dq = get_fe_basis(Y)
du, dp = get_trial_fe_basis(X)

@test dv === copy(dv)
@test dq === copy(dq)
@test du === copy(du)
@test dp === copy(dp)

n = VectorValue(1,2)

cellmat = integrate( (ndv)*dp + dq*dp, quad)
Expand Down Expand Up @@ -151,11 +156,17 @@ source_model = CartesianDiscreteModel((0,1,0,1),(10,10))

gh = interpolate_everywhere([ifh₁,ifh₂], V₂²)

ghc = copy(gh)
@test gh !== ghc

pts = [VectorValue(rand(2)) for i=1:10]
gh₁,gh₂ = gh
gc₁,gc₂ = ghc
for pt in pts
@test gh₁(pt) fh₁(pt)
@test gh₂(pt) fh₂(pt)
@test gc₁(pt) fh₁(pt)
@test gc₂(pt) fh₂(pt)
end
end

Expand Down
4 changes: 4 additions & 0 deletions test/MultiFieldTests/MultiFieldFEFunctionsTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ xh = FEFunction(X,free_values)
test_fe_function(xh)
uh, ph = xh

xhc = copy(xh)
test_fe_function(xhc)
uhc, phc = xhc

cell_values = get_cell_dof_values(xh,trian)
@test isa(cell_values[1],ArrayBlock)

Expand Down
7 changes: 7 additions & 0 deletions test/MultiFieldTests/MultiFieldFESpacesTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ uh, ph = xh
@test isa(uh,FEFunction)
@test isa(ph,FEFunction)

xhc = copy(xh)
test_fe_function(xhc)
@test isa(xhc,FEFunction)
uh, ph = xhc
@test isa(uh,FEFunction)
@test isa(ph,FEFunction)

cell_isconstr = get_cell_isconstrained(X,trian)
@test cell_isconstr == Fill(false,num_cells(model))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,15 @@ tol = 1.0e-9
@test e1_l2 < tol
@test e2_l2 < tol

uhc = copy(uh)
uc1, uc2 = uhc
ec1 = u1 - uc1
ec2 = u2 - uc2

ec1_l2 = sqrt(sum((ec1*ec1)*dΩ))
ec2_l2 = sqrt(sum((ec2*ec2)*dΩ))

@test ec1_l2 < tol
@test ec2_l2 < tol

end # module
26 changes: 26 additions & 0 deletions test/ODEsTests/ODEsTests/ODESolversTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ uf, tf, cache = solve_step!(uf,odesol,op,u0,t0,nothing)
uf
@test tf==t0+dt
@test all(uf.≈x)
ufc = copy(uf)
@test all(ufc.≈x)

# ODESolutions

Expand All @@ -76,6 +78,8 @@ current, state = Base.iterate(sol)
uf, tf = current
@test tf==t0+dt
@test all(uf.≈x)
ufc = copy(uf)
@test all(ufc.≈x)

# BackwardEulerNonlinearOperator tests

Expand Down Expand Up @@ -106,6 +110,8 @@ cache = nothing
uf, tf, cache = solve_step!(uf,odesol,op,u0,t0,cache)
@test tf==t0+dt
@test all(uf.≈1+11/9)
ufc = copy(uf)
@test all(ufc.≈1+11/9)

@test test_ode_solver(odesol,op,u0,t0,tf)
test_ode_solver(odesol,op,u0,t0,tf)
Expand All @@ -117,12 +123,16 @@ cache = nothing
uf, tf, cache = solve_step!(uf,odesol,op,u0,t0,cache)
@test tf==t0+dt
@test all(uf.≈1+11/9)
ufc = copy(uf)
@test all(ufc.≈1+11/9)

op = ODEOperatorMock{Float64,Affine}(1.0,0.0,1.0,1)
cache = nothing
uf, tf, cache = solve_step!(uf,odesol,op,u0,t0,cache)
@test tf==t0+dt
@test all(uf.≈1+11/9)
ufc = copy(uf)
@test all(ufc.≈1+11/9)

# RK tests

Expand All @@ -135,6 +145,8 @@ cache = nothing
uf, tf, cache = solve_step!(uf,odesol,op,u0,t0,cache)
@test tf==t0+dt
@test all(uf.≈1+11/9)
ufc = copy(uf)
@test all(ufc.≈1+11/9)
# SDIRK 2nd order
odesol = RungeKutta(ls,dt,:SDIRK_2_1_2)
uf = copy(u0)
Expand All @@ -143,13 +155,17 @@ cache = nothing
uf, tf, cache = solve_step!(uf,odesol,op,u0,t0,cache)
@test tf==t0+dt
@test all(uf.≈u0*(1.0+dt/(2*(1-dt))+dt*(1-2*dt)/(2*(1-dt)^2)))
ufc = copy(uf)
@test all(ufc.≈u0*(1.0+dt/(2*(1-dt))+dt*(1-2*dt)/(2*(1-dt)^2)))
# TRBDF (2nd order with some 0 on the diagonal)
odesol = RungeKutta(ls,dt,:TRBDF2_3_3_2)
uf.=1.0
cache = nothing
uf, tf, cache = solve_step!(uf,odesol,op,u0,t0,cache)
@test tf==t0+dt
@test all(uf.≈u0*1.105215241)
ufc = copy(uf)
@test all(ufc.≈u0*1.105215241)

@test test_ode_solver(odesol,op,u0,t0,tf)
test_ode_solver(odesol,op,u0,t0,tf)
Expand All @@ -172,6 +188,10 @@ aᵦ = 2*β*af .+ (1-2*β)*a0
@test tf==t0+dt
@test all(vf .≈ (v0 + dt*aᵧ))
@test all(uf .≈ (u0 + dt*v0 + 0.5*dt^2*aᵦ))
vfc = copy(vf)
ufc = copy(uf)
@test all(vfc .≈ (v0 + dt*aᵧ))
@test all(ufc .≈ (u0 + dt*v0 + 0.5*dt^2*aᵦ))

# GeneralizedAlpha test

Expand All @@ -193,5 +213,11 @@ ufθ, tf, cache = solve_step!(ufθ,odesolθ,op,u0,t0,nothing)
@test tf==t0+dt
@test all(ufα.≈ufθ)
@test all(vf.≈ 1/*dt) * (ufα-u0) + (1-1/γ)*v0)
ufαc = copy(ufα)
ufθc = copy(ufθ)
vfc = copy(vf)
@test all(ufθc.≈ufθ)
@test all(ufαc.≈ufθc)
@test all(vfc.≈ 1/*dt) * (ufα-u0) + (1-1/γ)*v0)

# end #module
4 changes: 4 additions & 0 deletions test/ODEsTests/TransientFEsTests/BoundaryHeatEquationTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,8 @@ for (uh_tn, tn) in sol_t
@test el2 < tol
end

all_sol = [ (copy(uh_tn), tn) for (uh_tn, tn) in sol_t ]
all_el2 = [ sqrt(sum( (l2( u(tn) - uhc_tn ))dΩ )) for (uhc_tn,tn) in all_sol ]
@test all( all_el2 .< tol )

end #module
Loading

0 comments on commit fe6b2e1

Please sign in to comment.