Skip to content

Commit

Permalink
Merge pull request #252 from YichengDWu/luxnew
Browse files Browse the repository at this point in the history
update to Lux new interfave
  • Loading branch information
YichengDWu authored Nov 25, 2023
2 parents b0035d6 + 798e747 commit 6edc431
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 113 deletions.
16 changes: 8 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ version = "0.4.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -23,6 +21,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
Expand All @@ -31,46 +30,46 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Sobol = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[weakdeps]
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"

[extensions]
SophonOptimisersExt = "Optimisers"
SophonTaylorDiffExt = "TaylorDiff"
SophonTaylorDiffLuxCUDAExt = ["TaylorDiff", "LuxCUDA"]
SophonLuxCUDAExt = "LuxCUDA"

[compat]
Adapt = "3"
CUDA = "5"
ChainRulesCore = "1"
ComponentArrays = "0.15"
Distributions = "0.25"
DomainSets = "0.5, 0.6, 0.7"
ForwardDiff = "0.10"
GPUArrays = "9"
GPUArraysCore = "0.1"
LRUCache = "1"
Lux = "0.5.6"
LuxCUDA = "0.3"
MacroTools = "0.5"
Memoize = "0.4"
ModelingToolkit = "8"
NNlib = "0.9"
NNlib = "0.8, 0.9"
Optimisers = "0.2"
Optimization = "3"
OptimizationOptimisers = "0.1"
PackageExtensionCompat = "1"
ProgressBars = "1.5"
QuasiMonteCarlo = "0.2, 0.3"
Requires = "1"
RuntimeGeneratedFunctions = "0.5"
SciMLBase = "2"
Sobol = "1, 2"
StaticArrays = "1.5"
StaticArraysCore = "1"
StatsBase = "0.33, 0.34"
Symbolics = "5"
julia = "1.8"
Expand All @@ -83,6 +82,7 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"

[targets]
test = ["Test", "Zygote", "ModelingToolkit", "DomainSets", "OptimizationOptimJL", "TaylorDiff"]
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand All @@ -18,5 +19,6 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Sophon = "077df616-1c15-4d29-b519-7542a62df138"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8 changes: 4 additions & 4 deletions docs/src/qa.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
## Q: How can I train the model using GPUs?
## Q: How can I train the model using my GPU?

A: To train the model on GPUs, invoke the gpu function on instances of PINN:
A: To train the model on a single GPU, do the following:

```julia
using Lux
pinn = gpu(PINN(...))
using Lux, LuxCUDA
prob = Sophon.discretize(...) |> gpu_device()
```
## Q: How can I monitor the loss for each loss function?

Expand Down
8 changes: 3 additions & 5 deletions docs/src/tutorials/helmholtz.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bcs = [u(-1,y) ~ 0, u(1,y) ~ 0, u(x, -1) ~ 0, u(x, 1) ~ 0]
Note that the boundary conditions are compatible with periocity, which allows us to apply [`BACON`](@ref).
```@example helmholtz
chain = BACON(2, 1, 5, 2; hidden_dims = 32, num_layers=5)
pinn = PINN(chain) # call `gpu` on it if you want to use gpu
pinn = PINN(chain)
sampler = QuasiRandomSampler(300, 100)
strategy = NonAdaptiveTraining()
Expand All @@ -50,14 +50,12 @@ prob = Sophon.discretize(helmholtz, pinn, sampler, strategy)
Let's plot the result.
```@example helmholtz
phi = pinn.phi
ps = res.u
xs, ys= [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
u_analytic(x,y) = sinpi(a1*x)*sinpi(a2*y)
u_real = [u_analytic(x,y) for x in xs, y in ys]
phi_cpu = cpu(phi) # in case you are using GPU
ps_cpu = cpu(res.u)
u_pred = [sum(phi_cpu(([x,y]), ps_cpu)) for x in xs, y in ys]
u_pred = [sum(phi(([x,y]), ps)) for x in xs, y in ys]
using CairoMakie
axis = (xlabel="x", ylabel="y", title="Analytical Solution")
Expand Down
11 changes: 11 additions & 0 deletions ext/SophonLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module SophonLuxCUDAExt

using Lux, LuxCUDA, Sophon, Optimization, Adapt

function (::LuxCUDADevice)(prob::OptimizationProblem)
u0 = adapt(CuArray, prob.u0)
p = Tuple(adapt(CuArray, prob.p[i]) for i in 1:length(prob.p)) # have to use tuple here...
return Optimization.OptimizationProblem(prob.f, u0, p)
end

end
9 changes: 1 addition & 8 deletions ext/SophonTaylorDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ for N in 1:5
end

@inline function taylordiff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{N}) where {T <: Number, N}
ε = Sophon.maybe_adapt(x, ε_)
ε = Sophon.maybe_convert(x, ε_)
return TaylorDiff.derivative(Base.Fix2(phi, θ), x, ε, Val{N+1}())
end

