Skip to content

Commit

Permalink
add additional input in sensitivity api
Browse files Browse the repository at this point in the history
  • Loading branch information
m-bossart committed May 2, 2024
1 parent 94e9f65 commit 37d9ff9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
6 changes: 3 additions & 3 deletions src/base/sensitivity_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Gets a function for taking gradients with respect to parameters.
# Arguments
- `sim::Simulation` : Initialized simulation object
- `device_parameter_pairs::Vector{Tuple{String, Type{T}, Symbol}}` : Tuple used to identify the parameter, via the device name, as a `String`, the type of the Device or DynamicComponent, and the parameter as a `Symbol`.
- `f::function` : User provided function with input a simulation and output a scalar value. This function can include executing the simulation and post-processing of results
- `f::function` : User provided function with two inputs: a simulation and an additional input which can be used for data (```f(sim::Simulation, data::Any)```) The output must be a scalar value. This function can include executing the simulation and post-processing of results.
# Example
```julia
Expand All @@ -36,12 +36,12 @@ function get_parameter_sensitivity_function!(sim, device_param_pairs, f)
@assert sim.status == BUILT
sim.initialize_level = sim_level
sim.enable_sensitivity = true
sensitivity_function = (p) ->
sensitivity_function = (p, data) ->
begin
sim.inputs = deepcopy(sim.inputs_init)
set_parameters!(sim, indices, p)
reset!(sim)
return f(sim)
return f(sim, data)
end
return sensitivity_function
end
Expand Down
38 changes: 27 additions & 11 deletions test/test_case_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ ieee_9bus_sys = build_system(PSIDTestSystems, "psid_test_ieee_9bus")

s_device = get_component(Source, omib_sys, "InfBus")
s_change = SourceBusVoltageChange(1.0, s_device, :V_ref, 1.02)
#using PlotlyJS #for debug only
#using PlotlyJS

#NOTES ON SENSITIVITY ALGORITHMS FROM SCIMLSENSITIVITY
#ReverseDiffVJP and EnzymeVJP only options compatible with Hybrid DEs (DEs with callbacks)
Expand All @@ -40,7 +40,7 @@ s_change = SourceBusVoltageChange(1.0, s_device, :V_ref, 1.02)
t, δ_gt = get_state_series(res, ("generator-102-1", ))
for solver in [FBDF(), Rodas5(), QNDF()]
for tol in [1e-6, 1e-9]
function f(sim)
function f(sim, δ_gt)
execute!(
sim,
solver;
Expand All @@ -62,8 +62,16 @@ s_change = SourceBusVoltageChange(1.0, s_device, :V_ref, 1.02)
)
#p = PSID.get_parameter_sensitivity_values(sim, [("generator-102-1", SingleMass, :H)])
#@error Zygote.gradient(g, [3.15])[1][1]
@test isapprox(Zygote.gradient(g, [3.14])[1][1], -8.0, atol = 1.0)
@test isapprox(Zygote.gradient(g, [3.15])[1][1], 8.0, atol = 1.0)
@test isapprox(
Zygote.gradient((p) -> g(p, δ_gt), [3.14])[1][1],
-8.0,
atol = 1.0,
)
@test isapprox(
Zygote.gradient((p) -> g(p, δ_gt), [3.15])[1][1],
8.0,
atol = 1.0,
)
end
end
finally
Expand All @@ -88,7 +96,7 @@ end
res = read_results(sim)
t, δ_gt = get_state_series(res, ("generator-102-1", ))

function f(sim)
function f(sim, δ_gt)
execute!(
sim,
FBDF(; autodiff = true);
Expand Down Expand Up @@ -117,7 +125,7 @@ end
#push!(loss_values, l)
return false
end
optfun = OptimizationFunction((u, _) -> g(u), Optimization.AutoZygote())
optfun = OptimizationFunction((u, _) -> g(u, δ_gt), Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, [3.14])
sol = Optimization.solve(
optprob,
Expand Down Expand Up @@ -185,7 +193,7 @@ end
MethodOfSteps(QNDF(; autodiff = true)),
]
for tol in [1e-6]
function f(sim)
function f(sim, δ_gt)
execute!(
sim,
solver;
Expand All @@ -207,8 +215,16 @@ end
)
#p = PSID.get_parameter_sensitivity_values(sim, [("generator-102-1", SingleMass, :H)])
#display(Zygote.gradient(g, [3.14]))
@test isapprox(Zygote.gradient(g, [3.14])[1][1], -10.0, atol = 1.0)
@test isapprox(Zygote.gradient(g, [3.15])[1][1], 10.0, atol = 1.0)
@test isapprox(
Zygote.gradient((p) -> g(p, δ_gt), [3.14])[1][1],
-10.0,
atol = 1.0,
)
@test isapprox(
Zygote.gradient((p) -> g(p, δ_gt), [3.15])[1][1],
10.0,
atol = 1.0,
)
end
end
finally
Expand Down Expand Up @@ -263,7 +279,7 @@ end
res = read_results(sim)
t, δ_gt = get_state_series(res, ("generator-102-1", ))

function f(sim)
function f(sim, δ_gt)
execute!(
sim,
MethodOfSteps(Rodas5(; autodiff = true));
Expand Down Expand Up @@ -294,7 +310,7 @@ end
#push!(loss_values, l)
return false
end
optfun = OptimizationFunction((u, _) -> g(u), Optimization.AutoZygote())
optfun = OptimizationFunction((u, _) -> g(u, δ_gt), Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, [3.14])
sol = Optimization.solve(
optprob,
Expand Down

0 comments on commit 37d9ff9

Please sign in to comment.