Skip to content

Commit

Permalink
Revert "Change box.jl, reapply formatter"
Browse files Browse the repository at this point in the history
This reverts commit 61d4b3b.
  • Loading branch information
simsurace committed Mar 18, 2024
1 parent 61d4b3b commit a4c62f3
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 83 deletions.
173 changes: 95 additions & 78 deletions examples/box.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ using Enzyme
# once.

struct ModelParameters

## handy to have constants
day::Float64
year::Float64
Expand Down Expand Up @@ -192,43 +193,56 @@ function compute_density(state, params)
end

## lastly, a function that takes one step forward
## Input: T_now = [T1(t), T2(t), ..., S3(t)]
## T_old = [T1(t-dt), ..., S3(t-dt)]
## Input: state_now = [T1(t), T2(t), ..., S3(t)]
## state_old = [T1(t-dt), ..., S3(t-dt)]
## u = transport(t)
## dt = time step
## Output: T_new = [T1(t+dt), ..., S3(t+dt)]

function compute_update(T_now, T_old, u, params, dt)
Ṫ_now = zeros(6)
T_new = zeros(6)
## Output: state_new = [T1(t+dt), ..., S3(t+dt)]

V = params.boxvol
Tstar = params.Tstar
function compute_update(state_now, state_old, u, params, dt)
dstate_now_dt = zeros(6)
state_new = zeros(6)

## first computing the time derivatives of the various temperatures and salinities
if u > 0
Ṫ_now[1] = u * (T_now[3] - T_now[1]) / V[1] + params.gamma * (Tstar[1] - T_now[1])
Ṫ_now[2] = u * (T_now[1] - T_now[2]) / V[2] + params.gamma * (Tstar[2] - T_now[2])
Ṫ_now[3] = u * (T_now[2] - T_now[3]) / V[3]

Ṫ_now[4] = u * (T_now[6] - T_now[4]) / V[1] + params.FW[1] / V[1]
Ṫ_now[5] = u * (T_now[4] - T_now[5]) / V[2] + params.FW[2] / V[2]
Ṫ_now[6] = u * (T_now[5] - T_now[6]) / V[3]
dstate_now_dt[1] =
u * (state_now[3] - state_now[1]) / params.boxvol[1] +
params.gamma * (params.Tstar[1] - state_now[1])
dstate_now_dt[2] =
u * (state_now[1] - state_now[2]) / params.boxvol[2] +
params.gamma * (params.Tstar[2] - state_now[2])
dstate_now_dt[3] = u * (state_now[2] - state_now[3]) / params.boxvol[3]

dstate_now_dt[4] =
u * (state_now[6] - state_now[4]) / params.boxvol[1] +
params.FW[1] / params.boxvol[1]
dstate_now_dt[5] =
u * (state_now[4] - state_now[5]) / params.boxvol[2] +
params.FW[2] / params.boxvol[2]
dstate_now_dt[6] = u * (state_now[5] - state_now[6]) / params.boxvol[3]

elseif u <= 0
Ṫ_now[1] = u * (T_now[2] - T_now[1]) / V[1] + params.gamma * (Tstar[1] - T_now[1])
Ṫ_now[2] = u * (T_now[3] - T_now[2]) / V[2] + params.gamma * (Tstar[2] - T_now[2])
Ṫ_now[3] = u * (T_now[1] - T_now[3]) / V[3]

Ṫ_now[4] = u * (T_now[5] - T_now[4]) / V[1] + params.FW[1] / V[1]
Ṫ_now[5] = u * (T_now[6] - T_now[5]) / V[2] + params.FW[2] / V[2]
Ṫ_now[6] = u * (T_now[4] - T_now[6]) / V[3]
dstate_now_dt[1] =
u * (state_now[2] - state_now[1]) / params.boxvol[1] +
params.gamma * (params.Tstar[1] - state_now[1])
dstate_now_dt[2] =
u * (state_now[3] - state_now[2]) / params.boxvol[2] +
params.gamma * (params.Tstar[2] - state_now[2])
dstate_now_dt[3] = u * (state_now[1] - state_now[3]) / params.boxvol[3]

dstate_now_dt[4] =
u * (state_now[5] - state_now[4]) / params.boxvol[1] +
params.FW[1] / params.boxvol[1]
dstate_now_dt[5] =
u * (state_now[6] - state_now[5]) / params.boxvol[2] +
params.FW[2] / params.boxvol[2]
dstate_now_dt[6] = u * (state_now[4] - state_now[6]) / params.boxvol[3]
end

## update fldnew using a version of Euler's method
T_new .= T_old + 2.0 * dt * Ṫ_now
state_new .= state_old + 2.0 * dt * dstate_now_dt

return T_new
return state_new
end

# ## Define forward functions
Expand All @@ -239,27 +253,28 @@ end
# Let's start with the standard forward function. This is just going to be used
# to store the states at every timestep:

function integrate(T_now, T_old, dt, M, parameters)
function integrate(state_now, state_old, dt, M, parameters)