Expand Down Expand Up @@ -147,13 +147,6 @@ for l in 1:4
end
end

# avoid NaN
function Base.:*(A::Union{Sophon.CuMatrix{T}, LinearAlgebra.Transpose{T, Sophon.CuArray}},
B::Sophon.CuMatrix{TaylorScalar{T, N}}) where {T, N}
C = similar(B, (size(A, 1), size(B, 2)))
fill!(C, zero(eltype(C)))
return LinearAlgebra.mul!(C, A, B)
end

function __init__()
@static if VERSION >= v"1.9.0"
Expand Down
12 changes: 12 additions & 0 deletions ext/SophonTaylorDiffLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module SophonTaylorDiffLuxCUDAExt

using TaylorDiff, LuxCUDA, Sophon

function Base.:*(A::Union{Sophon.CuMatrix{T}, LinearAlgebra.Transpose{T, Sophon.CuArray}},
B::Sophon.CuMatrix{TaylorScalar{T, N}}) where {T, N}
C = similar(B, (size(A, 1), size(B, 2)))
fill!(C, zero(eltype(C)))
return LinearAlgebra.mul!(C, A, B)
end

end
11 changes: 4 additions & 7 deletions src/Sophon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using ComponentArrays
import SciMLBase
import SciMLBase: parameterless_type, __solve, build_solution, NullParameters
using StatsBase, QuasiMonteCarlo
using Adapt, ChainRulesCore, CUDA, GPUArrays, GPUArraysCore
using Adapt, ChainRulesCore, GPUArraysCore
import GPUArraysCore: AbstractGPUArray
import QuasiMonteCarlo
import Sobol
Expand All @@ -26,7 +26,7 @@ using ForwardDiff
using MacroTools
using MacroTools: prewalk, postwalk
using Requires
using StaticArrays: SVector
using StaticArraysCore: SVector

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand All @@ -38,19 +38,16 @@ include("layers/nets.jl")
include("layers/utils.jl")
include("layers/operators.jl")

include("pde/componentarrays.jl")
include("pde/pinn_types.jl")
include("pde/utils.jl")
include("pde/sym_utils.jl")
include("pde/training_strategies.jl")
include("pde/pinnsampler.jl")
include("pde/discretize.jl")

using PackageExtensionCompat
function __init__()
@static if !isdefined(Base, :get_extension)
@require Optimisers="3bd65402-5787-11e9-1adc-39752487f4e2" begin include("../ext/SophonOptimisersExt.jl") end
@require TaylorDiff="b36ab563-344f-407b-a36a-4f200bebf99c" begin include("../ext/SophonTaylorDiffExt.jl") end
end
@require_extensions
end

export @showprogress
Expand Down
15 changes: 0 additions & 15 deletions src/pde/componentarrays.jl

This file was deleted.

4 changes: 1 addition & 3 deletions src/pde/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,9 @@ function discretize(pde_system, pinn::PINN, sampler::PINNSampler,
adtype=Optimization.AutoZygote())
datasets = sample(pde_system, sampler)
init_params = Lux.fmap(Base.Fix1(broadcast, fdtype), pinn.init_params)
init_params = _ComponentArray(init_params)
init_params = ComponentArray(init_params)

datasets = map(Base.Fix1(broadcast, fdtype), datasets)
datasets = init_params isa AbstractGPUComponentVector ?
map(Base.Fix1(adapt, CuArray), datasets) : datasets
pde_and_bcs_loss_function = build_loss_function(pde_system, pinn, strategy,
derivative, derivative_bc,
fdtype)
Expand Down
53 changes: 21 additions & 32 deletions src/pde/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
PINN(chain, rng::AbstractRNG=Random.default_rng())
PINN(rng::AbstractRNG=Random.default_rng(); kwargs...)
A container for a neural network, its states and its initial parameters. Call `gpu` and `cpu` to move the neural network to the GPU and CPU respectively.
A container for a neural network, its states and its initial parameters.
The default element type of the parameters is `Float64`.
## Fields
Expand Down Expand Up @@ -120,40 +120,29 @@ end

