From 487e1c46febe6b033ce867836426b1644182269f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20J=C3=A4rvinen?= Date: Tue, 19 Sep 2023 23:22:04 -0700 Subject: [PATCH] fix extra vector data for AtomsBase --- src/atoms.jl | 4 ++++ test/atomsbase.jl | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/atoms.jl b/src/atoms.jl index 5ae78a3..4bc933c 100644 --- a/src/atoms.jl +++ b/src/atoms.jl @@ -135,6 +135,8 @@ function Atoms(dict::Dict{String, Any}) atom_data[Symbol(key)] = arrays[key] * u"Å" elseif key in ("charge", ) # Add charge unit atom_data[Symbol(key)] = arrays[key] * u"e_au" + elseif typeof(arrays[key]) <: AbstractMatrix + atom_data[Symbol(key)] = [ collect(col) for col in eachcol(arrays[key]) ] else atom_data[Symbol(key)] = arrays[key] end @@ -198,6 +200,8 @@ function write_dict(atoms::Atoms) arrays[string(k)] = ustrip.(u"e_au", v) elseif v isa AbstractVector{<:ExtxyzType} arrays[string(k)] = v # These can be written losslessly + elseif v isa AbstractArray && eltype(v) <: AbstractVector{<:ExtxyzType} + arrays[string(k)] = reduce(hcat, v) else @warn "Writing quantities of type $(typeof(v)) is not supported in write_dict." end diff --git a/test/atomsbase.jl b/test/atomsbase.jl index 83eb85a..a403357 100644 --- a/test/atomsbase.jl +++ b/test/atomsbase.jl @@ -88,3 +88,28 @@ end end test_approx_eq(system, io_system; rtol=1e-4) end + +@testset "Extra variables for atoms" begin + text = """2 +Lattice="2.614036117884091 0.0 0.0 0.0 2.6528336296738044 0.0 0.0 0.0 3.8250280122051756" Properties=species:S:1:pos:R:3:force:R:3:tags:I:1 config_type=FLD_TiAl spacegroup="P 1" virial="5.072173561696366 0.1220123768779895 0.6518229755809941 0.1220123768779895 4.667636799854875 0.5969893898844183 0.6518229755809941 0.5969893898844183 4.700422750506493" energy=-1703.64063822 unit_cell=conventional n_minim_iter=2 pbc="T T T" +Ti 1.30924260 1.32316179 1.62637131 0.86219000 0.78737000 2.65969000 0 +Al 0.11095015 0.09471147 -0.05013464 -0.86219000 -0.78737000 -2.65969000 1 +""" + infile = tempname() + outfile = tempname() + open(infile, "w") do io + print(io, text) + end + frame = read_frame(infile) + atoms = Atoms(frame) + @test all( atoms[1][:force] .== [0.86219000, 0.78737000, 2.65969000] ) + @test all( atoms[2][:force] .== [-0.86219000, -0.78737000, -2.65969000] ) + @test atoms[1][:tags] == 0 + @test atoms[2][:tags] == 1 + ExtXYZ.save(outfile, atoms) + new_atoms = ExtXYZ.load(outfile) + @test all( atoms[1][:force] .== new_atoms[1][:force] ) + @test all( atoms[2][:force] .== new_atoms[2][:force] ) + @test atoms[1][:tags] == new_atoms[1][:tags] + @test atoms[2][:tags] == new_atoms[2][:tags] +end