## Because of the adjoint problem we're setting up, we need to store both the states before
## and after the Robert filter smoother has been applied
states_before = [T_old]
states_after = [T_old]
states_before = [state_old]
states_after = [state_old]

for t in 1:M
rho = compute_density(T_now, parameters)
rho = compute_density(state_now, parameters)
u = compute_transport(rho, parameters)
T_new = compute_update(T_now, T_old, u, parameters, dt)
state_new = compute_update(state_now, state_old, u, parameters, dt)

## Applying the Robert filter smoother (needed for stability)
T_new_smoothed = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old)
state_new_smoothed =
state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old)

push!(states_after, T_new_smoothed)
push!(states_before, T_new)
push!(states_after, state_new_smoothed)
push!(states_before, state_new)

## cycle the "now, new, old" states
T_old = T_new_smoothed
T_now = T_new
state_old = state_new_smoothed
state_now = state_new
end

return states_after, states_before
Expand All @@ -269,17 +284,18 @@ end
# that runs a single step of the model forward rather than the whole integration.
# This would allow us to save as many of the adjoint variables as we wish when running the adjoint method,
# although for the example we'll discuss later we technically only need one of them
function one_step_forward(T_now, T_old, out_now, out_old, parameters, dt)
T_new_smoothed = zeros(6)
rho = compute_density(T_now, parameters) ## compute density
function one_step_forward(state_now, state_old, out_now, out_old, parameters, dt)
state_new_smoothed = zeros(6)
rho = compute_density(state_now, parameters) ## compute density
u = compute_transport(rho, parameters) ## compute transport
T_new = compute_update(T_now, T_old, u, parameters, dt) ## compute new state values
state_new = compute_update(state_now, state_old, u, parameters, dt) ## compute new state values

## Robert filter smoother
T_new_smoothed[:] = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old)
state_new_smoothed[:] =
state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old)

out_old[:] = T_new_smoothed
out_now[:] = T_new
out_old[:] = state_new_smoothed
out_now[:] = state_new

return nothing
end
Expand All @@ -306,8 +322,8 @@ states_after_smoother, states_before_smoother = integrate(
)

