diff --git a/ext/SophonLuxCUDAExt.jl b/ext/SophonLuxCUDAExt.jl index b309ef28..6705bea6 100644 --- a/ext/SophonLuxCUDAExt.jl +++ b/ext/SophonLuxCUDAExt.jl @@ -1,8 +1,8 @@ module SophonLuxCUDAExt -using Lux, LuxCUDA, Sophon, ModelingToolkit +using Lux, LuxCUDA, Sophon, Optimization -function (::LuxCUDADevice)(prob::Union{ModelingToolkit.PDESystem, Sophon.PDESystem}) +function (::LuxCUDADevice)(prob::OptimizationProblem) u0 = adapt(CuArray, prob.u0) p = [adapt(CuArray, prob.p[i]) for i in 1:length(prob.p)] prob = remake(prob, u0=u0, p=p)