Skip to content

Commit

Permalink
Sketch DN decay logic and fix some obvious issues in the DN xs logic
Browse files Browse the repository at this point in the history
  • Loading branch information
austinschneider committed Jun 6, 2024
1 parent 88b287e commit 7bde698
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 93 deletions.
52 changes: 22 additions & 30 deletions resources/CrossSections/DarkNewsTables/DarkNewsCrossSection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,37 @@ def __init__(

def load_from_table(self, table_dir):
# Make the table directory where will we store cross section tables
table_dir_exists = False
if os.path.exists(table_dir):
# print("Directory '%s' already exists"%table_dir)
table_dir_exists = True
else:
if not os.path.exists(table_dir):
try:
os.makedirs(table_dir, exist_ok=False)
# print("Directory '%s' created successfully" % table_dir)
except OSError as error:
raise RuntimeError("Directory '%s' cannot be created" % table_dir)

# Look in table dir and check whether total/differential xsec tables exist
if table_dir_exists:
total_xsec_file = os.path.join(table_dir, "total_cross_sections.npy")
if os.path.exists(total_xsec_file):
self.total_cross_section_table = np.load(total_xsec_file)
diff_xsec_file = os.path.join(
table_dir, "differential_cross_sections.npy"
)
if os.path.exists(diff_xsec_file):
self.differential_cross_section_table = np.load(diff_xsec_file)
total_xsec_file = os.path.join(table_dir, "total_cross_sections.npy")
if os.path.exists(total_xsec_file):
self.total_cross_section_table = np.load(total_xsec_file)
diff_xsec_file = os.path.join(
table_dir, "differential_cross_sections.npy"
)
if os.path.exists(diff_xsec_file):
self.differential_cross_section_table = np.load(diff_xsec_file)

self.configure()

def save_to_table(self, table_dir, total=True, diff=True):
if total:
self._redefine_interpolation_objects(total=True)
with open(
os.path.join(table_dir, "total_cross_sections.npy"), "wb"
) as f:
np.save(f, self.total_cross_section_table)
if diff:
self._redefine_interpolation_objects(diff=True)
with open(
os.path.join(table_dir, "differential_cross_sections.npy"), "wb"
) as f:
np.save(f, self.differential_cross_section_table)

# serialization method
def get_representation(self):
Expand Down Expand Up @@ -313,21 +320,6 @@ def FillInterpolationTables(self, total=True, diff=True, factor=0.8, Emax=None):
self._redefine_interpolation_objects(total=total, diff=diff)
return num_added_points

# Saves the tables for the scipy interpolation objects
def SaveInterpolationTables(self, table_dir, total=True, diff=True):
if total:
self._redefine_interpolation_objects(total=True)
with open(
os.path.join(table_dir, "total_cross_sections.npy"), "wb"
) as f:
np.save(f, self.total_cross_section_table)
if diff:
self._redefine_interpolation_objects(diff=True)
with open(
os.path.join(table_dir, "differential_cross_sections.npy"), "wb"
) as f:
np.save(f, self.differential_cross_section_table)

def GetPossiblePrimaries(self):
return [Particle.ParticleType(self.ups_case.nu_projectile.pdgid)]

Expand Down
52 changes: 19 additions & 33 deletions resources/CrossSections/DarkNewsTables/DarkNewsDecay.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,49 +27,35 @@ def __init__(self, dec_case):
self.total_width = None

def load_from_table(self, table_dir):
if table_dir is None:
print(
"No table_dir specified; will sample from new VEGAS integrator for each decay"
)
print("WARNING: this will siginficantly slow down event generation")
return

# Make the table directory where will we store cross section integrators
table_dir_exists = False
if os.path.exists(table_dir):
# print("Directory '%s' already exists"%table_dir)
table_dir_exists = True
else:
if not os.path.exists(table_dir):
try:
os.makedirs(table_dir, exist_ok=False)
print("Directory '%s' created successfully" % table_dir)
except OSError as error:
print("Directory '%s' cannot be created" % table_dir)
exit(0)
raise RuntimeError("Directory '%s' cannot be created" % table_dir)

# Try to find the decay integrator
int_file = os.path.join(table_dir, "decay_integrator.pkl")
if os.path.isfile(int_file):
with open(int_file, "rb") as ifile:
self.decay_integrator = pickle.load(ifile)
# Try to find the normalization information
norm_file = os.path.join(table_dir, "decay_norm.json")
if os.path.isfile(norm_file):
with open(
norm_file,
) as nfile:
self.decay_norm = json.load(nfile)
decay_file = os.path.join(table_dir, "decay.pkl")
if os.path.isfile(decay_file):
with open(decay_file, "rb") as f:
self.decay_norm, self.decay_integrator = pickle.load(f)

def save_to_table(self, table_dir):
with open(os.path.join(table_dir, "decay.pkl") as f:
pickle.dump(f, {
"decay_integrator": self.decay_integrator,
"decay_norm": self.decay_norm
})

# serialization method
def get_representation(self):
return {"decay_integrator":self.decay_integrator,
"decay_norm":self.decay_norm,
"dec_case":self.dec_case,
"PS_samples":self.PS_samples,
"PS_weights":self.PS_weights,
"PS_weights_CDF":self.PS_weights_CDF,
"total_width":self.total_width,
return {"decay_integrator": self.decay_integrator,
"decay_norm": self.decay_norm,
"dec_case": self.dec_case,
"PS_samples": self.PS_samples,
"PS_weights": self.PS_weights,
"PS_weights_CDF": self.PS_weights_CDF,
"total_width": self.total_width,
}

def SetIntegratorAndNorm(self, decay_norm, decay_integrator):
Expand Down
182 changes: 152 additions & 30 deletions resources/CrossSections/DarkNewsTables/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
f"DarkNewsTables-v{siren.utilities.darknews_version()}", must_exist=False
)


def GetDetectorModelTargets(detector_model):
"""
Determines the targets that exist inside the detector model
Expand Down Expand Up @@ -117,15 +118,13 @@ def attempt_to_load_cross_section(
table_subdir = os.path.join(table_dir, subdir)
if os.path.isdir(table_subdir):
try:
cross_section = append(
load_cross_section_from_table(
models,
ups_key,
table_subdir,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)
cross_section = load_cross_section_from_table(
models,
ups_key,
table_subdir,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)
loaded = True
except Exception as e:
Expand All @@ -138,14 +137,12 @@ def attempt_to_load_cross_section(
table_subdir = os.path.join(table_dir, subdir)
if os.path.isdir(table_subdir):
try:
cross_section = append(
load_cross_section_from_pickle(
ups_key,
table_subdir,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)
cross_section = load_cross_section_from_pickle(
ups_key,
table_subdir,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)
loaded = True
except Exception as e:
Expand All @@ -156,14 +153,12 @@ def attempt_to_load_cross_section(
break
elif p == "normal":
try:
cross_sections = append(
load_cross_section(
models,
ups_key,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)
cross_sections = load_cross_section(
models,
ups_key,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)
loaded = True
except Exception as e:
Expand Down Expand Up @@ -201,6 +196,125 @@ def load_cross_sections(
return cross_sections


def load_decay(
model_container,
decay_key,
):
if decay_key not in model_container.dec_cases:
raise KeyError(
f'Decay key "{decay_key}" not present in model_container.dec_cases'
)
decay_model = model_container.dec_cases[decay_key]
return PyDarkNewsDecay(
decay_model,
)


def load_decay_from_table(
model_container,
decay_key,
table_dir,
):
subdir = "_".join(["Decay"] + [str(x) for x in decay_key])
table_subdir = os.path.join(table_dir, subdir)

decay = load_decay(
model_container,
decay_key,
)
decay.load_from_table(table_subdir)
return decay


def load_decay_from_pickle(
decay_key,
table_dir,
):
subdir = "_".join(["Decay"] + [str(x) for x in decay_key])
table_subdir = os.path.join(table_dir, subdir)
fname = os.path.join(table_dir, "dec_object.pkl")
with open(fname, "rb") as f:
dec_obj = pickle.load(f)
return dec_obj


def attempt_to_load_decay(
models,
dec_key,
tabel_dir,
preferences,
):
if len(preferences) == 0:
raise ValueError("preferences must have at least one entry")

subdir = "_".join(["Decay"] + [str(x) for x in dec_key])
loaded = False
decay = None
for p in preferences:
if p == "table":
table_subdir = os.path.join(table_dir, subdir)
if os.path.isdir(table_subdir):
try:
decay = load_decay_from_table(
models,
dec_key,
table_subdir,
)
loaded = True
except Exception as e:
print("Encountered exception while loading DN decay from table")
raise e from None
break
elif p == "pickle":
table_subdir = os.path.join(table_dir, subdir)
if os.path.isdir(table_subdir):
try:
decay = load_decay_from_pickle(
ups_key,
table_dir,
)
loaded = True
except Exception as e:
print("Encountered exception while loading DN decay from pickle")
raise e from None
break
elif p == "normal":
try:
decay = load_decay(
models,
dec_key,
)
loaded = True
except Exception as e:
print("Encountered exception while loading DN decay normally")
raise e from None
break

if not loaded:
raise RuntimeError("Not able to load DN decay with any strategy")
return decay


def load_decays(
model_kwargs,
table_dir=None,
preferences=None,
):
if preferences is None:
preferences = ["table", "pickle", "normal"]

models = ModelContainer(**model_kwargs)

if table_dir is None:
table_dir = ""

decays = []
for dec_key, dec_case in models.dec_cases.items():
decays.append(attempt_to_load_decy(models, dec_key, table_dir, preferences))

return decays


def load_processes(
primary_type=None,
target_types=None,
Expand Down Expand Up @@ -245,11 +359,19 @@ def load_processes(
}

cross_sections = load_cross_sections(
model_kwargs,
table_dir=None,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
model_kwargs,
table_dir=None,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)

decays = load_decays(
model_kwargs,
table_dir=None,
tolerance=tolerance,
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)

if fill_tables_at_start:
Expand Down

0 comments on commit 7bde698

Please sign in to comment.