Skip to content

Commit

Permalink
Update SophonLuxCUDAExt.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
YichengDWu committed Nov 25, 2023
1 parent 6b8a28c commit 8558045
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ext/SophonLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module SophonLuxCUDAExt

using Lux, LuxCUDA, Sophon, ModelingToolkit
using Lux, LuxCUDA, Sophon, Optimization

function (::LuxCUDADevice)(prob::Union{ModelingToolkit.PDESystem, Sophon.PDESystem})
function (::LuxCUDADevice)(prob::OptimizationProblem)
u0 = adapt(CuArray, prob.u0)
p = [adapt(CuArray, prob.p[i]) for i in 1:length(prob.p)]
prob = remake(prob, u0=u0, p=p)
Expand Down

0 comments on commit 8558045

Please sign in to comment.