diff --git a/components/eamxx/src/physics/ml_correction/eamxx_ml_correction_process_interface.cpp b/components/eamxx/src/physics/ml_correction/eamxx_ml_correction_process_interface.cpp index 42a54250fe5..171e3486ba4 100644 --- a/components/eamxx/src/physics/ml_correction/eamxx_ml_correction_process_interface.cpp +++ b/components/eamxx/src/physics/ml_correction/eamxx_ml_correction_process_interface.cpp @@ -2,6 +2,7 @@ #include "ekat/ekat_assert.hpp" #include "ekat/util/ekat_units.hpp" #include "share/field/field_utils.hpp" +#include "physics/share/physics_constants.hpp" namespace scream { // ========================================================================================= @@ -39,14 +40,14 @@ void MLCorrection::set_grids( FieldLayout scalar3d_mid = m_grid->get_3d_scalar_layout(true); FieldLayout scalar3d_int = m_grid->get_3d_scalar_layout(false); FieldLayout vector3d_mid = m_grid->get_3d_vector_layout(true,2); + const auto m2 = m*m; if (not m_ML_correction_unit_test) { - const auto m2 = m*m; const auto s2 = s*s; auto Wm2 = W / m / m; auto nondim = m/m; add_field("phis", scalar2d, m2/s2, grid_name); - add_field("SW_flux_dn", scalar3d_int, Wm2, grid_name, ps); add_field("sfc_alb_dif_vis", scalar2d, nondim, grid_name); + add_field("SW_flux_dn", scalar3d_int, Wm2, grid_name, ps); add_field("sfc_flux_sw_net", scalar2d, Wm2, grid_name); add_field("sfc_flux_lw_dn", scalar2d, Wm2, grid_name); m_lat = m_grid->get_geometry_data("lat"); @@ -62,6 +63,10 @@ void MLCorrection::set_grids( add_field("T_mid", scalar3d_mid, K, grid_name, ps); add_field("qv", scalar3d_mid, Q, grid_name, "tracers", ps); add_field("horiz_winds", vector3d_mid, m/s, grid_name, ps); + /* Note: we also need to update the precipitation after ML commits any column drying */ + add_field("pseudo_density", scalar3d_mid, Pa, grid_name, ps); + add_field("precip_liq_surf_mass", scalar2d, kg/m2, grid_name); + add_field("precip_ice_surf_mass", scalar2d, kg/m2, grid_name); /* ----------------------- WARNING --------------------------------*/ add_group("tracers", grid_name, 1, Bundling::Required); } @@ -86,16 +91,23 @@ void MLCorrection::run_impl(const double dt) { // use model time to infer solar zenith angle for the ML prediction auto current_ts = timestamp(); std::string datetime_str = current_ts.get_date_string() + " " + current_ts.get_time_string(); + + const auto &phis = get_field_in("phis").get_view(); + const auto &sfc_alb_dif_vis = get_field_in("sfc_alb_dif_vis").get_view(); + const auto &qv = get_field_out("qv").get_view(); const auto &T_mid = get_field_out("T_mid").get_view(); - const auto &phis = get_field_in("phis").get_view(); const auto &SW_flux_dn = get_field_out("SW_flux_dn").get_view(); - const auto &sfc_alb_dif_vis = get_field_in("sfc_alb_dif_vis").get_view(); const auto &sfc_flux_sw_net = get_field_out("sfc_flux_sw_net").get_view(); const auto &sfc_flux_lw_dn = get_field_out("sfc_flux_lw_dn").get_view(); const auto &u = get_field_out("horiz_winds").get_component(0).get_view(); const auto &v = get_field_out("horiz_winds").get_component(1).get_view(); + // For precipitation adjustment we need to track the change in column integrated 'qv' + host_view2d_type qv_told("",qv.extent(0),qv.extent(1)); + Kokkos::deep_copy(qv_told,qv); + + auto h_lat = m_lat.get_view(); auto h_lon = m_lon.get_view(); @@ -135,6 +147,69 @@ void MLCorrection::run_impl(const double dt) { ML_model_tq, ML_model_uv, ML_model_sfc_fluxes, datetime_str); pybind11::gil_scoped_release no_gil; ekat::enable_fpes(fpe_mask); + + // Now back out the qv change abd apply it to precipitation, only if Tq ML is turned on + if (m_ML_model_path_tq != "None") { + using PC = scream::physics::Constants; + using KT = KokkosTypes; + using MT = typename KT::MemberType; + using ESU = ekat::ExeSpaceUtils; + const auto &pseudo_density = get_field_in("pseudo_density").get_view(); + const auto &precip_liq_surf_mass = get_field_out("precip_liq_surf_mass").get_view(); + const auto &precip_ice_surf_mass = get_field_out("precip_ice_surf_mass").get_view(); + constexpr Real g = PC::gravit; + const auto num_levs = m_num_levs; + const auto policy = ESU::get_default_team_policy(m_num_cols, m_num_levs); + + const auto &qv_tnew = get_field_in("qv").get_view(); + Kokkos::parallel_for("Compute WVP diff", policy, + KOKKOS_LAMBDA(const MT& team) { + const int icol = team.league_rank(); + auto qold_icol = ekat::subview(qv_told,icol); + auto qnew_icol = ekat::subview(qv_tnew,icol); + auto rho_icol = ekat::subview(pseudo_density,icol); + Real net_column_moistening = 0; + // Compute WaterVaporPath Difference + Kokkos::parallel_reduce(Kokkos::TeamVectorRange(team, num_levs), + [&] (const int& ilev, Real& lsum) { + lsum += (qnew_icol(ilev)-qold_icol(ilev)) * rho_icol(ilev) / g; + },net_column_moistening); + team.team_barrier(); + // Adjust Precipitation + // - Note, we subtract the water vapor path because positive precip represents + // a descrease in qv. + auto tot_precip = precip_liq_surf_mass(icol)+precip_ice_surf_mass(icol); + if (tot_precip>0) { + // adjust precip by weighted avg of both phases + Kokkos::single(Kokkos::PerTeam(team), [&] { + auto liq_frac = precip_liq_surf_mass(icol)/tot_precip; + auto ice_frac = precip_ice_surf_mass(icol)/tot_precip; + precip_liq_surf_mass(icol) -= liq_frac*net_column_moistening; + precip_ice_surf_mass(icol) -= ice_frac*net_column_moistening; + }); + } else { + // Apply all the adjustment to a single phase based on surface temperature + Kokkos::single(Kokkos::PerTeam(team), [&] { + auto T_icol = ekat::subview(T_mid,icol); + if (T_icol(m_num_levs-1)>273.15) { + precip_liq_surf_mass(icol) -= net_column_moistening; + } else { + precip_ice_surf_mass(icol) -= net_column_moistening; + } + }); + } + if (precip_liq_surf_mass(icol)<0) { + Kokkos::single(Kokkos::PerTeam(team), [&] { + precip_liq_surf_mass(icol) = 0.0; + }); + } + if (precip_ice_surf_mass(icol)<0) { + Kokkos::single(Kokkos::PerTeam(team), [&] { + precip_ice_surf_mass(icol) = 0.0; + }); + } + }); + } } // =========================================================================================