Skip to content

Commit

Permalink
TEMP: Initial attempts at getting Vloop BC working
Browse files Browse the repository at this point in the history
This includes a hack to support outputting different BCs in the output
file. Previously, the simulation could only use one type of BC for the
whole run (ie either grad or value constraint). By wrapping the output in
a jnp.array(), BCs that are None get turned into NaN, which is compatible
with tree_map. Hence, this change allows you to have `grad_constraint =
[XXX, None, None, ...]` and `value_constraint = [None, XXX, YYY, ...]`
which is useful for testing the Vloop BC.
  • Loading branch information
theo-brown committed Nov 26, 2024
1 parent 017044b commit 3b7f33c
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 6 deletions.
26 changes: 22 additions & 4 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def _calculate_psi_grad_constraint_from_Ip_tot(
/ (geo.g2g3_over_rhon_face[-1] * geo.F_face[-1])
)


def _psi_value_constraint_from_Vloop(
dt: jax.Array,
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
Expand All @@ -576,6 +577,7 @@ def _psi_value_constraint_from_Vloop(
+ dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right * dt
)


def _init_psi_and_current(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: Geometry,
Expand All @@ -601,8 +603,14 @@ def _init_psi_and_current(
Returns:
Refined core profiles.
"""
use_Vloop_bound_right = (
dynamic_runtime_params_slice.profile_conditions.Vloop_bound_right is not None
)

# Retrieving psi from the profile conditions.
if dynamic_runtime_params_slice.profile_conditions.psi is not None:
# TODO: do we need to support the case where psi is given, but Vloop_bound_right
# is used to set the BC rather than Ip_tot?
psi = cell_variable.CellVariable(
value=dynamic_runtime_params_slice.profile_conditions.psi,
right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot(
Expand Down Expand Up @@ -631,7 +639,12 @@ def _init_psi_and_current(
right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot(
dynamic_runtime_params_slice,
geo,
),
)
if not use_Vloop_bound_right
else None,
right_face_constraint=geo.psi_from_Ip[-1]
if use_Vloop_bound_right
else None,
dr=geo.drho_norm,
)
core_profiles = dataclasses.replace(core_profiles, psi=psi)
Expand Down Expand Up @@ -954,10 +967,15 @@ def compute_boundary_conditions(
right_face_constraint=jnp.array(nimp_bound_right),
),
'psi': dict(
right_face_grad_constraint=_calculate_psi_grad_constraint_from_Ip_tot(
right_face_grad_constraint=(
_calculate_psi_grad_constraint_from_Ip_tot(
dynamic_runtime_params_slice_t,
geo,
),
)
if dynamic_runtime_params_slice_t.profile_conditions.Vloop_bound_right
is None
else None
),
right_face_constraint=(
_psi_value_constraint_from_Vloop(
dynamic_runtime_params_slice_t,
Expand All @@ -969,7 +987,7 @@ def compute_boundary_conditions(
else None
),
),
}
}


# pylint: disable=invalid-name
Expand Down
90 changes: 90 additions & 0 deletions torax/examples/vloop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
CONFIG = {
"runtime_params": {
"profile_conditions": {
'Ip_tot': 15.0,
"Vloop_bound_right": 0.0,
"psi": dict(
zip(
[
0.02,
0.06,
0.1,
0.14,
0.18,
0.22,
0.26,
0.3,
0.34,
0.38,
0.42,
0.46,
0.5,
0.54,
0.58,
0.62,
0.66,
0.7,
0.74,
0.78,
0.82,
0.86,
0.9,
0.94,
0.98,
],
[
4.27722812e-02,
3.94379591e-01,
1.09263271e00,
2.10681692e00,
3.41030703e00,
4.97594319e00,
6.77607994e00,
8.79231633e00,
1.10530604e01,
1.36745883e01,
1.67709665e01,
2.02615760e01,
2.39018070e01,
2.74979635e01,
3.09738348e01,
3.43088066e01,
3.74947223e01,
4.05261312e01,
4.33998473e01,
4.61154118e01,
4.86753469e01,
5.10851399e01,
5.33529513e01,
5.54890272e01,
5.75047356e01,
],
)
),
"set_pedestal": False,
}
},
"geometry": {
"geometry_type": "circular",
},
"sources": {
"j_bootstrap": {},
"generic_current_source": {},
"generic_particle_source": {},
"gas_puff_source": {},
"pellet_source": {},
"generic_ion_el_heat_source": {},
"fusion_heat_source": {},
"qei_source": {},
"ohmic_heat_source": {},
},
"transport": {
"transport_model": "constant",
},
"stepper": {
"stepper_type": "linear",
},
"time_step_calculator": {
"calculator_type": "chi",
},
}
8 changes: 6 additions & 2 deletions torax/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ToraxSimOutputs:
PSI = "psi"
PSIDOT = "psidot"
PSI_RIGHT_GRAD_BC = "psi_right_grad_bc"
PSI_RIGHT_BC = "psi_right_bc"
NE = "ne"
NE_RIGHT_BC = "ne_right_bc"
NI = "ni"
Expand Down Expand Up @@ -195,9 +196,9 @@ def __init__(
post_processed_output = [
state.post_processed_outputs for state in sim_outputs.sim_history
]
stack = lambda *ys: jnp.stack(ys)
stack = lambda *ys: jnp.stack(jnp.array(ys))
self.core_profiles: state.CoreProfiles = jax.tree_util.tree_map(
stack, *core_profiles
stack, *core_profiles, is_leaf=lambda x: x is None,
)
self.core_sources: source_profiles.SourceProfiles = jax.tree_util.tree_map(
stack, *core_sources
Expand Down Expand Up @@ -263,6 +264,9 @@ def _get_core_profiles(
xr_dict[PSI_RIGHT_GRAD_BC] = (
self.core_profiles.psi.right_face_grad_constraint
)
xr_dict[PSI_RIGHT_BC] = (
self.core_profiles.psi.right_face_constraint
)
xr_dict[PSIDOT] = self.core_profiles.psidot.value
xr_dict[NE] = self.core_profiles.ne.value
xr_dict[NE_RIGHT_BC] = self.core_profiles.ne.right_face_constraint
Expand Down

0 comments on commit 3b7f33c

Please sign in to comment.