Skip to content

Commit

Permalink
docs: hotfix for functors breakage
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 16, 2024
1 parent 0c45cf2 commit eecae90
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions examples/OptimizationIntegration/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,14 @@ end
# the `gdev` device to move the data to the GPU on each iteration.

# By default `gdev` will move all objects to the GPU. But we don't want to move the time
# vector to the GPU. So we will wrap it in a struct.
# vector to the GPU. So we will wrap it in a struct and mark it as a leaf using
# MLDataDevices.isleaf
struct TimeWrapper{T}
t::T
end

MLDataDevices.isleaf(::TimeWrapper) = true

Base.length(t::TimeWrapper) = length(t.t)

Base.getindex(t::TimeWrapper, i) = TimeWrapper(t.t[i])
Expand Down Expand Up @@ -103,7 +106,8 @@ function train_model(dataloader)
u0 = u_batch[:, 1]
dudt(u, p, t) = smodel(u, p)
prob = ODEProblem(dudt, u0, (t_batch[1], t_batch[end]), θ)
pred = convert(AbstractArray, solve(prob, Tsit5(); saveat=t_batch))
sol = solve(prob, Tsit5(); sensealg=InterpolatingAdjoint(), saveat=t_batch)
pred = stack(sol.u)
return MSELoss()(pred, u_batch)
end

Expand Down

0 comments on commit eecae90

Please sign in to comment.