## Run Enzyme one time on `one_step_forward``
dT_now = zeros(6)
dT_old = zeros(6)
dstate_now = zeros(6)
dstate_old = zeros(6)
out_now = zeros(6);
dout_now = ones(6);
out_old = zeros(6);
Expand All @@ -316,8 +332,8 @@ dout_old = ones(6);
autodiff(
Reverse,
one_step_forward,
Duplicated([Tbar; Sbar], dT_now),
Duplicated([Tbar; Sbar], dT_old),
Duplicated([Tbar; Sbar], dstate_now),
Duplicated([Tbar; Sbar], dstate_old),
Duplicated(out_now, dout_now),
Duplicated(out_old, dout_old),
Const(parameters),
Expand All @@ -338,47 +354,47 @@ autodiff(
@show states_before_smoother[2], states_after_smoother[2]

# we see that Enzyme has computed and stored exactly the output of the
# forward step. Next, let's look at `dT_now`:
# forward step. Next, let's look at `dstate_now`:

@show dT_now
@show dstate_now

# Just a few numbers, but this is what makes AD so nice: Enzyme has exactly computed
# the derivative of all outputs with respect to the input `T_now`, evaluated at
# `T_now`, and acted with this gradient on what we gave as `dout_now` (in our case,
# the derivative of all outputs with respect to the input `state_now`, evaluated at
# `state_now`, and acted with this gradient on what we gave as `dout_now` (in our case,
# all ones). Using AD notation for reverse mode, this is

# ```math
# \overline{\text{state\_now}} = \frac{\partial \text{out\_now}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_now} + \frac{\partial \text{out\_old}}{\partial \text{state\_now}}\right|_\text{state\_now} \overline{\text{out\_old}
# ```

# We note here that had we initialized `dT_now` and `dT_old` as something else, our results
# We note here that had we initialized `dstate_now` and `dstate_old` as something else, our results
# will change. Let's multiply them by two and see what happens.

dT_now_new = zeros(6)
dT_old_new = zeros(6)
dstate_now_new = zeros(6)
dstate_old_new = zeros(6)
out_now = zeros(6);
dout_now = 2 * ones(6);
out_old = zeros(6);
dout_old = 2 * ones(6);
autodiff(
Reverse,
one_step_forward,
Duplicated([Tbar; Sbar], dT_now_new),
Duplicated([Tbar; Sbar], dT_old_new),
Duplicated([Tbar; Sbar], dstate_now_new),
Duplicated([Tbar; Sbar], dstate_old_new),
Duplicated(out_now, dout_now),
Duplicated(out_old, dout_old),
Const(parameters),
Const(10 * parameters.day),
)

# Now checking `dT_now` and `dT_old` we see
# Now checking `dstate_now` and `dstate_old` we see

@show dT_now_new
@show dstate_now_new

# What happened? Enzyme is actually taking the computed gradient and acting on what we
# give as input to `dout_now` and `dout_old`. Checking this, we see

@show 2 * dT_now
@show 2 * dstate_now

# and they match the new results. This exactly matches what we'd expect to happen since
# we scaled `dout_now` by two.
Expand Down Expand Up @@ -423,26 +439,26 @@ function compute_adjoint_values(
dout_old = [1.0; 0.0; 0.0; 0.0; 0.0; 0.0]

for j in M:-1:1
dT_now = zeros(6)
dT_old = zeros(6)
dstate_now = zeros(6)
dstate_old = zeros(6)

autodiff(
Reverse,
one_step_forward,
Duplicated(states_before_smoother[j], dT_now),
Duplicated(states_after_smoother[j], dT_old),
Duplicated(states_before_smoother[j], dstate_now),
Duplicated(states_after_smoother[j], dstate_old),
Duplicated(zeros(6), dout_now),
Duplicated(zeros(6), dout_old),
Const(parameters),
Const(10 * parameters.day),
)

if j == 1
return dT_now, dT_old
return dstate_now, dstate_old
end

dout_now = copy(dT_now)
dout_old = copy(dT_old)
dout_now = copy(dstate_now)
dout_old = copy(dstate_old)
end
end

Expand All @@ -458,14 +474,14 @@ states_after_smoother, states_before_smoother = integrate(

# Next, we pass all of our states to the AD function to get back to the desired derivative:

dT_now, dT_old = compute_adjoint_values(
dstate_now, dstate_old = compute_adjoint_values(
states_before_smoother, states_after_smoother, M, parameters
)

# And we're done! We were interested in sensitivity to the initial salinity of box
# two, which will live in what we've called `dT_old`. Checking this value we see
# two, which will live in what we've called `dstate_old`. Checking this value we see

@show dT_old[5]
@show dstate_old[5]

# As it stands this is just a number, but a good check that Enzyme has computed what we want
# is to approximate the derivative with a Taylor series. Specifically,
Expand All @@ -492,26 +508,27 @@ use_to_check = states_after_smoother[M + 1]
diffs = []
step_sizes = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10]
for eps in step_sizes
T_new_smoothed = zeros(6)
state_new_smoothed = zeros(6)

initial_temperature = [20.0; 1.0; 1.0]
perturbed_initial_salinity = [35.5; 34.5; 34.5] + [0.0; eps; 0.0]

T_old = [initial_temperature; perturbed_initial_salinity]
T_now = [20.0; 1.0; 1.0; 35.5; 34.5; 34.5]
state_old = [initial_temperature; perturbed_initial_salinity]
state_now = [20.0; 1.0; 1.0; 35.5; 34.5; 34.5]

for t in 1:M
rho = compute_density(T_now, parameters)
rho = compute_density(state_now, parameters)
u = compute_transport(rho, parameters)
T_new = compute_update(T_now, T_old, u, parameters, 10 * parameters.day)
state_new = compute_update(state_now, state_old, u, parameters, 10 * parameters.day)

T_new_smoothed[:] = T_now + parameters.rf_coeff * (T_new - 2.0 * T_now + T_old)
state_new_smoothed[:] =
state_now + parameters.rf_coeff * (state_new - 2.0 * state_now + state_old)

T_old = T_new_smoothed
T_now = T_new
state_old = state_new_smoothed
state_now = state_new
end

push!(diffs, (T_old[1] - use_to_check[1]) / eps)
push!(diffs, (state_old[1] - use_to_check[1]) / eps)
end

# Then checking what we found the derivative to be analytically:
Expand All @@ -521,7 +538,7 @@ end
# which comes very close to our calculated value. We can go further and check the
# percent difference to see

@show abs.(diffs .- dT_old[5]) ./ dT_old[5]
@show abs.(diffs .- dstate_old[5]) ./ dstate_old[5]

# and we get down to a percent difference on the order of ``1e^{-5}``, showing Enzyme calculated
# the correct derivative. Success!
2 changes: 1 addition & 1 deletion lib/EnzymeCore/ext/AdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ function Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed)
return BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval))
end

end
end
2 changes: 1 addition & 1 deletion lib/EnzymeCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ y = @view data[3:end]
@test_throws ErrorException Duplicated(d, y)

@test_throws ErrorException Active(data)
@test_throws ErrorException Active(d)
@test_throws ErrorException Active(d)
2 changes: 1 addition & 1 deletion src/rules/allocrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,4 @@ end
(LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)
)
)
end
end
2 changes: 1 addition & 1 deletion src/rules/parallelrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -900,4 +900,4 @@ function wait_rev(B, orig, gutils, tape)
debug_from_orig!(gutils, cal, orig)
callconv!(cal, callconv(orig))
return nothing
end
end
2 changes: 1 addition & 1 deletion src/typeanalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ end

# typedef bool (*CustomRuleType)(int /*direction*/, CTypeTree * /*return*/,
# CTypeTree * /*args*/, size_t /*numArgs*/,
# LLVMValueRef)=T
# LLVMValueRef)=T

0 comments on commit a4c62f3

Please sign in to comment.