diff --git a/examples/box.jl b/examples/box.jl index 51e93bb6ab..535fa3d080 100644 --- a/examples/box.jl +++ b/examples/box.jl @@ -72,7 +72,6 @@ using Enzyme # once. struct ModelParameters - ## handy to have constants day::Float64 year::Float64 @@ -193,56 +192,43 @@ function compute_density(state, params) end ## lastly, a function that takes one step forward -## Input: state_now = [T1(t), T2(t), ..., S3(t)] -## state_old = [T1(t-dt), ..., S3(t-dt)] +## Input: T_now = [T1(t), T2(t), ..., S3(t)] +## T_old = [T1(t-dt), ..., S3(t-dt)] ## u = transport(t) ## dt = time step -## Output: state_new = [T1(t+dt), ..., S3(t+dt)] +## Output: T_new = [T1(t+dt), ..., S3(t+dt)] + +function compute_update(T_now, T_old, u, params, dt) + Ṫ_now = zeros(6) + T_new = zeros(6) -function compute_update(state_now, state_old, u, params, dt) - dstate_now_dt = zeros(6) - state_new = zeros(6) + V = params.boxvol + Tstar = params.Tstar ## first computing the time derivatives of the various temperatures and salinities if u > 0 - dstate_now_dt[1] = - u * (state_now[3] - state_now[1]) / params.boxvol[1] + - params.gamma * (params.Tstar[1] - state_now[1]) - dstate_now_dt[2] = - u * (state_now[1] - state_now[2]) / params.boxvol[2] + - params.gamma * (params.Tstar[2] - state_now[2]) - dstate_now_dt[3] = u * (state_now[2] - state_now[3]) / params.boxvol[3] - - dstate_now_dt[4] = - u * (state_now[6] - state_now[4]) / params.boxvol[1] + - params.FW[1] / params.boxvol[1] - dstate_now_dt[5] = - u * (state_now[4] - state_now[5]) / params.boxvol[2] + - params.FW[2] / params.boxvol[2] - dstate_now_dt[6] = u * (state_now[5] - state_now[6]) / params.boxvol[3] + Ṫ_now[1] = u * (T_now[3] - T_now[1]) / V[1] + params.gamma * (Tstar[1] - T_now[1]) + Ṫ_now[2] = u * (T_now[1] - T_now[2]) / V[2] + params.gamma * (Tstar[2] - T_now[2]) + Ṫ_now[3] = u * (T_now[2] - T_now[3]) / V[3] + + Ṫ_now[4] = u * (T_now[6] - T_now[4]) / V[1] + params.FW[1] / V[1] + Ṫ_now[5] = u * (T_now[4] - T_now[5]) / V[2] + params.FW[2] / V[2] + Ṫ_now[6] = u * (T_now[5] - T_now[6]) / V[3] elseif u <= 0 - dstate_now_dt[1] = - u * (state_now[2] - state_now[1]) / params.boxvol[1] + - params.gamma * (params.Tstar[1] - state_now[1]) - dstate_now_dt[2] = - u * (state_now[3] - state_now[2]) / params.boxvol[2] + - params.gamma * (params.Tstar[2] - state_now[2]) - dstate_now_dt[3] = u * (state_now[1] - state_now[3]) / params.boxvol[3] - - dstate_now_dt[4] = - u * (state_now[5] - state_now[4]) / params.boxvol[1] + - params.FW[1] / params.boxvol[1] - dstate_now_dt[5] = - u * (state_now[6] - state_now[5]) / params.boxvol[2] + - params.FW[2] / params.boxvol[2] - dstate_now_dt[6] = u * (state_now[4] - state_now[6]) / params.boxvol[3] + Ṫ_now[1] = u * (T_now[2] - T_now[1]) / V[1] + params.gamma * (Tstar[1] - T_now[1]) + Ṫ_now[2] = u * (T_now[3] - T_now[2]) / V[2] + params.gamma * (Tstar[2] - T_now[2]) + Ṫ_now[3] = u * (T_now[1] - T_now[3]) / V[3] + + Ṫ_now[4] = u * (T_now[5] - T_now[4]) / V[1] + params.FW[1] / V[1] + Ṫ_now[5] = u * (T_now[6] - T_now[5]) / V[2] + params.FW[2] / V[2] + Ṫ_now[6] = u * (T_now[4] - T_now[6]) / V[3] end ## update fldnew using a version of Euler's method - state_new .= state_old + 2.0 * dt * dstate_now_dt + T_new .= T_old + 2.0 * dt * Ṫ_now - return state_new + return T_new end # ## Define forward functions @@ -253,28 +239,27 @@ end # Let's start with the standard forward function. This is just going to be used # to store the states at every timestep: -function integrate(state_now, state_old, dt, M, parameters) +function integrate(T_now, T_old, dt, M, parameters) ## Because of the adjoint problem we're setting up, we need to store both the states before ## and after the Robert filter smoother has been applied - states_before = [state_old] - states_after = [state_old] + states_before = [T_old] + states_after = [T_old] for t in 1:M - rho = compute_density(state_now, parameters) + rho = compute_density(T_now, parameters) u = compute_transport(rho, parameters) - state_new = compute_update(state_now, state_old, u, parameters, dt) + T_new = compute_update(T_now, T_old, u, parameters, dt) ## Applying the Robert filter smoother (needed for stability) - state_new_smoothed = - state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old) + T_new_smoothed = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old) - push!(states_after, state_new_smoothed) - push!(states_before, state_new) + push!(states_after, T_new_smoothed) + push!(states_before, T_new) ## cycle the "now, new, old" states - state_old = state_new_smoothed - state_now = state_new + T_old = T_new_smoothed + T_now = T_new end return states_after, states_before @@ -284,18 +269,17 @@ end # that runs a single step of the model forward rather than the whole integration. # This would allow us to save as many of the adjoint variables as we wish when running the adjoint method, # although for the example we'll discuss later we technically only need one of them -function one_step_forward(state_now, state_old, out_now, out_old, parameters, dt) - state_new_smoothed = zeros(6) - rho = compute_density(state_now, parameters) ## compute density +function one_step_forward(T_now, T_old, out_now, out_old, parameters, dt) + T_new_smoothed = zeros(6) + rho = compute_density(T_now, parameters) ## compute density u = compute_transport(rho, parameters) ## compute transport - state_new = compute_update(state_now, state_old, u, parameters, dt) ## compute new state values + T_new = compute_update(T_now, T_old, u, parameters, dt) ## compute new state values ## Robert filter smoother - state_new_smoothed[:] = - state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old) + T_new_smoothed[:] = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old) - out_old[:] = state_new_smoothed - out_now[:] = state_new + out_old[:] = T_new_smoothed + out_now[:] = T_new return nothing end @@ -322,8 +306,8 @@ states_after_smoother, states_before_smoother = integrate( ) ## Run Enzyme one time on `one_step_forward`` -dstate_now = zeros(6) -dstate_old = zeros(6) +dT_now = zeros(6) +dT_old = zeros(6) out_now = zeros(6); dout_now = ones(6); out_old = zeros(6); @@ -332,8 +316,8 @@ dout_old = ones(6); autodiff( Reverse, one_step_forward, - Duplicated([Tbar; Sbar], dstate_now), - Duplicated([Tbar; Sbar], dstate_old), + Duplicated([Tbar; Sbar], dT_now), + Duplicated([Tbar; Sbar], dT_old), Duplicated(out_now, dout_now), Duplicated(out_old, dout_old), Const(parameters), @@ -354,24 +338,24 @@ autodiff( @show states_before_smoother[2], states_after_smoother[2] # we see that Enzyme has computed and stored exactly the output of the -# forward step. Next, let's look at `dstate_now`: +# forward step. Next, let's look at `dT_now`: -@show dstate_now +@show dT_now # Just a few numbers, but this is what makes AD so nice: Enzyme has exactly computed -# the derivative of all outputs with respect to the input `state_now`, evaluated at -# `state_now`, and acted with this gradient on what we gave as `dout_now` (in our case, +# the derivative of all outputs with respect to the input `T_now`, evaluated at +# `T_now`, and acted with this gradient on what we gave as `dout_now` (in our case, # all ones). Using AD notation for reverse mode, this is # ```math # \overline{\text{state\_now}} = \frac{\partial \text{out\_now}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_now} + \frac{\partial \text{out\_old}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_old} # ``` -# We note here that had we initialized `dstate_now` and `dstate_old` as something else, our results +# We note here that had we initialized `dT_now` and `dT_old` as something else, our results # will change. Let's multiply them by two and see what happens. -dstate_now_new = zeros(6) -dstate_old_new = zeros(6) +dT_now_new = zeros(6) +dT_old_new = zeros(6) out_now = zeros(6); dout_now = 2 * ones(6); out_old = zeros(6); @@ -379,22 +363,22 @@ dout_old = 2 * ones(6); autodiff( Reverse, one_step_forward, - Duplicated([Tbar; Sbar], dstate_now_new), - Duplicated([Tbar; Sbar], dstate_old_new), + Duplicated([Tbar; Sbar], dT_now_new), + Duplicated([Tbar; Sbar], dT_old_new), Duplicated(out_now, dout_now), Duplicated(out_old, dout_old), Const(parameters), Const(10 * parameters.day), ) -# Now checking `dstate_now` and `dstate_old` we see +# Now checking `dT_now` and `dT_old` we see -@show dstate_now_new +@show dT_now_new # What happened? Enzyme is actually taking the computed gradient and acting on what we # give as input to `dout_now` and `dout_old`. Checking this, we see -@show 2 * dstate_now +@show 2 * dT_now # and they match the new results. This exactly matches what we'd expect to happen since # we scaled `dout_now` by two. @@ -439,14 +423,14 @@ function compute_adjoint_values( dout_old = [1.0; 0.0; 0.0; 0.0; 0.0; 0.0] for j in M:-1:1 - dstate_now = zeros(6) - dstate_old = zeros(6) + dT_now = zeros(6) + dT_old = zeros(6) autodiff( Reverse, one_step_forward, - Duplicated(states_before_smoother[j], dstate_now), - Duplicated(states_after_smoother[j], dstate_old), + Duplicated(states_before_smoother[j], dT_now), + Duplicated(states_after_smoother[j], dT_old), Duplicated(zeros(6), dout_now), Duplicated(zeros(6), dout_old), Const(parameters), @@ -454,11 +438,11 @@ function compute_adjoint_values( ) if j == 1 - return dstate_now, dstate_old + return dT_now, dT_old end - dout_now = copy(dstate_now) - dout_old = copy(dstate_old) + dout_now = copy(dT_now) + dout_old = copy(dT_old) end end @@ -474,14 +458,14 @@ states_after_smoother, states_before_smoother = integrate( # Next, we pass all of our states to the AD function to get back to the desired derivative: -dstate_now, dstate_old = compute_adjoint_values( +dT_now, dT_old = compute_adjoint_values( states_before_smoother, states_after_smoother, M, parameters ) # And we're done! We were interested in sensitivity to the initial salinity of box -# two, which will live in what we've called `dstate_old`. Checking this value we see +# two, which will live in what we've called `dT_old`. Checking this value we see -@show dstate_old[5] +@show dT_old[5] # As it stands this is just a number, but a good check that Enzyme has computed what we want # is to approximate the derivative with a Taylor series. Specifically, @@ -508,27 +492,26 @@ use_to_check = states_after_smoother[M + 1] diffs = [] step_sizes = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10] for eps in step_sizes - state_new_smoothed = zeros(6) + T_new_smoothed = zeros(6) initial_temperature = [20.0; 1.0; 1.0] perturbed_initial_salinity = [35.5; 34.5; 34.5] + [0.0; eps; 0.0] - state_old = [initial_temperature; perturbed_initial_salinity] - state_now = [20.0; 1.0; 1.0; 35.5; 34.5; 34.5] + T_old = [initial_temperature; perturbed_initial_salinity] + T_now = [20.0; 1.0; 1.0; 35.5; 34.5; 34.5] for t in 1:M - rho = compute_density(state_now, parameters) + rho = compute_density(T_now, parameters) u = compute_transport(rho, parameters) - state_new = compute_update(state_now, state_old, u, parameters, 10 * parameters.day) + T_new = compute_update(T_now, T_old, u, parameters, 10 * parameters.day) - state_new_smoothed[:] = - state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old) + T_new_smoothed[:] = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old) - state_old = state_new_smoothed - state_now = state_new + T_old = T_new_smoothed + T_now = T_new end - push!(diffs, (state_old[1] - use_to_check[1]) / eps) + push!(diffs, (T_old[1] - use_to_check[1]) / eps) end # Then checking what we found the derivative to be analytically: @@ -538,7 +521,7 @@ end # which comes very close to our calculated value. We can go further and check the # percent difference to see -@show abs.(diffs .- dstate_old[5]) ./ dstate_old[5] +@show abs.(diffs .- dT_old[5]) ./ dT_old[5] # and we get down to a percent difference on the order of ``1e^{-5}``, showing Enzyme calculated # the correct derivative. Success! diff --git a/lib/EnzymeCore/ext/AdaptExt.jl b/lib/EnzymeCore/ext/AdaptExt.jl index 2e7a03cabd..e2adabd524 100644 --- a/lib/EnzymeCore/ext/AdaptExt.jl +++ b/lib/EnzymeCore/ext/AdaptExt.jl @@ -18,4 +18,4 @@ function Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) return BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) end -end \ No newline at end of file +end diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index cd16b2f8e8..749378cde7 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -25,4 +25,4 @@ y = @view data[3:end] @test_throws ErrorException Duplicated(d, y) @test_throws ErrorException Active(data) -@test_throws ErrorException Active(d) \ No newline at end of file +@test_throws ErrorException Active(d) diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 6785d650af..9b6e2d6ff8 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -163,4 +163,4 @@ end (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) ) ) -end \ No newline at end of file +end diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 12dafed918..da4827c817 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -900,4 +900,4 @@ function wait_rev(B, orig, gutils, tape) debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) return nothing -end \ No newline at end of file +end diff --git a/src/typeanalysis.jl b/src/typeanalysis.jl index d7050be4d1..8ad34aadf2 100644 --- a/src/typeanalysis.jl +++ b/src/typeanalysis.jl @@ -22,4 +22,4 @@ end # typedef bool (*CustomRuleType)(int /*direction*/, CTypeTree * /*return*/, # CTypeTree * /*args*/, size_t /*numArgs*/, -# LLVMValueRef)=T \ No newline at end of file +# LLVMValueRef)=T