diff --git a/examples/box.jl b/examples/box.jl index 535fa3d080..51e93bb6ab 100644 --- a/examples/box.jl +++ b/examples/box.jl @@ -72,6 +72,7 @@ using Enzyme # once. struct ModelParameters + ## handy to have constants day::Float64 year::Float64 @@ -192,43 +193,56 @@ function compute_density(state, params) end ## lastly, a function that takes one step forward -## Input: T_now = [T1(t), T2(t), ..., S3(t)] -## T_old = [T1(t-dt), ..., S3(t-dt)] +## Input: state_now = [T1(t), T2(t), ..., S3(t)] +## state_old = [T1(t-dt), ..., S3(t-dt)] ## u = transport(t) ## dt = time step -## Output: T_new = [T1(t+dt), ..., S3(t+dt)] - -function compute_update(T_now, T_old, u, params, dt) - Ṫ_now = zeros(6) - T_new = zeros(6) +## Output: state_new = [T1(t+dt), ..., S3(t+dt)] - V = params.boxvol - Tstar = params.Tstar +function compute_update(state_now, state_old, u, params, dt) + dstate_now_dt = zeros(6) + state_new = zeros(6) ## first computing the time derivatives of the various temperatures and salinities if u > 0 - Ṫ_now[1] = u * (T_now[3] - T_now[1]) / V[1] + params.gamma * (Tstar[1] - T_now[1]) - Ṫ_now[2] = u * (T_now[1] - T_now[2]) / V[2] + params.gamma * (Tstar[2] - T_now[2]) - Ṫ_now[3] = u * (T_now[2] - T_now[3]) / V[3] - - Ṫ_now[4] = u * (T_now[6] - T_now[4]) / V[1] + params.FW[1] / V[1] - Ṫ_now[5] = u * (T_now[4] - T_now[5]) / V[2] + params.FW[2] / V[2] - Ṫ_now[6] = u * (T_now[5] - T_now[6]) / V[3] + dstate_now_dt[1] = + u * (state_now[3] - state_now[1]) / params.boxvol[1] + + params.gamma * (params.Tstar[1] - state_now[1]) + dstate_now_dt[2] = + u * (state_now[1] - state_now[2]) / params.boxvol[2] + + params.gamma * (params.Tstar[2] - state_now[2]) + dstate_now_dt[3] = u * (state_now[2] - state_now[3]) / params.boxvol[3] + + dstate_now_dt[4] = + u * (state_now[6] - state_now[4]) / params.boxvol[1] + + params.FW[1] / params.boxvol[1] + dstate_now_dt[5] = + u * (state_now[4] - state_now[5]) / params.boxvol[2] + + params.FW[2] / params.boxvol[2] + dstate_now_dt[6] = u * (state_now[5] - state_now[6]) / params.boxvol[3] elseif u <= 0 - Ṫ_now[1] = u * (T_now[2] - T_now[1]) / V[1] + params.gamma * (Tstar[1] - T_now[1]) - Ṫ_now[2] = u * (T_now[3] - T_now[2]) / V[2] + params.gamma * (Tstar[2] - T_now[2]) - Ṫ_now[3] = u * (T_now[1] - T_now[3]) / V[3] - - Ṫ_now[4] = u * (T_now[5] - T_now[4]) / V[1] + params.FW[1] / V[1] - Ṫ_now[5] = u * (T_now[6] - T_now[5]) / V[2] + params.FW[2] / V[2] - Ṫ_now[6] = u * (T_now[4] - T_now[6]) / V[3] + dstate_now_dt[1] = + u * (state_now[2] - state_now[1]) / params.boxvol[1] + + params.gamma * (params.Tstar[1] - state_now[1]) + dstate_now_dt[2] = + u * (state_now[3] - state_now[2]) / params.boxvol[2] + + params.gamma * (params.Tstar[2] - state_now[2]) + dstate_now_dt[3] = u * (state_now[1] - state_now[3]) / params.boxvol[3] + + dstate_now_dt[4] = + u * (state_now[5] - state_now[4]) / params.boxvol[1] + + params.FW[1] / params.boxvol[1] + dstate_now_dt[5] = + u * (state_now[6] - state_now[5]) / params.boxvol[2] + + params.FW[2] / params.boxvol[2] + dstate_now_dt[6] = u * (state_now[4] - state_now[6]) / params.boxvol[3] end ## update fldnew using a version of Euler's method - T_new .= T_old + 2.0 * dt * Ṫ_now + state_new .= state_old + 2.0 * dt * dstate_now_dt - return T_new + return state_new end # ## Define forward functions @@ -239,27 +253,28 @@ end # Let's start with the standard forward function. This is just going to be used # to store the states at every timestep: -function integrate(T_now, T_old, dt, M, parameters) +function integrate(state_now, state_old, dt, M, parameters) ## Because of the adjoint problem we're setting up, we need to store both the states before ## and after the Robert filter smoother has been applied - states_before = [T_old] - states_after = [T_old] + states_before = [state_old] + states_after = [state_old] for t in 1:M - rho = compute_density(T_now, parameters) + rho = compute_density(state_now, parameters) u = compute_transport(rho, parameters) - T_new = compute_update(T_now, T_old, u, parameters, dt) + state_new = compute_update(state_now, state_old, u, parameters, dt) ## Applying the Robert filter smoother (needed for stability) - T_new_smoothed = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old) + state_new_smoothed = + state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old) - push!(states_after, T_new_smoothed) - push!(states_before, T_new) + push!(states_after, state_new_smoothed) + push!(states_before, state_new) ## cycle the "now, new, old" states - T_old = T_new_smoothed - T_now = T_new + state_old = state_new_smoothed + state_now = state_new end return states_after, states_before @@ -269,17 +284,18 @@ end # that runs a single step of the model forward rather than the whole integration. # This would allow us to save as many of the adjoint variables as we wish when running the adjoint method, # although for the example we'll discuss later we technically only need one of them -function one_step_forward(T_now, T_old, out_now, out_old, parameters, dt) - T_new_smoothed = zeros(6) - rho = compute_density(T_now, parameters) ## compute density +function one_step_forward(state_now, state_old, out_now, out_old, parameters, dt) + state_new_smoothed = zeros(6) + rho = compute_density(state_now, parameters) ## compute density u = compute_transport(rho, parameters) ## compute transport - T_new = compute_update(T_now, T_old, u, parameters, dt) ## compute new state values + state_new = compute_update(state_now, state_old, u, parameters, dt) ## compute new state values ## Robert filter smoother - T_new_smoothed[:] = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old) + state_new_smoothed[:] = + state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old) - out_old[:] = T_new_smoothed - out_now[:] = T_new + out_old[:] = state_new_smoothed + out_now[:] = state_new return nothing end @@ -306,8 +322,8 @@ states_after_smoother, states_before_smoother = integrate( ) ## Run Enzyme one time on `one_step_forward`` -dT_now = zeros(6) -dT_old = zeros(6) +dstate_now = zeros(6) +dstate_old = zeros(6) out_now = zeros(6); dout_now = ones(6); out_old = zeros(6); @@ -316,8 +332,8 @@ dout_old = ones(6); autodiff( Reverse, one_step_forward, - Duplicated([Tbar; Sbar], dT_now), - Duplicated([Tbar; Sbar], dT_old), + Duplicated([Tbar; Sbar], dstate_now), + Duplicated([Tbar; Sbar], dstate_old), Duplicated(out_now, dout_now), Duplicated(out_old, dout_old), Const(parameters), @@ -338,24 +354,24 @@ autodiff( @show states_before_smoother[2], states_after_smoother[2] # we see that Enzyme has computed and stored exactly the output of the -# forward step. Next, let's look at `dT_now`: +# forward step. Next, let's look at `dstate_now`: -@show dT_now +@show dstate_now # Just a few numbers, but this is what makes AD so nice: Enzyme has exactly computed -# the derivative of all outputs with respect to the input `T_now`, evaluated at -# `T_now`, and acted with this gradient on what we gave as `dout_now` (in our case, +# the derivative of all outputs with respect to the input `state_now`, evaluated at +# `state_now`, and acted with this gradient on what we gave as `dout_now` (in our case, # all ones). Using AD notation for reverse mode, this is # ```math # \overline{\text{state\_now}} = \frac{\partial \text{out\_now}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_now} + \frac{\partial \text{out\_old}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_old} # ``` -# We note here that had we initialized `dT_now` and `dT_old` as something else, our results +# We note here that had we initialized `dstate_now` and `dstate_old` as something else, our results # will change. Let's multiply them by two and see what happens. -dT_now_new = zeros(6) -dT_old_new = zeros(6) +dstate_now_new = zeros(6) +dstate_old_new = zeros(6) out_now = zeros(6); dout_now = 2 * ones(6); out_old = zeros(6); @@ -363,22 +379,22 @@ dout_old = 2 * ones(6); autodiff( Reverse, one_step_forward, - Duplicated([Tbar; Sbar], dT_now_new), - Duplicated([Tbar; Sbar], dT_old_new), + Duplicated([Tbar; Sbar], dstate_now_new), + Duplicated([Tbar; Sbar], dstate_old_new), Duplicated(out_now, dout_now), Duplicated(out_old, dout_old), Const(parameters), Const(10 * parameters.day), ) -# Now checking `dT_now` and `dT_old` we see +# Now checking `dstate_now` and `dstate_old` we see -@show dT_now_new +@show dstate_now_new # What happened? Enzyme is actually taking the computed gradient and acting on what we # give as input to `dout_now` and `dout_old`. Checking this, we see -@show 2 * dT_now +@show 2 * dstate_now # and they match the new results. This exactly matches what we'd expect to happen since # we scaled `dout_now` by two. @@ -423,14 +439,14 @@ function compute_adjoint_values( dout_old = [1.0; 0.0; 0.0; 0.0; 0.0; 0.0] for j in M:-1:1 - dT_now = zeros(6) - dT_old = zeros(6) + dstate_now = zeros(6) + dstate_old = zeros(6) autodiff( Reverse, one_step_forward, - Duplicated(states_before_smoother[j], dT_now), - Duplicated(states_after_smoother[j], dT_old), + Duplicated(states_before_smoother[j], dstate_now), + Duplicated(states_after_smoother[j], dstate_old), Duplicated(zeros(6), dout_now), Duplicated(zeros(6), dout_old), Const(parameters), @@ -438,11 +454,11 @@ function compute_adjoint_values( ) if j == 1 - return dT_now, dT_old + return dstate_now, dstate_old end - dout_now = copy(dT_now) - dout_old = copy(dT_old) + dout_now = copy(dstate_now) + dout_old = copy(dstate_old) end end @@ -458,14 +474,14 @@ states_after_smoother, states_before_smoother = integrate( # Next, we pass all of our states to the AD function to get back to the desired derivative: -dT_now, dT_old = compute_adjoint_values( +dstate_now, dstate_old = compute_adjoint_values( states_before_smoother, states_after_smoother, M, parameters ) # And we're done! We were interested in sensitivity to the initial salinity of box -# two, which will live in what we've called `dT_old`. Checking this value we see +# two, which will live in what we've called `dstate_old`. Checking this value we see -@show dT_old[5] +@show dstate_old[5] # As it stands this is just a number, but a good check that Enzyme has computed what we want # is to approximate the derivative with a Taylor series. Specifically, @@ -492,26 +508,27 @@ use_to_check = states_after_smoother[M + 1] diffs = [] step_sizes = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10] for eps in step_sizes - T_new_smoothed = zeros(6) + state_new_smoothed = zeros(6) initial_temperature = [20.0; 1.0; 1.0] perturbed_initial_salinity = [35.5; 34.5; 34.5] + [0.0; eps; 0.0] - T_old = [initial_temperature; perturbed_initial_salinity] - T_now = [20.0; 1.0; 1.0; 35.5; 34.5; 34.5] + state_old = [initial_temperature; perturbed_initial_salinity] + state_now = [20.0; 1.0; 1.0; 35.5; 34.5; 34.5] for t in 1:M - rho = compute_density(T_now, parameters) + rho = compute_density(state_now, parameters) u = compute_transport(rho, parameters) - T_new = compute_update(T_now, T_old, u, parameters, 10 * parameters.day) + state_new = compute_update(state_now, state_old, u, parameters, 10 * parameters.day) - T_new_smoothed[:] = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old) + state_new_smoothed[:] = + state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old) - T_old = T_new_smoothed - T_now = T_new + state_old = state_new_smoothed + state_now = state_new end - push!(diffs, (T_old[1] - use_to_check[1]) / eps) + push!(diffs, (state_old[1] - use_to_check[1]) / eps) end # Then checking what we found the derivative to be analytically: @@ -521,7 +538,7 @@ end # which comes very close to our calculated value. We can go further and check the # percent difference to see -@show abs.(diffs .- dT_old[5]) ./ dT_old[5] +@show abs.(diffs .- dstate_old[5]) ./ dstate_old[5] # and we get down to a percent difference on the order of ``1e^{-5}``, showing Enzyme calculated # the correct derivative. Success! diff --git a/lib/EnzymeCore/ext/AdaptExt.jl b/lib/EnzymeCore/ext/AdaptExt.jl index e2adabd524..2e7a03cabd 100644 --- a/lib/EnzymeCore/ext/AdaptExt.jl +++ b/lib/EnzymeCore/ext/AdaptExt.jl @@ -18,4 +18,4 @@ function Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) return BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) end -end +end \ No newline at end of file diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index 749378cde7..cd16b2f8e8 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -25,4 +25,4 @@ y = @view data[3:end] @test_throws ErrorException Duplicated(d, y) @test_throws ErrorException Active(data) -@test_throws ErrorException Active(d) +@test_throws ErrorException Active(d) \ No newline at end of file diff --git a/src/rules/allocrules.jl b/src/rules/allocrules.jl index 9b6e2d6ff8..6785d650af 100644 --- a/src/rules/allocrules.jl +++ b/src/rules/allocrules.jl @@ -163,4 +163,4 @@ end (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef) ) ) -end +end \ No newline at end of file diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index da4827c817..12dafed918 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -900,4 +900,4 @@ function wait_rev(B, orig, gutils, tape) debug_from_orig!(gutils, cal, orig) callconv!(cal, callconv(orig)) return nothing -end +end \ No newline at end of file diff --git a/src/typeanalysis.jl b/src/typeanalysis.jl index 8ad34aadf2..d7050be4d1 100644 --- a/src/typeanalysis.jl +++ b/src/typeanalysis.jl @@ -22,4 +22,4 @@ end # typedef bool (*CustomRuleType)(int /*direction*/, CTypeTree * /*return*/, # CTypeTree * /*args*/, size_t /*numArgs*/, -# LLVMValueRef)=T +# LLVMValueRef)=T \ No newline at end of file