-
-
Notifications
You must be signed in to change notification settings - Fork 204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allowing a function to be called multiple times with different inputs #627
base: master
Are you sure you want to change the base?
Changes from 43 commits
ed280a6
1c0f0d0
63e0ddc
4e7b1b8
8a612dc
83e2475
13df657
b17f92b
e885f45
74a2749
d0df2a3
41a75f6
c5d9960
64b56de
55fa847
b7e3d7a
fb199e4
2572dbf
3e36fbe
d115eae
abb85a8
c7d3dc5
d9da546
18338d3
cee31db
ea1c3b0
be3abf1
308454c
09b6cf6
6e4206b
55d142a
a9b6b47
f815469
5889a1b
424a7ef
d581889
edcb1a7
b07ae13
7a1e0b5
7f527c7
530d50e
48c8b04
e4f1536
238b315
44f3a28
fc7d36c
550ab40
4dcf2a8
00f07fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,35 +10,81 @@ Take expressions in the form: | |
|
||
to | ||
|
||
:((cord, θ, phi, derivative, u)->begin | ||
#= ... =# | ||
#= ... =# | ||
begin | ||
(θ1, θ2) = (θ[1:33], θ"[34:66]) | ||
(phi1, phi2) = (phi[1], phi[2]) | ||
let (x, y) = (cord[1], cord[2]) | ||
[(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, θ1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, θ2))) - 0, | ||
(+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, θ2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, θ1))) - 0] | ||
end | ||
end | ||
end) | ||
|
||
for Flux.Chain, and | ||
|
||
:((cord, θ, phi, derivative, u)->begin | ||
#= ... =# | ||
#= ... =# | ||
begin | ||
(u1, u2) = (θ.depvar.u1, θ.depvar.u2) | ||
(phi1, phi2) = (phi[1], phi[2]) | ||
let (x, y) = (cord[1], cord[2]) | ||
[(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, u1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, u1))) - 0, | ||
(+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, u2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, u2))) - 0] | ||
end | ||
end | ||
end) | ||
|
||
for Lux.AbstractExplicitLayer | ||
:((cord, θ, phi, derivative, integral, u, p)->begin | ||
#= ... =# | ||
#= ... =# | ||
begin | ||
(θ1, θ2) = (θ[1:205], θ[206:410]) | ||
(phi1, phi2) = (phi[1], phi[2]) | ||
let (x, y) = (cord[[1], :], cord[[2], :]) | ||
begin | ||
cord2 = vcat(x, y) | ||
cord1 = vcat(x, y) | ||
end | ||
(+).((*).(4, derivative(phi2, u, _vcat(x, y), [[0.0, ε]], 1, θ2)), derivative(phi1, u, _vcat(x, y), [[ε, 0.0]], 1, θ1)) .- 0 | ||
end | ||
end | ||
end) | ||
|
||
for Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0, and | ||
|
||
:((cord, θ, phi, derivative, integral, u, p)->begin | ||
#= ... =# | ||
#= ... =# | ||
begin | ||
(θ1, θ2) = (θ[1:205], θ[206:410]) | ||
(phi1, phi2) = (phi[1], phi[2]) | ||
let (x, y) = (cord[[1], :], cord[[2], :]) | ||
begin | ||
cord2 = vcat(x, y) | ||
cord1 = vcat(x, y) | ||
end | ||
(+).((*).(9, derivative(phi1, u, _vcat(x, y), [[0.0, ε]], 1, θ1)), derivative(phi2, u, _vcat(x, y), [[ε, 0.0]], 1, θ2)) .- 0 | ||
end | ||
end | ||
end) | ||
|
||
for Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0 (i.e., separate loss functions are created for each equation) | ||
|
||
with Flux.Chain; and | ||
|
||
:((cord, θ, phi, derivative, integral, u, p)->begin | ||
#= ... =# | ||
#= ... =# | ||
begin | ||
(θ1, θ2) = (θ.depvar.u1, θ.depvar.u2) | ||
(phi1, phi2) = (phi[1], phi[2]) | ||
let (x, y) = (cord[[1], :], cord[[2], :]) | ||
begin | ||
cord2 = vcat(x, y) | ||
cord1 = vcat(x, y) | ||
end | ||
(+).((*).(4, derivative(phi2, u, _vcat(x, y), [[0.0, ε]], 1, θ2)), derivative(phi1, u, _vcat(x, y), [[ε, 0.0]], 1, θ1)) .- 0 | ||
end | ||
end | ||
end) | ||
|
||
for Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0 and | ||
|
||
:((cord, θ, phi, derivative, integral, u, p)->begin | ||
#= ... =# | ||
#= ... =# | ||
begin | ||
(θ1, θ2) = (θ.depvar.u1, θ.depvar.u2) | ||
(phi1, phi2) = (phi[1], phi[2]) | ||
let (x, y) = (cord[[1], :], cord[[2], :]) | ||
begin | ||
cord2 = vcat(x, y) | ||
cord1 = vcat(x, y) | ||
end | ||
(+).((*).(9, derivative(phi1, u, _vcat(x, y), [[0.0, ε]], 1, θ1)), derivative(phi2, u, _vcat(x, y), [[ε, 0.0]], 1, θ2)) .- 0 | ||
end | ||
end | ||
end) | ||
|
||
for Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0 | ||
|
||
with Lux.Chain | ||
""" | ||
function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; | ||
eq_params = SciMLBase.NullParameters(), | ||
|
@@ -61,7 +107,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; | |
this_eq_pair = pair(eqs, depvars, dict_depvars, dict_depvar_input) | ||
this_eq_indvars = unique(vcat(values(this_eq_pair)...)) | ||
else | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the docstring above. What does the code look like now? |
||
this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars], | ||
this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => filter(arg -> !isempty(find_thing_in_expr(integrand, | ||
arg)), | ||
dict_depvar_input[intvars]), | ||
integrating_depvars)) | ||
this_eq_indvars = transformation_vars isa Nothing ? | ||
unique(vcat(values(this_eq_pair)...)) : transformation_vars | ||
|
@@ -142,17 +190,10 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; | |
vcat_expr = Expr(:block, :($(eq_pair_expr...))) | ||
vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename | ||
|
||
if strategy isa QuadratureTraining | ||
indvars_ex = get_indvars_ex(bc_indvars) | ||
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex | ||
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), | ||
build_expr(:tuple, right_arg_pairs)) | ||
else | ||
indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] | ||
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex | ||
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), | ||
build_expr(:tuple, right_arg_pairs)) | ||
end | ||
indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] | ||
left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex | ||
vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), | ||
build_expr(:tuple, right_arg_pairs)) | ||
|
||
if !(dict_transformation_vars isa Nothing) | ||
transformation_expr_ = Expr[] | ||
|
@@ -256,7 +297,7 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D | |
hcat(vec(map(points -> collect(points), | ||
Iterators.product(bc_data...)))...)) | ||
|
||
pde_train_sets = map(pde_args) do bt | ||
pde_train_sets = map(pde_vars) do bt | ||
span = map(b -> get(dict_var_span_, b, b), bt) | ||
_set = adapt(eltypeθ, | ||
hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) | ||
|
@@ -292,7 +333,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, | |
dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) | ||
dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) | ||
|
||
pde_args = get_argument(eqs, dict_indvars, dict_depvars) | ||
pde_args = get_variables(eqs, dict_indvars, dict_depvars) | ||
|
||
pde_lower_bounds = map(pde_args) do pd | ||
span = map(p -> get(dict_lower_bound, p, p), pd) | ||
|
@@ -325,19 +366,33 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str | |
] for d in domains]) | ||
|
||
# pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] | ||
pde_args = get_argument(eqs, dict_indvars, dict_depvars) | ||
pde_bounds = map(pde_args) do pde_arg | ||
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) | ||
bds = eltypeθ.(bds) | ||
bds[1, :], bds[2, :] | ||
pde_vars = get_variables(eqs, dict_indvars, dict_depvars) | ||
pde_bounds = map(pde_vars) do pde_var | ||
if !isempty(pde_var) | ||
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_var) | ||
bds = eltypeθ.(bds) | ||
bds[1, :], bds[2, :] | ||
else | ||
[eltypeθ(0.0)], [eltypeθ(0.0)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what case is this handling? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I remember correctly, it was the case of something like |
||
end | ||
end | ||
|
||
bound_args = get_argument(bcs, dict_indvars, dict_depvars) | ||
bcs_bounds = map(bound_args) do bound_arg | ||
bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) | ||
bds = eltypeθ.(bds) | ||
bds[1, :], bds[2, :] | ||
dx_bcs = 1 / strategy.bcs_points | ||
dict_span_bcs = Dict([Symbol(d.variables) => [ | ||
infimum(d.domain) + dx_bcs, | ||
supremum(d.domain) - dx_bcs, | ||
] for d in domains]) | ||
bound_vars = get_variables(bcs, dict_indvars, dict_depvars) | ||
bcs_bounds = map(bound_vars) do bound_var | ||
if !isempty(bound_var) | ||
bds = mapreduce(s -> get(dict_span_bcs, s, fill(s, 2)), hcat, bound_var) | ||
bds = eltypeθ.(bds) | ||
bds[1, :], bds[2, :] | ||
else | ||
[eltypeθ(0.0)], [eltypeθ(0.0)] | ||
end | ||
end | ||
|
||
return pde_bounds, bcs_bounds | ||
end | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are those made and then not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're not used in this case any more. I think they may be used in the integral case, but that might not be true either. I can look through the different cases to see if they are ever used and remove them if not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it looks like deprecated code now so it would be good to just remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, currently,
cord1 = vcat(...)
is being used for integral equations only and in my efforts to see if it's possible to remove it, I've found something I broke that wasn't being tested for the integral equations, so I'm can work more on fixing that next week. In particular, I'll look for a fix that removes any need for those lines.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing I realized as I was working on this was that$\int u(\sin x) dx$ . However, $u'(\sin x)$ , not $\frac{d}{dx}\left[ u(\sin x) \right] = u'(\sin x) \cos x$ . They're interpreted this way because that's how the numeric integral and numeric derivative functions were already written. However, it feels a little inconsistent with the way the integral was interpreted; it's instead consistent with an interpretation of $U(\sin x)$ , where $U$ is an antiderivative of $u$ .
Ix(u(sin(x))
will now be interpreted asDx(u(sin(x))
is (under my current changes) being interpreted asIx(u(sin(x))
asIt feels to me like$\int u(\sin x) dx$ and $\frac{d}{dx}\left[ u(\sin x) \right]$ , but then I don't know how you would actually specify $U(\sin x)$ or $u'(\sin x)$ , or if you should even be allowed to. (I'm fine not letting people use $U(\sin x)$ since it's not uniquely defined, but it feels like they should be able to use $u'(\sin x)$ .)
Ix(u(sin(x))
isDx(u(sin(x))
isThoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not correct and would not play nicely. It should give the same result as what happens when basic symbolic interactions are done: