Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthiasSachs committed Oct 31, 2023
1 parent e3e0f31 commit cb4e8ca
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,6 @@ println("Epoch: $epoch, Avg Training Loss: $(loss_traj["train"][end]/n_train), T
c_fit = params(ffm)
set_params!(fm, c_fit)


using FluxOptTools
using Optim
using Zygote

loss() = weighted_l2_loss(ffm,train)
pars = Flux.params(ffm)
loss()
lossfun, gradfun, fg!, p0 = optfuns(loss, pars)
# copy the optimal parameters back into pars (not that this simulatenously modifies the flux model parameters `ffm.c``)
using Plots
Plots.contourf(() -> log10(1 + loss()), pars, color=:turbo, npoints=50, lnorm=1)


# Evaluate different error statistics

using ACEds.Analytics: error_stats, plot_error, plot_error_all,friction_entries
Expand Down
Empty file added examples/plot_loss.jl
Empty file.
12 changes: 6 additions & 6 deletions src/atomcutoffs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ function ACE.write_dict(cutoff::SphericalCutoff{T}) where {T}
return Dict("__id__" => "ACEds_SphericalCutoff",
"rcut" => cutoff.rcut,
"T" => T)
end
end

function ACE.read_dict(::Val{:ACEds_SphericalCutoff}, D::Dict)
rcut = D["rcut"]
T = D["T"]
return SphericalCutoff{T}(rcut)
end
function ACE.read_dict(::Val{:ACEds_SphericalCutoff}, D::Dict)
T = getfield(Base, Symbol(D["T"]))
rcut = T(D["rcut"])
return SphericalCutoff(rcut)
end

end
2 changes: 1 addition & 1 deletion src/frictionmodels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ function ACE.write_dict(fm::FrictionModel)
"matrixmodels" => Dict(id=>write_dict(fm.matrixmodels[id]) for id in keys(fm.matrixmodels)))
end
function ACE.read_dict(::Val{:ACEds_FrictionModel}, D::Dict)
matrixmodels = NamedTuple(Dict(id=>read_dict(val) for (id,val) in D["matrixmodels"]))
matrixmodels = NamedTuple(Dict(Symbol(id)=>read_dict(val) for (id,val) in D["matrixmodels"]))
return FrictionModel(matrixmodels)
end

Expand Down
28 changes: 25 additions & 3 deletions src/matrixmodels/acmatrixmodels.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,42 @@
# Z2S<:Uncoupled, SPSYM<:SpeciesUnCoupled, CUTOFF<:SphericalCutoff
struct ACMatrixModel{O3S,CUTOFF,COUPLING} <: MatrixModel{O3S}
onsite::Dict{AtomicNumber,OnSiteModel{O3S,TM1}} where {TM1}
offsite::Dict{Tuple{AtomicNumber, AtomicNumber},OffSiteModel{O3S,TM2,Z2S,CUTOFF}} where {TM2, Z2S}#, CUTOFF<:SphericalCutoff}
onsite::OnSiteModels{O3S}
offsite::OffSiteModels{O3S,Z2S,CUTOFF} where {Z2S}#, CUTOFF<:SphericalCutoff}
n_rep::Int
inds::SiteInds
id::Symbol
function ACMatrixModel(onsite::OnSiteModels{O3S,TM1}, offsite::OffSiteModels{O3S,TM2,Z2S,CUTOFF}, id::Symbol, ::COUPLING) where {O3S,TM1, TM2,Z2S, CUTOFF<:SphericalCutoff, COUPLING<:Union{RowCoupling,ColumnCoupling}}
function ACMatrixModel(onsite::OnSiteModels{O3S}, offsite::OffSiteModels{O3S,Z2S,CUTOFF}, id::Symbol, ::COUPLING) where {O3S,Z2S, CUTOFF<:SphericalCutoff, COUPLING<:Union{RowCoupling,ColumnCoupling}}
_assert_offsite_keys(offsite, SpeciesUnCoupled())
@assert _n_rep(onsite) == _n_rep(offsite)
@assert length(unique([mo.cutoff for mo in values(offsite)])) == 1
@assert length(unique([mo.cutoff for mo in values(onsite)])) == 1
#@assert all([z1 in keys(onsite), z2 in keys(offsite) for (z1,z2) in zzkeys])
@show typeof(onsite)
return new{O3S,CUTOFF,COUPLING}(onsite, offsite, _n_rep(onsite), SiteInds(_get_basisinds(onsite), _get_basisinds(offsite)), id)
end
end #TODO: Add proper constructor that checks for correct Species coupling

