From 131014826d29bafc91ff3945608980bf4b3ba698 Mon Sep 17 00:00:00 2001 From: Alvaro Cea Date: Wed, 27 Nov 2024 15:53:58 +0000 Subject: [PATCH] sharding gust free flying working --- examples/BUG/modelgeneration.org | 12 ++-- feniax/intrinsic/args.py | 9 +-- feniax/intrinsic/argshard.py | 43 ++++++++++++ feniax/intrinsic/dq_dynamic.py | 28 ++++++-- feniax/intrinsic/dynamicShard.py | 107 +++++++++++++++++++++++++++++ feniax/intrinsic/postprocess.py | 2 +- feniax/plotools/upyvista.py | 40 +++++------ feniax/systems/intrinsic_system.py | 21 ++++++ 8 files changed, 227 insertions(+), 35 deletions(-) diff --git a/examples/BUG/modelgeneration.org b/examples/BUG/modelgeneration.org index b3105dc..ee33d6b 100644 --- a/examples/BUG/modelgeneration.org +++ b/examples/BUG/modelgeneration.org @@ -2009,7 +2009,7 @@ Running Nastran using the tailored functions in run_nastra.sh which moves output inp.system.aero.D = f"./AERO/{Dhj_file}.npy" inp.system.aero.gust_profile = "mc" inp.system.aero.gust.intensity = 20 - inp.system.aero.gust.length = 250. + inp.system.aero.gust.length = 150. inp.system.aero.gust.step = 0.1 inp.system.aero.gust.shift = 0. inp.system.aero.gust.panels_dihedral = f"./AERO/Dihedral_{label_dlm}.npy" @@ -2756,18 +2756,18 @@ make gust video <> inp.driver.sol_path = pathlib.Path( - f"./results/gust1_{sol}Shard") + f"./results/gust2_{sol}Shard") inp.system.aero.gust.fixed_discretisation = [150, u_inf] # Shard inputs - inputflow = dict(length=np.linspace(25,265,13), - intensity= np.linspace(0.1, 3, 11), - rho_inf = [0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75] + inputflow = dict(length=[25, 75, 150, 240], #np.linspace(25,265,13), + intensity= [1, 10, 15, 20] #np.linspace(0.1, 3, 11), + #rho_inf = [0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75] ) inp.system.shard = dict(input_type="gust1", inputs=inputflow) num_gpus = 8 - solstatic1shard = feniax.feniax_shardmain.main(input_dict=inp, device_count=num_gpus) + solgust21shard = feniax.feniax_shardmain.main(input_dict=inp, device_count=num_gpus) #+end_src diff --git a/feniax/intrinsic/args.py b/feniax/intrinsic/args.py index ef4b15e..c0328a8 100644 --- a/feniax/intrinsic/args.py +++ b/feniax/intrinsic/args.py @@ -690,6 +690,7 @@ def arg_20g546( *args, **kwargs, ): + eta_0 = kwargs["eta_0"] gamma1 = sol.data.couplings.gamma1 gamma2 = sol.data.couplings.gamma2 @@ -711,20 +712,20 @@ def arg_20g546( omega, phi1l, states, + aero.poles, num_modes, num_poles, + gust.x, + c_ref, aero.A0hat, aero.A1hat, aero.A2hatinv, aero.A3hat, u_inf, - c_ref, - aero.poles, - gust.x, F1g, Flg, ) - + @catter2library def arg_20G546( diff --git a/feniax/intrinsic/argshard.py b/feniax/intrinsic/argshard.py index cbb2560..62f548b 100644 --- a/feniax/intrinsic/argshard.py +++ b/feniax/intrinsic/argshard.py @@ -90,3 +90,46 @@ def arg_20g21( ) ) + +def arg_20g546( + sol: solution.IntrinsicSolution, + system: intrinsicmodal.Dsystem, + fem: intrinsicmodal.Dfem, + *args, + **kwargs, +): + + eta_0 = kwargs["eta_0"] + phi1l = sol.data.modes.phi1l + phi2l = sol.data.modes.phi2l + psi2l = sol.data.modes.psi2l + X_xdelta = sol.data.modes.X_xdelta + omega = sol.data.modes.omega + X_xdelta = sol.data.modes.X_xdelta + C0ab = sol.data.modes.C0ab + gamma1 = sol.data.couplings.gamma1 + gamma2 = sol.data.couplings.gamma2 + states = system.states + c_ref = system.aero.c_ref + num_modes = fem.num_modes + num_poles = system.aero.num_poles + poles = system.aero.poles + A = system.aero.A + D = system.aero.D + xgust = system.aero.gust.time + collocation_x = system.aero.gust.collocation_points[:,0] + return (phi1l, phi2l, psi2l, X_xdelta, C0ab, A, D, c_ref, + ( + eta_0, + gamma1, + gamma2, + omega, + phi1l, + states, + poles, + num_modes, + num_poles, + xgust, + ) + + ) diff --git a/feniax/intrinsic/dq_dynamic.py b/feniax/intrinsic/dq_dynamic.py index 847013c..e9a4143 100644 --- a/feniax/intrinsic/dq_dynamic.py +++ b/feniax/intrinsic/dq_dynamic.py @@ -352,23 +352,43 @@ def dq_20g273(t, q, *args): def dq_20g546(t, q, *args): """Gust response free flight, q0 obtained via integrator q1.""" + # ( + # eta_0, + # gamma1, + # gamma2, + # omega, + # phi1l, + # states, + # num_modes, + # num_poles, + # A0hat, + # A1hat, + # A2hatinv, + # A3hat, + # u_inf, + # c_ref, + # poles, + # xgust, + # F1gust, + # Flgust, + # ) = args[0] ( eta_0, gamma1, gamma2, omega, - phi1l, + phi1l, states, + poles, num_modes, num_poles, + xgust, + c_ref, A0hat, A1hat, A2hatinv, A3hat, u_inf, - c_ref, - poles, - xgust, F1gust, Flgust, ) = args[0] diff --git a/feniax/intrinsic/dynamicShard.py b/feniax/intrinsic/dynamicShard.py index 5594bd0..714eb9b 100644 --- a/feniax/intrinsic/dynamicShard.py +++ b/feniax/intrinsic/dynamicShard.py @@ -151,3 +151,110 @@ def _main_20g21_3(inp): main_vmap = jax.vmap(_main_20g21_3) results = main_vmap(inputs) return results + +@partial(jax.jit, static_argnames=["config"]) +def main_20g546_3( + inputs, # + q0, + config, + args, + **kwargs, +): + + X = config.fem.X + phi1l, phi2l, psi2l, X_xdelta, C0ab, A, D, c_ref, _dqargs = args + states = _dqargs[5] + q1_index = states["q1"] + q2_index = states["q2"] + # q1_index = states["q1"] + # q2_index = states["q2"] + + collocation_points = config.system.aero.gust.collocation_points + gust_shift = config.system.aero.gust.shift + dihedral = config.system.aero.gust.panels_dihedral + fshape_span = igust._get_spanshape(config.system.aero.gust.shape) + # gust_totaltime = config.system.aero.gust.totaltime + time_gust = config.system.aero.gust.time + # @jax.jit + def _main_20g546_3(inp): + + rho_inf = inp[0] + u_inf = inp[1] + gust_length = inp[2] + gust_intensity = inp[3] + q_inf = 0.5 * rho_inf * u_inf**2 + gust_totaltime = gust_length / u_inf + + A0hat = q_inf * A[0] + A1hat = c_ref * rho_inf * u_inf / 4 * A[1] + A2hat = c_ref**2 * rho_inf / 8 * A[2] + A3hat = q_inf * A[3:] + A2hatinv = jnp.linalg.inv(jnp.eye(len(A2hat)) - A2hat) + D0hat = q_inf * D[0] + D1hat = c_ref * rho_inf * u_inf / 4 * D[1] + D2hat = c_ref**2 * rho_inf / 8 * D[2] + D3hat = q_inf * D[3:] + + gust, gust_dot, gust_ddot = igust._downwashRogerMc( + u_inf, + gust_length, + gust_intensity, + gust_shift, + collocation_points, + dihedral, # normals, + time_gust, + gust_totaltime, + fshape_span, + ) + Q_w, Q_wdot, Q_wddot, Q_wsum, Ql_wdot = igust._getGAFs( + D0hat, # NbxNm + D1hat, + D2hat, + D3hat, + gust, + gust_dot, + gust_ddot, + ) + + args_inp = (c_ref, + A0hat, + A1hat, + A2hatinv, + A3hat, + u_inf, + Q_wsum, + Ql_wdot + ) + + dq_args = _dqargs + args_inp + states_puller, eqsolver = sollibs.factory( + config.system.solver_library, config.system.solver_function + ) + + sol = eqsolver( + dq_dynamic.dq_20g546, + dq_args, + config.system.solver_settings, + q0=q0, + t0=config.system.t0, + t1=config.system.t1, + tn=config.system.tn, + dt=config.system.dt, + t=config.system.t, + ) + # jax.debug.breakpoint() + q = states_puller(sol) + q1 = q[:, q1_index] + q2 = q[:, q2_index] + tn = len(q) + # X2, X3, ra, Cab = isys.recover_staticfields(q2, tn, X, + # phi2l, psi2l, X_xdelta, C0ab, config.fem) + X1, X2, X3, ra, Cab = isys.recover_fieldsRB( + q1, q2, tn, config.system.dt, X, phi1l, phi2l, psi2l, X_xdelta, C0ab, config + ) + + return dict(q=q, X1=X1,X2=X2, X3=X3, ra=ra, Cab=Cab) + + main_vmap = jax.vmap(_main_20g546_3) + results = main_vmap(inputs) + return results diff --git a/feniax/intrinsic/postprocess.py b/feniax/intrinsic/postprocess.py index e266a9d..4e5f88e 100644 --- a/feniax/intrinsic/postprocess.py +++ b/feniax/intrinsic/postprocess.py @@ -45,7 +45,7 @@ def velocity_ra(): ... def strains_ra(): ... - +@jax.jit def integrate_node0(X1, dt, ra_n0, Rab_n0): v_average = (X1[:-1, :3] + X1[1:, :3]) / 2 theta_average = (X1[:-1, 3:6] + X1[1:, 3:6]) / 2 * dt diff --git a/feniax/plotools/upyvista.py b/feniax/plotools/upyvista.py index 640b040..9d141f3 100644 --- a/feniax/plotools/upyvista.py +++ b/feniax/plotools/upyvista.py @@ -42,31 +42,31 @@ def render_mesh(points, lines): return mesh -import pyvista as pv +# import pyvista as pv -points = np.random.rand(100, 3) -mesh = pv.PolyData(points) -mesh.plot(point_size=10, style='points', color='tan') +# points = np.random.rand(100, 3) +# mesh = pv.PolyData(points) +# mesh.plot(point_size=10, style='points', color='tan') -polydata = pv.PolyData(points) +# polydata = pv.PolyData(points) -# Add the vectors as point data -polydata["vectors"] = vectors +# # Add the vectors as point data +# polydata["vectors"] = vectors -# Create glyphs to represent vectors -glyphs = polydata.glyph(orient="vectors", scale=False, factor=0.3) +# # Create glyphs to represent vectors +# glyphs = polydata.glyph(orient="vectors", scale=False, factor=0.3) -# Plot the vector field -plotter = pv.Plotter() -plotter.add_mesh(glyphs, color='red') -plotter.add_mesh(polydata, color='blue', point_size=5, render_points_as_spheres=True) -plotter.show() +# # Plot the vector field +# plotter = pv.Plotter() +# plotter.add_mesh(glyphs, color='red') +# plotter.add_mesh(polydata, color='blue', point_size=5, render_points_as_spheres=True) +# plotter.show() -mesh = pyvista.PolyData(v, cells) -mesh.save(folder_path / f"collocation_{k}.ply", binary=False) +# mesh = pyvista.PolyData(v, cells) +# mesh.save(folder_path / f"collocation_{k}.ply", binary=False) -X=config.fem.X, -time=range(len(inp.system.t)), -ra=sol.staticsystem_sys1.ra[i], -Rab=sol.staticsystem_sys1.Cab[i], +# X=config.fem.X, +# time=range(len(inp.system.t)), +# ra=sol.staticsystem_sys1.ra[i], +# Rab=sol.staticsystem_sys1.Cab[i], diff --git a/feniax/systems/intrinsic_system.py b/feniax/systems/intrinsic_system.py index e7c824a..26d5d8c 100644 --- a/feniax/systems/intrinsic_system.py +++ b/feniax/systems/intrinsic_system.py @@ -38,6 +38,27 @@ def recover_fields(q1, q2, tn, X, phi1l, phi2l, psi2l, X_xdelta, C0ab, config): return X1, X2, X3, ra, Cab +@partial(jax.jit, static_argnames=["config", "tn", "dt"]) +def recover_fieldsRB(q1, q2, tn, dt, X, phi1l, phi2l, psi2l, X_xdelta, C0ab, config): + ra_n0 = X[0] + Rab_n0 = jnp.eye(3) + X1 = postprocess.compute_velocities(phi1l, q1) + + X2 = postprocess.compute_internalforces(phi2l, q2) + X3 = postprocess.compute_strains(psi2l, q2) + + Cab0, ra0 = postprocess.integrate_node0( + X1[:, :, 0], dt, ra_n0, Rab_n0 + ) + Cab, ra = postprocess.integrate_strains_t( + ra0, + Cab0, + X3, + X_xdelta, + C0ab, + config, + ) + return X1, X2, X3, ra, Cab @partial(jax.jit, static_argnames=["config", "tn"]) def recover_staticfields(q2, tn, X, phi2l, psi2l, X_xdelta, C0ab, config):