Skip to content

Commit

Permalink
gpu_device() on prob
Browse files Browse the repository at this point in the history
  • Loading branch information
YichengDWu committed Nov 25, 2023
1 parent ff9b95c commit 6b8a28c
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 63 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
[extensions]
SophonOptimisersExt = "Optimisers"
SophonTaylorDiffExt = "TaylorDiff"
SophonTaylorDiffLuxExt = ["TaylorDiff", "LuxCUDA"]
SophonTaylorDiffLuxCUDAExt = ["TaylorDiff", "LuxCUDA"]
SophonLuxCUDAExt = "LuxCUDA"

[compat]
Adapt = "3"
Expand Down
3 changes: 1 addition & 2 deletions docs/src/qa.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ A: To train the model on a single GPU, do the following:

```julia
using Lux, LuxCUDA
device = gpu_device()
pinn = PINN(...) |> device
prob = Sophon.discretize(...) |> gpu_device()
```
## Q: How can I monitor the loss for each loss function?

Expand Down
12 changes: 12 additions & 0 deletions ext/SophonLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module SophonLuxCUDAExt

using Lux, LuxCUDA, Sophon, ModelingToolkit

function (::LuxCUDADevice)(prob::Union{ModelingToolkit.PDESystem, Sophon.PDESystem})
u0 = adapt(CuArray, prob.u0)
p = [adapt(CuArray, prob.p[i]) for i in 1:length(prob.p)]
prob = remake(prob, u0=u0, p=p)
return prob
end

end
1 change: 0 additions & 1 deletion src/Sophon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ include("layers/nets.jl")
include("layers/utils.jl")
include("layers/operators.jl")

include("pde/componentarrays.jl")
include("pde/pinn_types.jl")
include("pde/utils.jl")
include("pde/sym_utils.jl")
Expand Down
15 changes: 0 additions & 15 deletions src/pde/componentarrays.jl

This file was deleted.

4 changes: 1 addition & 3 deletions src/pde/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,9 @@ function discretize(pde_system, pinn::PINN, sampler::PINNSampler,
adtype=Optimization.AutoZygote())
datasets = sample(pde_system, sampler)
init_params = Lux.fmap(Base.Fix1(broadcast, fdtype), pinn.init_params)
init_params = _ComponentArray(init_params)
init_params = ComponentArray(init_params)

datasets = map(Base.Fix1(broadcast, fdtype), datasets)
datasets = init_params isa AbstractGPUComponentVector ?
map(Base.Fix1(adapt, get_gpu_adaptor()), datasets) : datasets
pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy,
derivative, derivative_bc,
fdtype)
Expand Down
3 changes: 1 addition & 2 deletions src/pde/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
PINN(chain, rng::AbstractRNG=Random.default_rng())
PINN(rng::AbstractRNG=Random.default_rng(); kwargs...)
A container for a neural network, its states and its initial parameters. Call `Lux.gpu_device()`
and `Lux.cpu_device()` to move the neural network to the GPU and CPU respectively.
A container for a neural network, its states and its initial parameters.
The default element type of the parameters is `Float64`.
## Fields
Expand Down
25 changes: 7 additions & 18 deletions src/pde/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
function isongpu(nt::NamedTuple)
return any(x -> x isa AbstractGPUArray, Lux.fcollect(nt))
end

function get_l2_loss_function(loss_function, dataset)
loss(θ) = mean(abs2, loss_function(dataset, θ))
return loss
Expand All @@ -14,35 +10,28 @@ This function is only used for the first order derivative.
"""
forwarddiff(phi, t, εs, order, θ) = ForwardDiff.gradient(sum Base.Fix2(phi, θ), t)

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
ldev = Symbol("Lux$(dev)Device")
ladaptor = Symbol("Lux$(dev)Adaptor")
@eval @inline get_adaptor(::$(ldev)) = $(ladaptor)()
end
@inline get_gpu_adaptor() = get_adaptor(gpu_device())

@memoize maybe_adapt(x::AbstractGPUArray, ε) = convert(parameterless_type(x), ε)
@memoize maybe_adapt(x, ε) = ε
ChainRulesCore.@non_differentiable maybe_adapt(x, ε)
@memoize maybe_convert(x::AbstractGPUArray, ε) = convert(parameterless_type(x), ε)
@memoize maybe_convert(x, ε) = ε
ChainRulesCore.@non_differentiable maybe_convert(x, ε)

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{1}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
ε = maybe_convert(x, ε_)
return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* (h / 2)
end

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{2}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
ε = maybe_convert(x, ε_)
return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* h^2
end

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{3}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
ε = maybe_convert(x, ε_)
return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) -
phi(x .- 2 .* ε, θ)) .* h^3 ./ 2
end

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{4}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
ε = maybe_convert(x, ε_)
return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ 6 .* phi(x, θ) .-
4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* h^4
end
Expand Down
21 changes: 0 additions & 21 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,27 +288,6 @@ rng = Random.default_rng()
@test_nowarn AdaptiveTraining((θ, p) -> p, 5)
@test_nowarn AdaptiveTraining(((θ, p) -> p, (θ, p) -> θ), (3, 4, 5))
end

#=
@testset "GPU" begin
@testset "single model" begin
pinn = PINN(DiscreteFourierFeature(2,1,2,2))
pinn = pinn |> gpu
@test getdata(pinn.init_params) isa CuArray
phi = pinn.phi
@test phi.state.weight isa CuArray
end
@testset "multiple models" begin
pinn = PINN(u = DiscreteFourierFeature(2,1,2,2),
v = DiscreteFourierFeature(2,1,2,2))
pinn = pinn |> gpu
@test getdata(pinn.init_params) isa CuArray
phi = pinn.phi
@test phi.u.state.weight isa CuArray
end
end
=#
end

@testset "BetaSampler" begin include("betasampler.jl") end
Expand Down

0 comments on commit 6b8a28c

Please sign in to comment.