Skip to content

Commit

Permalink
Merge pull request #422 from HajimeKawahara/transmission_bugfix
Browse files Browse the repository at this point in the history
r was inverse
  • Loading branch information
HajimeKawahara authored Sep 30, 2023
2 parents 9b202d1 + 0b55c81 commit 6d2ed21
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
25 changes: 16 additions & 9 deletions src/exojax/atm/atmprof.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def pressure_layer_logspace(log_pressure_top=-8.,
return pressure, dParr, k


def pressure_upper_logspace(pressure, pressure_decrease_rate, reference_point=0.5):
def pressure_upper_logspace(pressure,
pressure_decrease_rate,
reference_point=0.5):
"""computes pressure at the upper point of the layers
Args:
Expand All @@ -58,10 +60,12 @@ def pressure_upper_logspace(pressure, pressure_decrease_rate, reference_point=0.
Returns:
_type_: pressure at the upper point (\overline{P}_i)
"""
return (pressure_decrease_rate**reference_point)*pressure
return (pressure_decrease_rate**reference_point) * pressure


def pressure_lower_logspace(pressure, pressure_decrease_rate, reference_point=0.5):
def pressure_lower_logspace(pressure,
pressure_decrease_rate,
reference_point=0.5):
"""computes pressure at the lower point of the layers
Args:
Expand All @@ -72,10 +76,13 @@ def pressure_lower_logspace(pressure, pressure_decrease_rate, reference_point=0.
Returns:
_type_: pressure at the lower point (underline{P}_i)
"""
return (pressure_decrease_rate**(reference_point-1.0))*pressure
return (pressure_decrease_rate**(reference_point - 1.0)) * pressure


def pressure_boundary_logspace(pressure, pressure_decrease_rate, reference_point=0.5, numpy=False):
def pressure_boundary_logspace(pressure,
pressure_decrease_rate,
reference_point=0.5,
numpy=False):
"""computes pressure at the boundary of the layers (Nlayer + 1)
Args:
Expand All @@ -87,10 +94,10 @@ def pressure_boundary_logspace(pressure, pressure_decrease_rate, reference_point
Returns:
_type_: pressure at the boundary (Nlayer + 1)
"""
pressure_bottom_boundary = (
pressure_decrease_rate**(reference_point-1.0))*pressure[-1]
pressure_upper = pressure_upper_logspace(
pressure, pressure_decrease_rate, reference_point)
pressure_bottom_boundary = (pressure_decrease_rate
**(reference_point - 1.0)) * pressure[-1]
pressure_upper = pressure_upper_logspace(pressure, pressure_decrease_rate,
reference_point)
if numpy:
return np.append(pressure_upper, pressure_bottom_boundary)
else:
Expand Down
30 changes: 15 additions & 15 deletions src/exojax/spec/rtransfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def initialize_gaussian_quadrature(nstream):

return mus, weights


@jit
def rtrun_emis_pureabs_ibased_linsap(dtau, source_matrix_boundary, mus,
weights):
Expand All @@ -167,37 +168,36 @@ def rtrun_emis_pureabs_ibased_linsap(dtau, source_matrix_boundary, mus,
source_matrix_boundary_p1 = jnp.roll(source_matrix_boundary, -1,
axis=0) # S_{n+1}


# NOT IMPLEMENTED YET
# need to replace the last element of the above
#

#scan part
muws = [mus, weights]



def f(carry_fmu, muw):
mu, w = muw
dtau_per_mu = dtau/mu
trans = jnp.exp(-dtau_per_mu) # hat{T}
dtau_per_mu = dtau / mu
trans = jnp.exp(-dtau_per_mu) # hat{T}
beta, gamma = coeffs_linsap(dtau_per_mu, trans)

#adds coeffs at the bottom of the layers
beta = jnp.vstack([beta,jnp.ones(Nnus)])
gamma = jnp.vstack([gamma,jnp.zeros(Nnus)])
beta = jnp.vstack([beta, jnp.ones(Nnus)])
gamma = jnp.vstack([gamma, jnp.zeros(Nnus)])

dI = beta * source_matrix_boundary + gamma * source_matrix_boundary_p1
intensity_for_mu = jnp.sum(dI *
jnp.cumprod(jnp.vstack([jnp.ones(Nnus), trans]), axis=0),
axis=0)
intensity_for_mu = jnp.sum(
dI * jnp.cumprod(jnp.vstack([jnp.ones(Nnus), trans]), axis=0),
axis=0)

carry_fmu = carry_fmu + 2.0 * mu * w * intensity_for_mu

return carry_fmu, None

spec, _ = scan(f, jnp.zeros(Nnus), muws)
return spec


def coeffs_linsap(dtau_per_mu, trans):
"""coefficients of the linsap
Expand Down Expand Up @@ -233,9 +233,9 @@ def rtrun_trans_pureabs(dtau_chord, radius_lower):
If you would like to compute the transit depth, devide the output by the square of stellar radius
"""
deltaRp2 = 2.0 * jnp.trapz(
(1.0 - jnp.exp(-dtau_chord)) * radius_lower[::-1, None],
x=radius_lower[::-1],
deltaRp2 = -2.0 * jnp.trapz(
(1.0 - jnp.exp(-dtau_chord)) * radius_lower[:, None],
x=radius_lower,
axis=0)
return deltaRp2 + radius_lower[-1]**2

Expand Down

0 comments on commit 6d2ed21

Please sign in to comment.