diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 4203df63..11b33839 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -564,6 +564,7 @@ def _calculate_psi_grad_constraint_from_Ip_tot( / (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1]) ) + def _psi_value_constraint_from_Vloop( dt: jax.Array, dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice, @@ -576,6 +577,7 @@ def _psi_value_constraint_from_Vloop( + dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right * dt ) + def _init_psi_and_current( dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, geo: Geometry, @@ -601,8 +603,14 @@ def _init_psi_and_current( Returns: Refined core profiles. """ + use_Vloop_bound_right = ( + dynamic_runtime_params_slice.profile_conditions.Vloop_bound_right is not None + ) + # Retrieving psi from the profile conditions. if dynamic_runtime_params_slice.profile_conditions.psi is not None: + # TODO: do we need to support the case where psi is given, but Vloop_bound_right + # is used to set the BC rather than Ip_tot? psi = cell_variable.CellVariable( value=dynamic_runtime_params_slice.profile_conditions.psi, right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( @@ -631,7 +639,12 @@ def _init_psi_and_current( right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice, geo, - ), + ) + if not use_Vloop_bound_right + else None, + right_face_constraint=geo.psi_from_Ip[-1] + if use_Vloop_bound_right + else None, dr=geo.drho_norm, ) core_profiles = dataclasses.replace(core_profiles, psi=psi) @@ -954,10 +967,15 @@ def compute_boundary_conditions( right_face_constraint=jnp.array(nimp_bound_right), ), 'psi': dict( - right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot( + right_face_grad_constraint=( + _calculate_psi_grad_constraint_from_Ip_tot( dynamic_runtime_params_slice_t, geo, - ), + ) + if dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right + is None + else None + ), right_face_constraint=( _psi_value_constraint_from_Vloop( dynamic_runtime_params_slice_t, @@ -969,7 +987,7 @@ def compute_boundary_conditions( else None ), ), - } + } # pylint: disable=invalid-name diff --git a/torax/examples/vloop.py b/torax/examples/vloop.py new file mode 100644 index 00000000..a58ebe6c --- /dev/null +++ b/torax/examples/vloop.py @@ -0,0 +1,90 @@ +CONFIG = { + "runtime_params": { + "profile_conditions": { + 'Ip_tot': 15.0, + "Vloop_bound_right": 0.0, + "psi": dict( + zip( + [ + 0.02, + 0.06, + 0.1, + 0.14, + 0.18, + 0.22, + 0.26, + 0.3, + 0.34, + 0.38, + 0.42, + 0.46, + 0.5, + 0.54, + 0.58, + 0.62, + 0.66, + 0.7, + 0.74, + 0.78, + 0.82, + 0.86, + 0.9, + 0.94, + 0.98, + ], + [ + 4.27722812e-02, + 3.94379591e-01, + 1.09263271e00, + 2.10681692e00, + 3.41030703e00, + 4.97594319e00, + 6.77607994e00, + 8.79231633e00, + 1.10530604e01, + 1.36745883e01, + 1.67709665e01, + 2.02615760e01, + 2.39018070e01, + 2.74979635e01, + 3.09738348e01, + 3.43088066e01, + 3.74947223e01, + 4.05261312e01, + 4.33998473e01, + 4.61154118e01, + 4.86753469e01, + 5.10851399e01, + 5.33529513e01, + 5.54890272e01, + 5.75047356e01, + ], + ) + ), + "set_pedestal": False, + } + }, + "geometry": { + "geometry_type": "circular", + }, + "sources": { + "j_bootstrap": {}, + "generic_current_source": {}, + "generic_particle_source": {}, + "gas_puff_source": {}, + "pellet_source": {}, + "generic_ion_el_heat_source": {}, + "fusion_heat_source": {}, + "qei_source": {}, + "ohmic_heat_source": {}, + }, + "transport": { + "transport_model": "constant", + }, + "stepper": { + "stepper_type": "linear", + }, + "time_step_calculator": { + "calculator_type": "chi", + }, +} diff --git a/torax/output.py b/torax/output.py index f5d01c25..5c37176b 100644 --- a/torax/output.py +++ b/torax/output.py @@ -60,6 +60,7 @@ class ToraxSimOutputs: PSI = "psi" PSIDOT = "psidot" PSI_RIGHT_GRAD_BC = "psi_right_grad_bc" +PSI_RIGHT_BC = "psi_right_bc" NE = "ne" NE_RIGHT_BC = "ne_right_bc" NI = "ni" @@ -195,9 +196,9 @@ def __init__( post_processed_output = [ state.post_processed_outputs for state in sim_outputs.sim_history ] - stack = lambda *ys: jnp.stack(ys) + stack = lambda *ys: jnp.stack(jnp.array(ys)) self.core_profiles: state.CoreProfiles = jax.tree_util.tree_map( - stack, *core_profiles + stack, *core_profiles, is_leaf=lambda x: x is None, ) self.core_sources: source_profiles.SourceProfiles = jax.tree_util.tree_map( stack, *core_sources @@ -263,6 +264,9 @@ def _get_core_profiles( xr_dict[PSI_RIGHT_GRAD_BC] = ( self.core_profiles.psi.right_face_grad_constraint ) + xr_dict[PSI_RIGHT_BC] = ( + self.core_profiles.psi.right_face_constraint + ) xr_dict[PSIDOT] = self.core_profiles.psidot.value xr_dict[NE] = self.core_profiles.ne.value xr_dict[NE_RIGHT_BC] = self.core_profiles.ne.right_face_constraint