Skip to content

Commit

Permalink
Performance optimizations (#671)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael F. Herbst <[email protected]>
  • Loading branch information
antoine-levitt and mfherbst authored Jun 25, 2022
1 parent bef41ff commit f06b277
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 31 deletions.
12 changes: 7 additions & 5 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,17 +402,19 @@ Return the index tuple `I` such that `G_vectors(basis)[I] == G`
or the index `i` such that `G_vectors(basis, kpoint)[i] == G`.
Returns nothing if outside the range of valid wave vectors.
"""
function index_G_vectors(basis::PlaneWaveBasis, G::AbstractVector{T}) where {T <: Integer}
@inline function index_G_vectors(basis::PlaneWaveBasis, G::AbstractVector{T}) where {T <: Integer}
# the inline declaration encourages the compiler to hoist these (G-independent) precomputations
start = .- cld.(basis.fft_size .- 1, 2)
stop = fld.(basis.fft_size .- 1, 2)
lengths = stop .- start .+ 1

function mapaxis(lengthi, Gi)
Gi >= 0 && return 1 + Gi
return 1 + lengthi + Gi
# FFTs store wavevectors as [0 1 2 3 -2 -1] (example for N=5)
function G_to_index(length, G)
G >= 0 && return 1 + G
return 1 + length + G
end
if all(start .<= G .<= stop)
CartesianIndex(Tuple(mapaxis.(lengths, G)))
CartesianIndex(Tuple(G_to_index.(lengths, G)))
else
nothing # Outside range of valid indices
end
Expand Down
6 changes: 4 additions & 2 deletions src/SymOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ Base.:(==)(op1::SymOp, op2::SymOp) = op1.W == op2.W && op1.w == op2.w
function Base.isapprox(op1::SymOp, op2::SymOp; atol=SYMMETRY_TOLERANCE)
op1.W == op2.W && is_approx_integer(op1.w - op2.w; tol=atol)
end
Base.one(::Type{SymOp}) = SymOp(Mat3{Int}(I), Vec3(zeros(Bool, 3)))
Base.one(::SymOp) = one(SymOp)
Base.one(::Type{SymOp}) = one(SymOp{Bool}) # Not sure about this method
Base.one(::Type{SymOp{T}}) where {T} = SymOp(Mat3{Int}(I), Vec3(zeros(T, 3)))
Base.one(::SymOp{T}) where {T} = one(SymOp{T})
Base.isone(op::SymOp) = isone(op.W) && iszero(op.w)

# group composition and inverse.
function Base.:*(op1::SymOp, op2::SymOp)
Expand Down
13 changes: 7 additions & 6 deletions src/densities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ Compute the density for a wave function `ψ` discretized on the plane-wave
grid `basis`, where the individual k-points are occupied according to `occupation`.
`ψ` should be one coefficient matrix per ``k``-point.
"""
@views @timing function compute_density(basis::PlaneWaveBasis, ψ, occupation)
@views @timing function compute_density(basis, ψ, occupation)
T = promote_type(eltype(basis), real(eltype(ψ[1])))

# we split the total iteration range (ik, n) in chunks, and parallelize over them
ik_n = [(ik, n) for ik = 1:length(basis.kpoints) for n = 1:size(ψ[ik], 2)]
chunk_length = cld(length(ik_n), Threads.nthreads())

# chunk-local variables
ρ_chunklocal = [zeros(T, basis.fft_size..., basis.model.n_spin_components)
for _ = 1:Threads.nthreads()]
ψnk_real_chunklocal = [zeros(complex(T), basis.fft_size) for _ = 1:Threads.nthreads()]
ρ_chunklocal = Array{T,4}[zeros(T, basis.fft_size..., basis.model.n_spin_components)
for _ = 1:Threads.nthreads()]
ψnk_real_chunklocal = Array{complex(T),3}[zeros(complex(T), basis.fft_size)
for _ = 1:Threads.nthreads()]

@sync for (ichunk, chunk) in enumerate(Iterators.partition(ik_n, chunk_length))
Threads.@spawn for (ik, n) in chunk # spawn a task per chunk
Expand All @@ -43,7 +44,7 @@ grid `basis`, where the individual k-points are occupied according to `occupatio

ρ = sum(ρ_chunklocal)
mpi_sum!(ρ, basis.comm_kpts)
ρ = symmetrize_ρ(basis, ρ)
ρ = symmetrize_ρ(basis, ρ; do_lowpass=false)

_check_positive(ρ)
n_elec_check = weighted_ksum(basis, sum.(occupation))
Expand Down Expand Up @@ -75,7 +76,7 @@ end
end
end
mpi_sum!(τ, basis.comm_kpts)
symmetrize_ρ(basis, τ)
symmetrize_ρ(basis, τ; do_lowpass=false)
end

total_density(ρ) = dropdims(sum(ρ; dims=4); dims=4)
Expand Down
9 changes: 5 additions & 4 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,16 @@ Perform an FFT to obtain the Fourier representation of `f_real`. If
`kpt` is given, the coefficients are truncated to the k-dependent
spherical basis set.
"""
function r_to_G(basis::PlaneWaveBasis{T}, f_real::AbstractArray) where T
f_fourier = similar(f_real, complex(promote_type(T, eltype(f_real))))
function r_to_G(basis::PlaneWaveBasis{T}, f_real::AbstractArray{U}) where {T, U}
f_fourier = similar(f_real, complex(promote_type(T, U)))
@assert length(size(f_real)) (3, 4)
# this exploits trailing index convention
for= 1:size(f_real, 4)
for= 1:size(f_real, 4) # this exploits trailing index convention
@views r_to_G!(f_fourier[:, :, :, iσ], basis, f_real[:, :, :, iσ])
end
f_fourier
end


# TODO optimize this
function r_to_G(basis::PlaneWaveBasis, kpt::Kpoint, f_real::AbstractArray3; kwargs...)
r_to_G!(similar(f_real, length(kpt.mapping)), basis, kpt, copy(f_real); kwargs...)
Expand Down
31 changes: 17 additions & 14 deletions src/symmetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ basis of the new ``k``-point).
"""
function apply_symop(symop::SymOp, basis, kpoint, ψk::AbstractVecOrMat)
S, τ = symop.S, symop.τ
symop == one(SymOp) && return kpoint, ψk
isone(symop) && return kpoint, ψk

# Apply S and reduce coordinates to interval [-0.5, 0.5)
# Doing this reduction is important because
Expand Down Expand Up @@ -146,19 +146,17 @@ end
"""
Apply a symmetry operation to a density.
"""
function apply_symop(symop::SymOp, basis, ρin)
symop == one(SymOp) && return ρin
symmetrize_ρ(basis, ρin; symmetries=[symop])
function apply_symop(symop::SymOp, basis, ρin; kwargs...)
isone(symop) && return ρin
symmetrize_ρ(basis, ρin; symmetries=[symop], kwargs...)
end


# Accumulates the symmetrized versions of the density ρin into ρout (in Fourier space).
# No normalization is performed
function accumulate_over_symmetries!(ρaccu, ρin, basis, symmetries)
T = eltype(basis)
@timing function accumulate_over_symmetries!(ρaccu, ρin, basis::PlaneWaveBasis{T}, symmetries) where {T}
for symop in symmetries
# Common special case, where ρin does not need to be processed
if symop == one(SymOp)
if isone(symop)
ρaccu .+= ρin
continue
end
Expand All @@ -174,8 +172,13 @@ function accumulate_over_symmetries!(ρaccu, ρin, basis, symmetries)
invS = Mat3{Int}(inv(symop.S))
for (ig, G) in enumerate(G_vectors_generator(basis.fft_size))
igired = index_G_vectors(basis, invS * G)
if igired !== nothing
@inbounds ρaccu[ig] += cis2pi(-T(dot(G, symop.τ))) * ρin[igired]
isnothing(igired) && continue

if iszero(symop.τ)
@inbounds ρaccu[ig] += ρin[igired]
else
factor = cis2pi(-T(dot(G, symop.τ)))
@inbounds ρaccu[ig] += factor * ρin[igired]
end
end
end # symop
Expand All @@ -185,7 +188,7 @@ end
# Low-pass filters ρ (in Fourier) so that symmetry operations acting on it stay in the grid
function lowpass_for_symmetry!(ρ, basis; symmetries=basis.symmetries)
for symop in symmetries
symop == one(SymOp) && continue
isone(symop) && continue
for (ig, G) in enumerate(G_vectors_generator(basis.fft_size))
if index_G_vectors(basis, symop.S * G) === nothing
ρ[ig] = 0
Expand All @@ -198,13 +201,13 @@ end
"""
Symmetrize a density by applying all the basis (by default) symmetries and forming the average.
"""
@views @timing function symmetrize_ρ(basis, ρ; symmetries=basis.symmetries)
ρin_fourier = r_to_G(basis, ρ)
@views @timing function symmetrize_ρ(basis, ρ; symmetries=basis.symmetries, do_lowpass=true)
ρin_fourier = r_to_G(basis, ρ)
ρout_fourier = zero(ρin_fourier)
for σ = 1:size(ρ, 4)
accumulate_over_symmetries!(ρout_fourier[:, :, :, σ],
ρin_fourier[:, :, :, σ], basis, symmetries)
lowpass_for_symmetry!(ρout_fourier[:, :, :, σ], basis; symmetries)
do_lowpass && lowpass_for_symmetry!(ρout_fourier[:, :, :, σ], basis; symmetries)
end
G_to_r(basis, ρout_fourier ./ length(symmetries))
end
Expand Down

0 comments on commit f06b277

Please sign in to comment.