Skip to content

Commit

Permalink
forward leg tested and finised
Browse files Browse the repository at this point in the history
  • Loading branch information
darioizzo committed Dec 8, 2023
1 parent c51dbbc commit 4515d0c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 37 deletions.
84 changes: 54 additions & 30 deletions benchmark/notebooks/grad.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -62,26 +62,6 @@
"veff = isp * 9.80665\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"27993600.0"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tof"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -91,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -104,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -182,12 +162,12 @@
"# grad ms\n",
"grad_ms=0\n",
"for i in range(nseg):\n",
" grad_ms += Mc[i+1]@Iv@dDv[i][:, -2:-1]\n",
" grad_ms += (Mc[i+1]@Iv)@dDv[i][:, -2:-1]\n",
"\n",
"# grad xs\n",
"grad_xs = Mc[0]\n",
"\n",
"# Assembling te return value\n",
"# Assembling the return value\n",
"# grad will contain the gradient of the final posvelm with respect to the throttles and the tof\n",
"grad = np.hstack((grad_u, grad_tof))\n",
"grad = np.vstack((grad, dm[-1][:,:-1]))\n",
Expand All @@ -201,22 +181,66 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7, 16)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grad.shape"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.53166067],\n",
" [0.22184716],\n",
" [0.57794741]])"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a[:,2:]"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6"
"array([[0.23424618, 0.3926254 , 0.53166067],\n",
" [0.23142296, 0.43140307, 0.22184716],\n",
" [0.32659941, 0.15505997, 0.57794741]])"
]
},
"execution_count": 6,
"execution_count": 107,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(m)"
"a"
]
},
{
Expand Down
67 changes: 60 additions & 7 deletions src/leg/sims_flanagan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <xtensor/xbuilder.hpp>
#include <xtensor/xio.hpp>
#include <xtensor/xmath.hpp>
#include <xtensor/xview.hpp>

#include <kep3/core_astro/constants.hpp>
#include <kep3/core_astro/propagate_lagrangian.hpp>
Expand All @@ -33,6 +34,7 @@ namespace kep3::leg
using kep3::linalg::_dot;
using kep3::linalg::mat13;
using kep3::linalg::mat61;
using kep3::linalg::mat63;
using kep3::linalg::mat66;

void _check_tof(double tof)
Expand Down Expand Up @@ -333,7 +335,8 @@ std::pair<std::array<double, 49>, std::vector<double>> sims_flanagan::compute_mc
xt::xarray<double> dtof = xt::zeros<double>({1u, nseg_fwd * 3u + 2u});
std::vector<mat13> Dv(nseg_fwd);
std::vector<xt::xarray<double>> dDv(nseg_fwd, xt::zeros<double>({3u, nseg_fwd * 3u + 2u}));
std::vector<mat66> M(nseg_fwd+1);
std::vector<mat66> M(nseg_fwd + 1); // The STMs
std::vector<mat66> Mc(nseg_fwd + 1); // Mc will contain [Mn@..@M0,Mn@..@M1, Mn]
std::vector<mat61> f(nseg_fwd + 1, xt::zeros<double>({6u, 1u}));
// Initialize values
m[0] = m_ms;
Expand All @@ -347,7 +350,7 @@ std::pair<std::array<double, 49>, std::vector<double>> sims_flanagan::compute_mc
}
dm[0](0, nseg_fwd * 3u) = 1.;
dtof(0, nseg_fwd * 3u + 1) = 1.;
// We compute the mass schedule and related quantities
// 1 - We compute the mass schedule and related gradients
for (decltype(m_throttles.size()) i = 0; i < nseg_fwd; ++i) {
Dv[i] = c / m[i] * u[i];
double un = std::sqrt(u[i](0, 0) * u[i](0, 0) + u[i](0, 1) * u[i](0, 1) + u[i](0, 2) * u[i](0, 2));
Expand All @@ -359,14 +362,14 @@ std::pair<std::array<double, 49>, std::vector<double>> sims_flanagan::compute_mc
m[i + 1] = m[i] * std::exp(-Dvn * a);
dm[i + 1] = -m[i + 1] * a * dDvn + std::exp(-Dvn * a) * dm[i];
}
// We compute the various STMs
// 2 - We compute the various STMs
std::array<std::array<double, 3>, 2> rv_it(get_rvs());
fmt::print("{}", dt);
std::optional<std::array<double, 36>> M_it;
for (decltype(m_throttles.size()) i = 0; i < nseg_fwd + 1; ++i) {
auto dur = dt;
if (i==0 || i==nseg_fwd) {
dur=dt/2;
if (i == 0 || i == nseg_fwd) {
dur = dt / 2;
}
std::tie(rv_it, M_it) = kep3::propagate_lagrangian(rv_it, dur, m_mu, true);
// Now we have the STM in M_it, but its a vector, we must operate on an xtensor object instead.
Expand All @@ -381,10 +384,60 @@ std::pair<std::array<double, 49>, std::vector<double>> sims_flanagan::compute_mc
rv_it[1][2] += Dv[i](0, 2);
}
}
// 3 - We now need to apply the chain rule to assemble the gradients we want (i.e. not w.r.t DV but w.r.t. u etc...)
mat63 Iv = xt::zeros<double>({6u, 3u}); // This is the gradient of x (rv) w.r.t. v
Iv(3, 0) = 1.;
Iv(4, 1) = 1.;
Iv(5, 2) = 1.;
Mc[nseg_fwd] = M[nseg_fwd]; // Mc will contain [Mn@..@M0,Mn@..@M1, Mn]
for (decltype(m_throttles.size()) i = 1; i < nseg_fwd + 1; ++i) {
Mc[nseg_fwd - i] = _dot(Mc[nseg_fwd - i + 1], M[nseg_fwd - i]);
}
// grad_tof
// First the d/dtof term - example: (0.5 * f3 + M3 @ f2 + M3 @ M2 @ f1 + 0.5 * M3 @ M2 @ M1 @ f0) / N
mat61 grad_tof = 0.5 * f[nseg_fwd];
for (decltype(m_throttles.size()) i = 0; i < nseg_fwd - 1; ++i) {
grad_tof += _dot(Mc[i + 2], f[i + 1]);
}
grad_tof += 0.5 * _dot(Mc[1], f[0]);
grad_tof /= nseg_fwd;
// Then we add the d/Dvi * dDvi/dtof - example: M3 @ Iv @ dDv2 + M3 @ M2 @ Iv @ dDv1 + M3 @ M2 @ M1 @ Iv @ dDv0
for (decltype(m_throttles.size()) i = 0; i < nseg_fwd; ++i) {
grad_tof += xt::linalg::dot(
_dot(Mc[i + 1], Iv), xt::eval(xt::view(dDv[i], xt::all(), xt::range(nseg_fwd * 3 + 1, nseg_fwd * 3 + 2))));
}
// grad_u
xt::xarray<double> grad_u = xt::zeros<double>({6u, nseg_fwd * 3u});
for (decltype(m_throttles.size()) i = 0u; i < nseg_fwd; ++i) {
grad_u
+= xt::linalg::dot(_dot(Mc[i + 1], Iv), xt::eval(xt::view(dDv[i], xt::all(), xt::range(0, nseg_fwd * 3))));
}
// grad_ms
xt::xarray<double> grad_ms = xt::zeros<double>({6u, 1u});
for (decltype(m_throttles.size()) i = 0u; i < nseg_fwd; ++i) {
grad_ms += xt::linalg::dot(_dot(Mc[i + 1], Iv),
xt::eval(xt::view(dDv[i], xt::all(), xt::range(nseg_fwd * 3, nseg_fwd * 3 + 1))));
}
// grad_xs
mat66 grad_xs = Mc[0];