const NTofChainState{names} = NamedTuple{names, <:Tuple{Vararg{ChainState}}}

function Lux.cpu(cs::ChainState)
Lux.@set! cs.state = cpu(cs.state)
return cs
end

function Lux.gpu(cs::ChainState)
Lux.@set! cs.state = adapt(CuArray, cs.state)
return cs
end

function Lux.cpu(cs::NamedTuple{names, <:Tuple{Vararg{ChainState}}}) where {names}
return map(cs) do c
return cpu(c)
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
ldev = Symbol("Lux$(dev)Device")
ladaptor = Symbol("Lux$(dev)Adaptor")
@eval begin
function (device::$ldev)(cs::ChainState)
Lux.@set! cs.state = device(cs.state)
return cs
end

function (device::$ldev)(cs::NTofChainState{names}) where {names}
return map(cs) do c
return device(c)
end
end

function (device::$ldev)(pinn::PINN)
Lux.@set! pinn.phi = device(pinn.phi)
Lux.@set! pinn.init_params = adapt($(ladaptor)(), pinn.init_params)
return pinn
end
end
end

function Lux.gpu(cs::NamedTuple{names, <:Tuple{Vararg{ChainState}}}) where {names}
return map(cs) do c
return gpu(c)
end
end

function Lux.gpu(pinn::PINN)
Lux.@set! pinn.phi = gpu(pinn.phi)
Lux.@set! pinn.init_params = adapt(CuArray, pinn.init_params)
return pinn
end

function Lux.cpu(pinn::PINN)
Lux.@set! pinn.phi = cpu(pinn.phi)
Lux.@set! pinn.init_params = cpu(pinn.init_params)
return pinn
end

"""
using Sophon, ModelingToolkit, DomainSets
using DomainSets: ×
Expand Down
17 changes: 7 additions & 10 deletions src/pde/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
function isongpu(nt::NamedTuple)
return any(x -> x isa AbstractGPUArray, Lux.fcollect(nt))
end

function get_l2_loss_function(loss_function, dataset)
loss(θ) = mean(abs2, loss_function(dataset, θ))
return loss
Expand All @@ -14,27 +10,28 @@ This function is only used for the first order derivative.
"""
forwarddiff(phi, t, εs, order, θ) = ForwardDiff.gradient(sum Base.Fix2(phi, θ), t)

@inline maybe_adapt(x::AbstractGPUArray, ε_) = ChainRulesCore.@ignore_derivatives convert(CuArray, ε_)
@inline maybe_adapt(x, ε_) = ChainRulesCore.@ignore_derivatives ε_
@memoize maybe_convert(x::AbstractGPUArray, ε) = convert(parameterless_type(x), ε)
@memoize maybe_convert(x, ε) = ε
ChainRulesCore.@non_differentiable maybe_convert(x, ε)

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{1}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
ε = maybe_convert(x, ε_)
return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* (h / 2)
end

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{2}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
ε = maybe_convert(x, ε_)
return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* h^2
end

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{3}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
ε = maybe_convert(x, ε_)
return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) -
phi(x .- 2 .* ε, θ)) .* h^3 ./ 2
end

@inline function finitediff(phi, x, θ, ε_::AbstractVector{T}, h::T, ::Val{4}) where {T<:AbstractFloat}
ε = maybe_adapt(x, ε_)
ε = maybe_convert(x, ε_)
return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ 6 .* phi(x, θ) .-
4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* h^4
end
Expand Down
21 changes: 0 additions & 21 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,27 +288,6 @@ rng = Random.default_rng()
@test_nowarn AdaptiveTraining((θ, p) -> p, 5)
@test_nowarn AdaptiveTraining(((θ, p) -> p, (θ, p) -> θ), (3, 4, 5))
end

#=
@testset "GPU" begin
@testset "single model" begin
pinn = PINN(DiscreteFourierFeature(2,1,2,2))
pinn = pinn |> gpu
@test getdata(pinn.init_params) isa CuArray
phi = pinn.phi
@test phi.state.weight isa CuArray
end
@testset "multiple models" begin
pinn = PINN(u = DiscreteFourierFeature(2,1,2,2),
v = DiscreteFourierFeature(2,1,2,2))
pinn = pinn |> gpu
@test getdata(pinn.init_params) isa CuArray
phi = pinn.phi
@test phi.u.state.weight isa CuArray
end
end
=#
end

@testset "BetaSampler" begin include("betasampler.jl") end
Expand Down

0 comments on commit 6edc431

Please sign in to comment.