From 3fde6d67f4f35898f1919d00718b124092fef5ec Mon Sep 17 00:00:00 2001 From: Jonathan Citrin Date: Thu, 7 Nov 2024 10:43:16 -0800 Subject: [PATCH] Generalize the transport_model _make_core_transport method. Now easier supports non-QuaLiKiz transport models like TGLF and TGLF-NN. PiperOrigin-RevId: 694171536 --- torax/transport_model/qlknn_wrapper.py | 2 + .../qualikiz_based_transport_model.py | 108 +++-------- torax/transport_model/qualikiz_wrapper.py | 2 + .../quasilinear_transport_model.py | 172 +++++++++++++++++- .../tests/qualikiz_based_transport_model.py | 2 + .../tests/quasilinear_transport_model.py | 4 + 6 files changed, 204 insertions(+), 86 deletions(-) diff --git a/torax/transport_model/qlknn_wrapper.py b/torax/transport_model/qlknn_wrapper.py index ae8d4224..6906933c 100644 --- a/torax/transport_model/qlknn_wrapper.py +++ b/torax/transport_model/qlknn_wrapper.py @@ -356,6 +356,8 @@ def _combined( transport=runtime_config_inputs.transport, geo=geo, core_profiles=core_profiles, + gradient_reference_length=geo.Rmaj, + gyrobohm_flux_reference_length=geo.Rmin, ) def __hash__(self) -> int: diff --git a/torax/transport_model/qualikiz_based_transport_model.py b/torax/transport_model/qualikiz_based_transport_model.py index 029e4d5f..28410a73 100644 --- a/torax/transport_model/qualikiz_based_transport_model.py +++ b/torax/transport_model/qualikiz_based_transport_model.py @@ -26,6 +26,7 @@ @chex.dataclass class RuntimeParams(quasilinear_transport_model.RuntimeParams): """Shared parameters for Qualikiz-based models.""" + # Collisionality multiplier. coll_mult: float = 1.0 # ensure that smag - alpha > -0.2 always, to compensate for no slab modes @@ -44,6 +45,7 @@ def make_provider( @chex.dataclass(frozen=True) class DynamicRuntimeParams(quasilinear_transport_model.DynamicRuntimeParams): """Shared parameters for Qualikiz-based models.""" + coll_mult: float avoid_big_negative_s: bool smag_alpha_correction: bool @@ -66,8 +68,6 @@ class QualikizInputs(quasilinear_transport_model.QuasilinearInputs): """Inputs to Qualikiz-based models.""" Zeff_face: chex.Array - Ani0: chex.Array - Ani1: chex.Array q: chex.Array smag: chex.Array x: chex.Array @@ -95,83 +95,30 @@ def _prepare_qualikiz_inputs( """Prepare Qualikiz inputs.""" constants = constants_module.CONSTANTS - Rmin = geo.Rmin - Rmaj = geo.Rmaj - # define radial coordinate as midplane average r # (typical assumption for transport models developed in circular geo) rmid = (geo.Rout - geo.Rin) * 0.5 rmid_face = (geo.Rout_face - geo.Rin_face) * 0.5 - temp_ion_var = core_profiles.temp_ion - temp_ion_face = temp_ion_var.face_value() - temp_ion_face_grad = temp_ion_var.face_grad(rmid) - temp_el_var = core_profiles.temp_el - temp_electron_face = temp_el_var.face_value() - temp_electron_face_grad = temp_el_var.face_grad(rmid) - # Careful, these are in n_ref units, not postprocessed to SI units yet - raw_ne = core_profiles.ne - raw_ne_face = raw_ne.face_value() - raw_ne_face_grad = raw_ne.face_grad(rmid) - raw_ni = core_profiles.ni - raw_ni_face = raw_ni.face_value() - raw_ni_face_grad = raw_ni.face_grad(rmid) - raw_nimp = core_profiles.nimp - raw_nimp_face = raw_nimp.face_value() - raw_nimp_face_grad = raw_nimp.face_grad(rmid) - - # True SI value versions - true_ne_face = raw_ne_face * nref - true_ni_face = raw_ni_face * nref - true_nimp_face = raw_nimp_face * nref - # gyrobohm diffusivity # (defined here with Lref=Rmin due to QLKNN training set normalization) - chiGB = ( - (core_profiles.Ai * constants.mp) ** 0.5 - / (constants.qe * geo.B0) ** 2 - * (temp_ion_face * constants.keV2J) ** 1.5 - / Rmin + chiGB = quasilinear_transport_model.calculate_chiGB( + core_profiles=core_profiles, + b_unit=geo.B0, + reference_length=geo.Rmin, ) # transport coefficients from the qlknn-hyper-10D model # (K.L. van de Plassche PoP 2020) - # TODO(b/335581689): make a unit test that tests this function directly - # with set_pedestal = False. Currently this is tested only via - # sim test7, which has set_pedestal=True. With set_pedestal=True, - # mutants of Ati[-1], Ate[-1], An[-1] all affect only chi[-1], but - # chi[-1] remains above config.transport.chimin for all mutants. - # The pedestal feature then clips chi[-1] to config.transport.chimin, so the - # mutants have no effect. - # set up input vectors (all as jax.numpy arrays on face grid) - # R/LTi profile from current timestep temp_ion - Ati = -Rmaj * temp_ion_face_grad / temp_ion_face - # to avoid divisions by zero - Ati = jnp.where(jnp.abs(Ati) < constants.eps, constants.eps, Ati) - - # R/LTe profile from current timestep temp_el - Ate = -Rmaj * temp_electron_face_grad / temp_electron_face - # to avoid divisions by zero - Ate = jnp.where(jnp.abs(Ate) < constants.eps, constants.eps, Ate) - - # R/Ln profiles from current timestep - # OK to use normalized version here, because nref in numer and denom - # cancels. - Ane = -Rmaj * raw_ne_face_grad / raw_ne_face - Ani0 = -Rmaj * raw_ni_face_grad / raw_ni_face - # To avoid divisions by zero in cases where Zeff=1. - Ani1 = jnp.where( - jnp.abs(raw_nimp_face) < constants.eps, - 0.0, - -Rmaj * raw_nimp_face_grad / raw_nimp_face, + # Calculate normalized logarithmic gradients + normalized_logarithmic_gradients = quasilinear_transport_model.NormalizedLogarithmicGradients.from_profiles( + core_profiles=core_profiles, + radial_coordinate=rmid, + reference_length=geo.Rmaj, ) - # to avoid divisions by zero - Ane = jnp.where(jnp.abs(Ane) < constants.eps, constants.eps, Ane) - Ani0 = jnp.where(jnp.abs(Ani0) < constants.eps, constants.eps, Ani0) - Ani1 = jnp.where(jnp.abs(Ani1) < constants.eps, constants.eps, Ani1) # Calculate q and s. # Need to recalculate since in the nonlinear solver psi has intermediate @@ -187,13 +134,15 @@ def _prepare_qualikiz_inputs( ) # Inverse aspect ratio at LCFS. - epsilon_lcfs = rmid_face[-1] / Rmaj + epsilon_lcfs = rmid_face[-1] / geo.Rmaj # Local normalized radius. x = rmid_face / rmid_face[-1] x = jnp.where(jnp.abs(x) < constants.eps, constants.eps, x) # Ion to electron temperature ratio - Ti_Te = temp_ion_face / temp_electron_face + Ti_Te = ( + core_profiles.temp_ion.face_value() / core_profiles.temp_el.face_value() + ) # logarithm of normalized collisionality nu_star = physics.calc_nu_star( @@ -206,11 +155,12 @@ def _prepare_qualikiz_inputs( log_nu_star_face = jnp.log10(nu_star) # calculate alpha for magnetic shear correction (see S. van Mulders NF 2021) - factor_0 = 2 / geo.B0**2 * constants.mu0 * q**2 - alpha = factor_0 * ( - temp_electron_face * constants.keV2J * true_ne_face * (Ate + Ane) - + true_ni_face * temp_ion_face * constants.keV2J * (Ati + Ani0) - + true_nimp_face * temp_ion_face * constants.keV2J * (Ati + Ani1) + alpha = quasilinear_transport_model.calculate_alpha( + core_profiles=core_profiles, + nref=nref, + q=q, + b_unit=geo.B0, + normalized_logarithmic_gradients=normalized_logarithmic_gradients, ) # to approximate impact of Shafranov shift. From van Mulders Nucl. Fusion @@ -248,14 +198,14 @@ def _prepare_qualikiz_inputs( alpha - 0.2, smag, ) - normni = raw_ni_face / raw_ne_face + normni = core_profiles.ni.face_value() / core_profiles.ne.face_value() return QualikizInputs( Zeff_face=Zeff_face, - Ati=Ati, - Ate=Ate, - Ane=Ane, - Ani0=Ani0, - Ani1=Ani1, + Ati=normalized_logarithmic_gradients.Ati, + Ate=normalized_logarithmic_gradients.Ate, + Ane=normalized_logarithmic_gradients.Ane, + Ani0=normalized_logarithmic_gradients.Ani0, + Ani1=normalized_logarithmic_gradients.Ani1, q=q, smag=smag, x=x, @@ -263,8 +213,8 @@ def _prepare_qualikiz_inputs( log_nu_star_face=log_nu_star_face, normni=normni, chiGB=chiGB, - Rmaj=Rmaj, - Rmin=Rmin, + Rmaj=geo.Rmaj, + Rmin=geo.Rmin, alpha=alpha, epsilon_lcfs=epsilon_lcfs, ) diff --git a/torax/transport_model/qualikiz_wrapper.py b/torax/transport_model/qualikiz_wrapper.py index 3e55fd96..64f6b36f 100644 --- a/torax/transport_model/qualikiz_wrapper.py +++ b/torax/transport_model/qualikiz_wrapper.py @@ -242,6 +242,8 @@ def _extract_run_data( transport=transport, geo=geo, core_profiles=core_profiles, + gradient_reference_length=geo.Rmaj, + gyrobohm_flux_reference_length=geo.Rmin, ) diff --git a/torax/transport_model/quasilinear_transport_model.py b/torax/transport_model/quasilinear_transport_model.py index 72d79d4c..c3b83683 100644 --- a/torax/transport_model/quasilinear_transport_model.py +++ b/torax/transport_model/quasilinear_transport_model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Base class for quasilinear models.""" +from __future__ import annotations import chex import jax @@ -19,14 +20,99 @@ from torax import constants as constants_module from torax import geometry from torax import state +from torax.fvm import cell_variable from torax.transport_model import runtime_params as runtime_params_lib from torax.transport_model import transport_model +def calculate_chiGB( # pylint: disable=invalid-name + core_profiles: state.CoreProfiles, + b_unit: chex.Numeric, + reference_length: chex.Numeric, +) -> chex.Array: + """Calculates the gyrobohm diffusivity. + + Args: + core_profiles: CoreProfiles object containing plasma profiles. + b_unit: Magnetic field strength. Different transport models have different + definitions of the specific magnetic field input. + reference_length: Reference length for normalization. + + Returns: + Gyrobohm diffusivity as a chex.Array. + """ + constants = constants_module.CONSTANTS + return ( + (core_profiles.Ai * constants.mp) ** 0.5 + / (constants.qe * b_unit) ** 2 + * (core_profiles.temp_ion.face_value() * constants.keV2J) ** 1.5 + / reference_length + ) + + +def calculate_alpha( + core_profiles: state.CoreProfiles, + nref: chex.Numeric, + q: chex.Array, + b_unit: chex.Numeric, + normalized_logarithmic_gradients: NormalizedLogarithmicGradients, +) -> chex.Array: + """Calculates the alpha_MHD parameter. + + alpha_MHD = Lref q^2 beta' , where beta' is the radial gradient of beta, the + ratio of plasma pressure to magnetic pressure, Lref a reference length, + and q is the safety factor. Lref is included within the + NormalizedLogarithmicGradients. + + Args: + core_profiles: CoreProfiles object containing plasma profiles. + nref: Reference density. + q: Safety factor. + b_unit: Magnetic field strength. Different transport models have different + definitions of the specific magnetic field input. + normalized_logarithmic_gradients: Normalized logarithmic gradients of plasma + profiles. + + Returns: + Alpha value as a chex.Array. + """ + constants = constants_module.CONSTANTS + + factor_0 = 2 / b_unit**2 * constants.mu0 * q**2 + alpha = factor_0 * ( + core_profiles.temp_el.face_value() + * constants.keV2J + * core_profiles.ne.face_value() + * nref + * ( + normalized_logarithmic_gradients.Ate + + normalized_logarithmic_gradients.Ane + ) + + core_profiles.ni.face_value() + * nref + * core_profiles.temp_ion.face_value() + * constants.keV2J + * ( + normalized_logarithmic_gradients.Ati + + normalized_logarithmic_gradients.Ani0 + ) + + core_profiles.nimp.face_value() + * nref + * core_profiles.temp_ion.face_value() + * constants.keV2J + * ( + normalized_logarithmic_gradients.Ati + + normalized_logarithmic_gradients.Ani1 + ) + ) + return alpha + + # pylint: disable=invalid-name @chex.dataclass class RuntimeParams(runtime_params_lib.RuntimeParams): """Shared parameters for Quasilinear models.""" + # effective D / effective V approach for particle transport DVeff: bool = False # minimum |R/Lne| below which effective V is used instead of effective D @@ -34,13 +120,14 @@ class RuntimeParams(runtime_params_lib.RuntimeParams): def make_provider( self, torax_mesh: geometry.Grid1D | None = None - ) -> 'RuntimeParamsProvider': + ) -> RuntimeParamsProvider: return RuntimeParamsProvider(**self.get_provider_kwargs(torax_mesh)) @chex.dataclass(frozen=True) class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams): """Shared parameters for Quasilinear models.""" + DVeff: bool An_min: float @@ -55,15 +142,76 @@ def build_dynamic_params(self, t: chex.Numeric) -> DynamicRuntimeParams: return DynamicRuntimeParams(**self.get_dynamic_params_kwargs(t)) +@chex.dataclass(frozen=True) +class NormalizedLogarithmicGradients: + """Normalized logarithmic gradients of plasma profiles.""" + + Ati: chex.Array + Ate: chex.Array + Ane: chex.Array + Ani0: chex.Array + Ani1: chex.Array + + @classmethod + def from_profiles( + cls, + core_profiles: state.CoreProfiles, + radial_coordinate: jnp.ndarray, + reference_length: jnp.ndarray, + ) -> NormalizedLogarithmicGradients: + """Calculates the normalized logarithmic gradients.""" + gradients = {} + for name, profile in { + "Ati": core_profiles.temp_ion, + "Ate": core_profiles.temp_el, + "Ane": core_profiles.ne, + "Ani0": core_profiles.ni, + "Ani1": core_profiles.nimp, + }.items(): + gradients[name] = cls._calculate_normalized_logarithmic_gradient( + var=profile, + radial_coordinate=radial_coordinate, + reference_length=reference_length, + ) + return cls(**gradients) + + @staticmethod + def _calculate_normalized_logarithmic_gradient( + var: cell_variable.CellVariable, + radial_coordinate: jax.Array, + reference_length: jax.Array, + ) -> jax.Array: + """Calculates the normalized logarithmic gradient of a CellVariable on the face grid.""" + + # var ~ 0 is only possible for ions (e.g. zero impurity density), and we + # guard against possible division by zero. + result = jnp.where( + jnp.abs(var.face_value()) < constants_module.CONSTANTS.eps, + constants_module.CONSTANTS.eps, + -reference_length * var.face_grad(radial_coordinate) / var.face_value(), + ) + + # to avoid divisions by zero elsewhere in TORAX, if the gradient is zero + result = jnp.where( + jnp.abs(result) < constants_module.CONSTANTS.eps, + constants_module.CONSTANTS.eps, + result, + ) + return result + + @chex.dataclass(frozen=True) class QuasilinearInputs: """Variables required to convert outputs to TORAX CoreTransport outputs.""" + chiGB: chex.Array Rmin: chex.Array Rmaj: chex.Array Ati: chex.Array Ate: chex.Array Ane: chex.Array + Ani0: chex.Array + Ani1: chex.Array class QuasilinearTransportModel(transport_model.TransportModel): @@ -78,27 +226,37 @@ def _make_core_transport( transport: DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + gradient_reference_length: chex.Numeric, + gyrobohm_flux_reference_length: chex.Numeric, ) -> state.CoreTransport: """Converts model output to CoreTransport.""" constants = constants_module.CONSTANTS # conversion to SI units (note that n is normalized here) + + # Convert the electron heat flux from GB (pfe) to SI units. pfe_SI = ( pfe * core_profiles.ne.face_value() * quasilinear_inputs.chiGB - / quasilinear_inputs.Rmin + / gyrobohm_flux_reference_length ) # chi outputs in SI units. - # chi in GB units is Q[GB]/(a/LT) , Lref=Rmin in Q[GB]. - # max/min clipping included + # chi[GB] = -Q[GB]/(Lref/LT), chi is heat diffusivity, Q is heat flux, + # where Lref is the gyrobohm normalization length, LT the logarithmic + # gradient length (unnormalized). For normalized_logarithmic_gradients like + # the gradient normalization length can in principle be different from the + # gyrobohm flux reference length. e.g. in QuaLiKiz Ati = -Rmaj/LTi, but the + # gyrobohm flux reference length is Rmin. + # In case they are indeed different we rescale the normalized logarithmic + # gradient by the ratio of the two reference lengths. chi_face_ion = ( - ((quasilinear_inputs.Rmaj / quasilinear_inputs.Rmin) * qi) + ((gradient_reference_length / gyrobohm_flux_reference_length) * qi) / quasilinear_inputs.Ati ) * quasilinear_inputs.chiGB chi_face_el = ( - ((quasilinear_inputs.Rmaj / quasilinear_inputs.Rmin) * qe) + ((gradient_reference_length / gyrobohm_flux_reference_length) * qe) / quasilinear_inputs.Ate ) * quasilinear_inputs.chiGB @@ -135,7 +293,7 @@ def Dscaled_approach() -> tuple[jax.Array, jax.Array]: pfe_SI / core_profiles.ne.face_value() - quasilinear_inputs.Ane * d_face_el - / quasilinear_inputs.Rmaj + / gradient_reference_length * geo.g1_over_vpr2_face * geo.rho_b**2 ) / (geo.g0_over_vpr_face * geo.rho_b) diff --git a/torax/transport_model/tests/qualikiz_based_transport_model.py b/torax/transport_model/tests/qualikiz_based_transport_model.py index 5727da6a..b8d63a33 100644 --- a/torax/transport_model/tests/qualikiz_based_transport_model.py +++ b/torax/transport_model/tests/qualikiz_based_transport_model.py @@ -186,6 +186,8 @@ def _call_implementation( transport=transport, geo=geo, core_profiles=core_profiles, + gradient_reference_length=geo.Rmaj, + gyrobohm_flux_reference_length=geo.Rmin, ) diff --git a/torax/transport_model/tests/quasilinear_transport_model.py b/torax/transport_model/tests/quasilinear_transport_model.py index 592eff6c..e0effb05 100644 --- a/torax/transport_model/tests/quasilinear_transport_model.py +++ b/torax/transport_model/tests/quasilinear_transport_model.py @@ -140,6 +140,8 @@ def _call_implementation( Ati=jnp.array(1.1), Ate=jnp.array(1.2), Ane=jnp.array(1.3), + Ani0=jnp.array(1.4), + Ani1=jnp.array(1.5), ) transport = dynamic_runtime_params_slice.transport # Assert required for pytype. @@ -155,6 +157,8 @@ def _call_implementation( transport=transport, geo=geo, core_profiles=core_profiles, + gradient_reference_length=3.0, + gyrobohm_flux_reference_length=1.0, )