// Allocate the return values
std::array<double, 49> grad_rvm{}; // The mismatch constraints gradient w.r.t. extended state r,v,m
std::vector<double> grad(nseg * 3 + 1, 0.); // The mismatch constraints gradient w.r.t. throttles and tof
std::array<double, 49> grad_rvm{}; // The mismatch constraints gradient w.r.t. extended state r,v,m
std::vector<double> grad((nseg_fwd * 3lu + 1) * 7,
0.); // The mismatch constraints gradient w.r.t. throttles and tof
// Copying in the computed derivatives
// a) xgrad (the xtensof gradient w.r.t. throttles and tof)
auto xgrad_rvm = xt::adapt(grad_rvm, {7u, 7u});
auto xgrad = xt::adapt(grad, {7u, nseg_fwd * 3 + 1u});
xt::view(xgrad, xt::range(0u, 6u), xt::range(0u, nseg_fwd * 3u)) = grad_u;
xt::view(xgrad, xt::range(0u, 6u), xt::range(nseg_fwd * 3, nseg_fwd * 3 + 1)) = grad_tof;
xt::view(xgrad, xt::range(6u, 7u), xt::all()) = xt::view(dm[nseg_fwd], xt::all(), xt::range(0u, nseg_fwd * 3 + 1));
// At this point since the variable order is u,m,tof we have put dmf/dms in rather than dms/dtof. So we fix this.
xgrad(6u, nseg_fwd * 3) = dm[nseg_fwd](0, nseg_fwd * 3 + 1);
// b) xgrad_rvm (the xtensor gradient w.r.t. the initial conditions)
xt::view(xgrad_rvm, xt::range(0, 6), xt::range(0, 6)) = grad_xs;
xt::view(xgrad_rvm, xt::range(0, 6), xt::range(6, 7)) = grad_ms;
xgrad_rvm(6, 6) = dm[nseg_fwd](0, nseg_fwd * 3);
return {grad_rvm, grad};
}

Expand Down

0 comments on commit 4515d0c

Please sign in to comment.