Skip to content

Commit

Permalink
Fix imports. loader -> logger. Instance checking instead of type chec…
Browse files Browse the repository at this point in the history
…king. non-zero CDF check.
  • Loading branch information
austinschneider committed Aug 29, 2024
1 parent f6a3762 commit 45bf92a
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions resources/Processes/DarkNewsTables/DarkNewsDecay.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import numpy as np
import functools
import pickle

from siren import _util

base_path = os.path.dirname(os.path.abspath(__file__))
loader_file = os.path.join(base_path, "loader.py")
siren._util.load_module("loader", loader_file)
logger_file = os.path.join(base_path, "logger.py")
_util.load_module("logger", logger_file)

# SIREN methods
from siren.interactions import DarkNewsDecay
Expand Down Expand Up @@ -41,7 +43,7 @@ def load_from_table(self, table_dir):
self.decay_norm, self.decay_integrator = pickle.load(f)

def save_to_table(self, table_dir):
with open(os.path.join(table_dir, "decay.pkl") as f:
with open(os.path.join(table_dir, "decay.pkl")) as f:
pickle.dump(f, {
"decay_integrator": self.decay_integrator,
"decay_norm": self.decay_norm
Expand Down Expand Up @@ -94,7 +96,7 @@ def DifferentialDecayWidth(self, record):
# Momentum variables of HNL necessary for calculating decay phase space
PN = np.array(record.primary_momentum)

if type(self.dec_case) == FermionSinglePhotonDecay:
if isinstance(self.dec_case, FermionSinglePhotonDecay):
gamma_idx = 0
for secondary in record.signature.secondary_types:
if secondary == dataclasses.Particle.ParticleType.Gamma:
Expand All @@ -107,7 +109,7 @@ def DifferentialDecayWidth(self, record):
Pgamma = np.array(record.secondary_momenta[gamma_idx])
momenta = np.expand_dims(PN, 0), np.expand_dims(Pgamma, 0)

elif type(self.dec_case) == FermionDileptonDecay:
elif isinstance(self.dec_case, FermionDileptonDecay):
lepminus_idx = -1
lepplus_idx = -1
nu_idx = -1
Expand Down Expand Up @@ -144,9 +146,9 @@ def DifferentialDecayWidth(self, record):
return self.dec_case.differential_width(momenta)

def TotalDecayWidth(self, arg1):
if type(arg1) == dataclasses.InteractionRecord:
if isinstance(arg1, dataclasses.InteractionRecord):
primary = arg1.signature.primary_type
elif type(arg1) == dataclasses.Particle.ParticleType:
elif isinstance(arg1, dataclasses.Particle.ParticleType):
primary = arg1
else:
print("Incorrect function call to TotalDecayWidth!")
Expand All @@ -155,7 +157,7 @@ def TotalDecayWidth(self, arg1):
return 0
if self.total_width is None:
# Need to set the total width
if type(self.dec_case) == FermionDileptonDecay and (
if isinstance(self.dec_case, FermionDileptonDecay) and (
self.dec_case.vector_off_shell and self.dec_case.scalar_off_shell
):
# total width calculation requires evaluating an integral
Expand Down Expand Up @@ -194,9 +196,9 @@ def TotalDecayWidthForFinalState(self, record):
return ret

def DensityVariables(self):
if type(self.dec_case) == FermionSinglePhotonDecay:
if isinstance(self.dec_case, FermionSinglePhotonDecay):
return "cost"
elif type(self.dec_case) == FermionDileptonDecay:
elif isinstance(self.dec_case, FermionDileptonDecay):
if self.dec_case.vector_on_shell and self.dec_case.scalar_on_shell:
print("Can't have both the scalar and vector on shell")
exit(0)
Expand All @@ -223,6 +225,23 @@ def GetPSSample(self, random):
PSidx = np.argmax(x - self.PS_weights_CDF <= 0)
return self.PS_samples[:, PSidx]

def GetPSSample(self, random):
# Make the PS weight CDF if that hasn't been done
if self.PS_weights_CDF is None:
self.PS_weights_CDF = np.cumsum(self.PS_weights)

# Check that the CDF makes sense
total_weight = self.PS_weights_CDF[-1]
if total_weight == 0:
raise ValueError("Total weight is zero, cannot sample")

# Random number to determine
x = random.Uniform(0, total_weight)

# find first instance of a CDF entry greater than x
PSidx = np.argmax(x - self.PS_weights_CDF <= 0)
return self.PS_samples[:, PSidx]

def SampleRecordFromDarkNews(self, record, random):
# First, make sure we have PS samples and weights
if self.PS_samples is None or self.PS_weights is None:
Expand Down Expand Up @@ -254,7 +273,7 @@ def SampleRecordFromDarkNews(self, record, random):

secondaries = record.GetSecondaryParticleRecords()

if type(self.dec_case) == FermionSinglePhotonDecay:
if isinstance(self.dec_case, FermionSinglePhotonDecay):
gamma_idx = 0
for secondary in record.signature.secondary_types:
if secondary == dataclasses.Particle.ParticleType.Gamma:
Expand All @@ -269,7 +288,7 @@ def SampleRecordFromDarkNews(self, record, random):
secondaries[nu_idx].four_momentum = np.squeeze(four_momenta["P_decay_N_daughter"])
secondaries[nu_idx].mass = 0

elif type(self.dec_case) == FermionDileptonDecay:
elif isinstance(self.dec_case, FermionDileptonDecay):
lepminus_idx = -1
lepplus_idx = -1
nu_idx = -1
Expand Down

0 comments on commit 45bf92a

Please sign in to comment.