Skip to content

Commit

Permalink
workarounds for complex exponentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
epolack committed May 3, 2022
1 parent 4983ddd commit 0576459
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
21 changes: 8 additions & 13 deletions src/terms/pairwise.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# We cannot use `LinearAlgebra.norm` with complex numbers due to the need to use its
# analytic continuation
function norm_cplx(x)
# TODO: ForwardDiff bug (https://github.com/JuliaDiff/ForwardDiff.jl/issues/324)
sqrt(sum(x.*x))
sqrt(sum(x.^2))
end

struct PairwisePotential
Expand All @@ -17,10 +16,10 @@ Lennard—Jones terms.
The potential is dependent on the distance between to atomic positions and the pairwise
atomic types:
For a distance `d` between to atoms `A` and `B`, the potential is `V(d, params[(A, B)])`.
The parameters `max_radius` is of `1000` by default, and gives the maximum (Cartesian)
The parameters `max_radius` is of `100` by default, and gives the maximum (Cartesian)
distance between nuclei for which we consider interactions.
"""
function PairwisePotential(V, params; max_radius=1000)
function PairwisePotential(V, params; max_radius=100)
params = Dict(minmax(key[1], key[2]) => value for (key, value) in params)
PairwisePotential(V, params, max_radius)
end
Expand All @@ -43,8 +42,7 @@ end
@timing "forces: Pairwise" function compute_forces(term::TermPairwisePotential,
basis::PlaneWaveBasis{T}, ψ, occ;
kwargs...) where {T}
TT = promote_type(T, eltype(basis.model.positions[1]))
forces = zero(TT, basis.model.positions)
forces = zero(basis.model.positions)
energy_pairwise(basis.model, term.V, term.params; max_radius=term.max_radius,
forces=forces, kwargs...)
forces
Expand All @@ -65,15 +63,13 @@ end

# This could be factorised with Ewald, but the use of `symbols` would slow down the
# computationally intensive Ewald sums. So we leave it as it for now.
# TODO: *Beware* of using ForwardDiff to derive this function with complex numbers, use
# multiplications and not powers (https://github.com/JuliaDiff/ForwardDiff.jl/issues/324).
# `q` is the phonon `q`-point (`Vec3`), and `ph_disp` a list of `Vec3` displacements to
# compute the Fourier transform of the force constant matrix.
function energy_pairwise(lattice, symbols, positions, V, params;
max_radius=1000, forces=nothing, ph_disp=nothing, q=nothing)
max_radius=100, forces=nothing, ph_disp=nothing, q=nothing)
@assert length(symbols) == length(positions)

T = eltype(lattice)
T = eltype(positions[1])
if ph_disp !== nothing
@assert q !== nothing
T = promote_type(complex(T), eltype(ph_disp[1]))
Expand Down Expand Up @@ -135,9 +131,8 @@ function energy_pairwise(lattice, symbols, positions, V, params;
sum_pairwise += energy_contribution
if forces !== nothing
dE_ddist = ForwardDiff.derivative(real(zero(eltype(dist)))) do ε
res = V(dist + ε, param_ij)
[real(res), imag(res)]
end |> x -> complex(x...)
V(dist + ε, param_ij)
end
dE_dti = lattice' * ((dE_ddist / dist) * Δr)
# We need to "break" the symmetry for phonons; at equilibrium, expect
# the forces to be zero at machine precision.
Expand Down
29 changes: 29 additions & 0 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,32 @@ function Smearing.occupation(S::Smearing.FermiDirac, d::ForwardDiff.Dual{T}) whe
end
ForwardDiff.Dual{T}(Smearing.occupation(S, x), ∂occ * ForwardDiff.partials(d))
end

# Workarounds for issue https://github.com/JuliaDiff/ForwardDiff.jl/issues/324
ForwardDiff.derivative(f, x::Complex) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives). Separate real and imaginary parts of the input."))
@inline ForwardDiff.extract_derivative(::Type{T}, y::Complex) where {T} = zero(y)
@inline function ForwardDiff.extract_derivative(::Type{T}, y::Complex{TD}) where {T, TD <: ForwardDiff.Dual}
complex(ForwardDiff.partials(T, real(y), 1), ForwardDiff.partials(T, imag(y), 1))
end
function Base.:^(x::Complex{ForwardDiff.Dual{T,V,N}}, y::Complex{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
xx = complex(ForwardDiff.value(real(x)), ForwardDiff.value(imag(x)))
yy = complex(ForwardDiff.value(real(y)), ForwardDiff.value(imag(y)))
dx = complex.(ForwardDiff.partials(real(x)), ForwardDiff.partials(imag(x)))
dy = complex.(ForwardDiff.partials(real(y)), ForwardDiff.partials(imag(y)))

expv = xx^yy
∂expv∂x = yy * xx^(yy-1)
∂expv∂y = log(xx) * expv
dxexpv = ∂expv∂x * dx
# TODO: Fishy and should be checked, but seems to catch most cases
if iszero(xx) && ForwardDiff.isconstant(real(y)) && ForwardDiff.isconstant(imag(y)) && imag(y) === zero(imag(y)) && real(y) > 0
dexpv = zero(expv)
elseif iszero(xx)
throw(DomainError(x, "mantissa cannot be zero for complex exponentiation"))
else
dyexpv = ∂expv∂y * dy
dexpv = dxexpv + dyexpv
end
complex(ForwardDiff.Dual{T,V,N}(real(expv), ForwardDiff.Partials{N,V}(tuple(real(dexpv)...))),
ForwardDiff.Dual{T,V,N}(imag(expv), ForwardDiff.Partials{N,V}(tuple(imag(dexpv)...))))
end

0 comments on commit 0576459

Please sign in to comment.