function ACE.write_dict(M::ACMatrixModel{O3S,CUTOFF,COUPLING}) where {O3S,CUTOFF,COUPLING}
return Dict("__id__" => "ACEds_ACMatrixModel",
"onsite" => ACE.write_dict(M.onsite),
#Dict(zz=>write_dict(val) for (zz,val) in M.onsite),
"offsite" => ACE.write_dict(M.offsite),
# => Dict(zz=>write_dict(val) for (zz,val) in M.offsite),
"id" => string(M.id),
"O3S" => write_dict(O3S),
"CUTOFF" => write_dict(CUTOFF),
"COUPLING" => write_dict(COUPLING()))
end
function ACE.read_dict(::Val{:ACEds_ACMatrixModel}, D::Dict)
onsite = ACE.read_dict(D["onsite"])
offsite = ACE.read_dict(D["offsite"])
#Dict(zz=>read_dict(val) for (zz,val) in D["onsite"])
#offsite = Dict(zz=>read_dict(val) for (zz,val) in D["offsite"])
id = Symbol(D["id"])
coupling = read_dict(D["COUPLING"])
return ACMatrixModel(onsite, offsite, id, coupling)
end

function ACE.set_params!(mb::ACMatrixModel, θ::NamedTuple)
ACE.set_params!(mb, :onsite, θ.onsite)
ACE.set_params!(mb, :offsite, θ.offsite)
Expand Down
107 changes: 24 additions & 83 deletions src/matrixmodels/matrixmodels-io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,88 +2,29 @@ ACE.write_dict(v::SVector{N,T}) where {N,T} = v
ACE.read_dict(v::SVector{N,T}) where {N,T} = v


function ACE.write_dict(m::OnSiteModel{O3S,TM}) where {O3S,TM}
return Dict("__id__" => "ACEds_OnSiteModel",
"linbasis" => write_dict(m.linmodel.basis),
"c" => write_dict(params(m.linmodel)),
"cutoff" => write_dict(m.cutoff)
)
end

function ACE.read_dict(::Val{:ACEds_OnSiteModel}, D::Dict)
linbasis = ACE.read_dict(D["linbasis"])
c = ACE.read_dict(D["c"])
cutoff = ACE.read_dict(D["cutoff"])
return OnSiteModel(linbasis, cutoff, c)
end

function ACE.write_dict(m::OffSiteModel{O3S,TM,Z2S,CUTOFF}) where {O3S,TM,Z2S,CUTOFF}
return Dict("__id__" => "ACEds_OffSiteModel",
"linbasis" => write_dict(m.linmodel.basis),
"c" => write_dict(params(m.linmodel)),
"cutoff" => write_dict(m.cutoff),
"Z2S" => write_dict(Z2S()))
end

function ACE.read_dict(::Val{:ACEds_OffSiteModel}, D::Dict)
linbasis = ACE.read_dict(D["linbasis"])
c = ACE.read_dict(D["c"])
cutoff = ACE.read_dict(D["cutoff"])
Z2S = ACE.read_dict(D["Z2S"])
bondbais = BondBasis(linbasis,Z2S)
return OffSiteModel(bondbais, cutoff, c)
end
#linbasis = BondBasis(linbasis,::Z2SYM)

function ACE.write_dict(z2s::Z2S) where {Z2S<:Z2Symmetry}
return Dict("__id__" => string("ACEds_Z2Symmetry"), "z2s"=>typeof(z2s))
end
function ACE.read_dict(::Val{:ACEds_Z2Symmetry}, D::Dict)
return D["z2s"]()
end

