diff --git a/src/exojax/atm/atmprof.py b/src/exojax/atm/atmprof.py index 331e929e..b28e7256 100644 --- a/src/exojax/atm/atmprof.py +++ b/src/exojax/atm/atmprof.py @@ -47,7 +47,9 @@ def pressure_layer_logspace(log_pressure_top=-8., return pressure, dParr, k -def pressure_upper_logspace(pressure, pressure_decrease_rate, reference_point=0.5): +def pressure_upper_logspace(pressure, + pressure_decrease_rate, + reference_point=0.5): """computes pressure at the upper point of the layers Args: @@ -58,10 +60,12 @@ def pressure_upper_logspace(pressure, pressure_decrease_rate, reference_point=0. Returns: _type_: pressure at the upper point (\overline{P}_i) """ - return (pressure_decrease_rate**reference_point)*pressure + return (pressure_decrease_rate**reference_point) * pressure -def pressure_lower_logspace(pressure, pressure_decrease_rate, reference_point=0.5): +def pressure_lower_logspace(pressure, + pressure_decrease_rate, + reference_point=0.5): """computes pressure at the lower point of the layers Args: @@ -72,10 +76,13 @@ def pressure_lower_logspace(pressure, pressure_decrease_rate, reference_point=0. Returns: _type_: pressure at the lower point (underline{P}_i) """ - return (pressure_decrease_rate**(reference_point-1.0))*pressure + return (pressure_decrease_rate**(reference_point - 1.0)) * pressure -def pressure_boundary_logspace(pressure, pressure_decrease_rate, reference_point=0.5, numpy=False): +def pressure_boundary_logspace(pressure, + pressure_decrease_rate, + reference_point=0.5, + numpy=False): """computes pressure at the boundary of the layers (Nlayer + 1) Args: @@ -87,10 +94,10 @@ def pressure_boundary_logspace(pressure, pressure_decrease_rate, reference_point Returns: _type_: pressure at the boundary (Nlayer + 1) """ - pressure_bottom_boundary = ( - pressure_decrease_rate**(reference_point-1.0))*pressure[-1] - pressure_upper = pressure_upper_logspace( - pressure, pressure_decrease_rate, reference_point) + pressure_bottom_boundary = (pressure_decrease_rate + **(reference_point - 1.0)) * pressure[-1] + pressure_upper = pressure_upper_logspace(pressure, pressure_decrease_rate, + reference_point) if numpy: return np.append(pressure_upper, pressure_bottom_boundary) else: diff --git a/src/exojax/spec/rtransfer.py b/src/exojax/spec/rtransfer.py index eef5a5cc..8af5278d 100644 --- a/src/exojax/spec/rtransfer.py +++ b/src/exojax/spec/rtransfer.py @@ -143,6 +143,7 @@ def initialize_gaussian_quadrature(nstream): return mus, weights + @jit def rtrun_emis_pureabs_ibased_linsap(dtau, source_matrix_boundary, mus, weights): @@ -167,30 +168,28 @@ def rtrun_emis_pureabs_ibased_linsap(dtau, source_matrix_boundary, mus, source_matrix_boundary_p1 = jnp.roll(source_matrix_boundary, -1, axis=0) # S_{n+1} - # NOT IMPLEMENTED YET # need to replace the last element of the above # #scan part muws = [mus, weights] - - + def f(carry_fmu, muw): mu, w = muw - dtau_per_mu = dtau/mu - trans = jnp.exp(-dtau_per_mu) # hat{T} + dtau_per_mu = dtau / mu + trans = jnp.exp(-dtau_per_mu) # hat{T} beta, gamma = coeffs_linsap(dtau_per_mu, trans) - + #adds coeffs at the bottom of the layers - beta = jnp.vstack([beta,jnp.ones(Nnus)]) - gamma = jnp.vstack([gamma,jnp.zeros(Nnus)]) + beta = jnp.vstack([beta, jnp.ones(Nnus)]) + gamma = jnp.vstack([gamma, jnp.zeros(Nnus)]) dI = beta * source_matrix_boundary + gamma * source_matrix_boundary_p1 - intensity_for_mu = jnp.sum(dI * - jnp.cumprod(jnp.vstack([jnp.ones(Nnus), trans]), axis=0), - axis=0) - + intensity_for_mu = jnp.sum( + dI * jnp.cumprod(jnp.vstack([jnp.ones(Nnus), trans]), axis=0), + axis=0) + carry_fmu = carry_fmu + 2.0 * mu * w * intensity_for_mu return carry_fmu, None @@ -198,6 +197,7 @@ def f(carry_fmu, muw): spec, _ = scan(f, jnp.zeros(Nnus), muws) return spec + def coeffs_linsap(dtau_per_mu, trans): """coefficients of the linsap @@ -233,9 +233,9 @@ def rtrun_trans_pureabs(dtau_chord, radius_lower): If you would like to compute the transit depth, devide the output by the square of stellar radius """ - deltaRp2 = 2.0 * jnp.trapz( - (1.0 - jnp.exp(-dtau_chord)) * radius_lower[::-1, None], - x=radius_lower[::-1], + deltaRp2 = -2.0 * jnp.trapz( + (1.0 - jnp.exp(-dtau_chord)) * radius_lower[:, None], + x=radius_lower, axis=0) return deltaRp2 + radius_lower[-1]**2