function ACE.write_dict(coupling::COUPLING) where {COUPLING<:NoiseCoupling}
return Dict("__id__" => string("ACEds_NoiseCoupling"), "coupling"=>typeof(coupling))
end
function ACE.read_dict(::Val{:ACEds_NoiseCoupling}, D::Dict)
return D["coupling"]()
end

function ACE.write_dict(M::ACMatrixModel{O3S,CUTOFF,COUPLING}) where {O3S,CUTOFF,COUPLING}
return Dict("__id__" => "ACEds_ACMatrixModel",
"onsite" => Dict(zz=>write_dict(val) for (zz,val) in M.onsite),
"offsite" => Dict(zz=>write_dict(val) for (zz,val) in M.offsite),
"id" => string(M.id),
"O3S" => write_dict(O3S),
"CUTOFF" => write_dict(CUTOFF),
"COUPLING" => write_dict(COUPLING()))
end
function ACE.read_dict(::Val{:ACEds_ACMatrixModel}, D::Dict)
onsite = Dict(zz=>read_dict(val) for (zz,val) in D["onsite"])
offsite = Dict(zz=>read_dict(val) for (zz,val) in D["offsite"])
id = Symbol(D["id"])
coupling = read_dict(D["COUPLING"])
return ACMatrixModel(onsite, offsite, id, coupling)
end

function ACE.write_dict(M::PWCMatrixModel{O3S,CUTOFF,COUPLING}) where {O3S,CUTOFF,COUPLING}
return Dict("__id__" => "ACEds_PWCMatrixModel",
"offsite" => Dict(zz=>write_dict(val) for (zz,val) in M.offsite),
"id" => string(M.id))
end
function ACE.read_dict(::Val{:ACEds_PWCMatrixModel}, D::Dict)
offsite = Dict(zz=>read_dict(val) for (zz,val) in D["offsite"])
id = Symbol(D["id"])
return PWCMatrixModel(offsite, id)
end

function ACE.write_dict(M::OnsiteOnlyMatrixModel)
return Dict("__id__" => "ACEds_OnsiteOnlyMatrixModel",
"onsite" => Dict(zz=>write_dict(val) for (zz,val) in M.onsite),
"id" => string(M.id))
end
function ACE.read_dict(::Val{:ACEds_OnsiteOnlyMatrixModel}, D::Dict)
onsite = Dict(zz=>read_dict(val) for (zz,val) in D["onsite"])
id = Symbol(D["id"])
return OnsiteOnlyMatrixModel(onsite, id)
end
# function ACE.write_dict(z2s::Z2S) where {Z2S<:Z2Symmetry}
# return Dict("__id__" => string("ACEds_Z2Symmetry"), "z2s"=>typeof(z2s))
# end
# function ACE.read_dict(::Val{:ACEds_Z2Symmetry}, D::Dict)
# return D["z2s"]()
# end

# function ACE.write_dict(coupling::COUPLING) where {COUPLING<:NoiseCoupling}
# return Dict("__id__" => string("ACEds_NoiseCoupling"), "coupling"=>typeof(coupling))
# end
# function ACE.read_dict(::Val{:ACEds_NoiseCoupling}, D::Dict)
# return D["coupling"]()
# end

# function ACE.write_dict(M::PWCMatrixModel{O3S,CUTOFF,COUPLING}) where {O3S,CUTOFF,COUPLING}
# return Dict("__id__" => "ACEds_PWCMatrixModel",
# "offsite" => Dict(zz=>write_dict(val) for (zz,val) in M.offsite),
# "id" => string(M.id))
# end
# function ACE.read_dict(::Val{:ACEds_PWCMatrixModel}, D::Dict)
# offsite = Dict(zz=>read_dict(val) for (zz,val) in D["offsite"])
# id = Symbol(D["id"])
# return PWCMatrixModel(offsite, id)
# end
Loading

0 comments on commit cb4e8ca

Please sign in to comment.