Skip to content

Commit

Permalink
Update hierarchical.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WuShichao authored Aug 28, 2023
1 parent 7423058 commit fef5a3f
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions pycbc/inference/models/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .base import BaseModel
from .relbin import RelativeTimeDom
from .relbin_cpu import snr_predictor_dom
from .tools import DistMarg
import tqdm

#
# =============================================================================
Expand Down Expand Up @@ -607,7 +607,7 @@ def _loglikelihood(self):
return logl


class MultibandRelativeTimeDom(HierarchicalModel, DistMarg):
class MultibandRelativeTimeDom(HierarchicalModel):
""" Hierarchical heterodyne likelihood for coherent multiband
parameter estimation which combines data from space-borne and
ground-based GW detectors coherently. Currently, this only
Expand All @@ -618,9 +618,7 @@ class MultibandRelativeTimeDom(HierarchicalModel, DistMarg):
space-borne and ground-based GW detectors, then transform all
the parameters into the same frame in the sub model level, use
`HierarchicalModel` to get the joint likelihood, and marginalize
over all the extrinsic parameters supported by `RelativeTimeDom`
on the multiband likelihood level.
over all the extrinsic parameters supported by `RelativeTimeDom`.
Note that LISA submodel only supports the `Relative` for now,
for ground-based detectors, please use `RelativeTimeDom`.
"""
Expand All @@ -630,22 +628,11 @@ def __init__(self, variable_params, submodels, **kwargs):
super().__init__(variable_params, submodels, **kwargs)

# We assume the ground-based submodel as the primary model.
lbl_list = list(self.submodels.keys())
lbl_primary = lbl_list[0]
self.primary_model = self.submodels[lbl_primary]
if self.primary_model.still_needs_det_response:
lbl_primary = lbl_list[1]
self.primary_model = self.submodels[lbl_primary]
self.primary_model = self.submodels[kwargs['primary_lbl'][0]]
self.other_models = self.submodels.copy()
self.other_models.pop(lbl_primary)
self.other_models.pop(kwargs['primary_lbl'][0])
self.other_models = list(self.other_models.values())

# marginalize on multiband likelihood level
marginalize_phase = self.primary_model.marginalize_phase
variable_params, kwargs = self.setup_marginalization(
variable_params,
marginalize_phase=marginalize_phase,
**kwargs)
self.other_models_labels = kwargs['others_lbls']

def write_metadata(self, fp, group=None):
"""Adds metadata to the output files
Expand Down Expand Up @@ -694,24 +681,35 @@ def _loglr(self):

# note that for SOBHB signals, ground-based detectors dominant SNR
# and accuracy of (tc, ra, dec)
sh_ground, hh_ground = self.primary_model.loglr(just_sh_hh=True)
sh_primary, hh_primary = self.primary_model.loglr(just_sh_hh=True)

nums = self.primary_model.vsamples
margin_params = self.primary_model.marginalize_vector_params.copy()
margin_params.pop('logw_partial')

# add likelihood contribution from space-borne detectors, we
# calculate sh/hh for each extrinsic parameter

logging.info("Calculating sh/hh for space-borne detectors")
parameter_set = self.primary_model.current_params
nums = self.primary_model.marginalize_vector_samples
sh_space = numpy.zeros(nums)
hh_space = numpy.zeros(nums)
for parameters in parameter_set:
i = numpy.where(parameter_set==parameters)[0][0]
self.other_models.update(**parameters)
sh_space[i], hh_space[i] = self.other_models.loglr(
just_sh_hh=True)

sh_total = sh_ground + sh_space
hh_total = hh_ground + hh_space
# calculate sh/hh for each marginalized parameter
logging.info("Calculating sh/hh for space-borne detector(s)")
sh_others = numpy.zeros(nums)
hh_others = numpy.zeros(nums)

for label_i, other_model in enumerate(self.other_models):
logging.info("============= %s =============",
self.other_models_labels[label_i])
current_params_other = other_model.current_params.copy()
# there are still some values in margin_params
for p in margin_params:
current_params_other.pop(p)
for i in tqdm.tqdm(range(nums)):
parameters = current_params_other.copy()
parameters.update(
{key: value[i] for key, value in margin_params.items()})
other_model.update(**parameters)
sh_others[i], hh_others[i] = other_model.loglr(
just_sh_hh=True)

sh_total = sh_primary + sh_others
hh_total = hh_primary + hh_others
loglr = self.primary_model.marginalize_loglr(sh_total, hh_total)
return loglr

Expand Down Expand Up @@ -750,7 +748,9 @@ def from_config(cls, cp, **kwargs):
# circular imports, we import it here
from pycbc.inference.models import read_from_config
# get the submodels
submodel_lbls = shlex.split(cp.get('model', 'submodels'))
kwargs['primary_lbl'] = shlex.split(cp.get('model', 'primary_model'))
kwargs['others_lbls'] = shlex.split(cp.get('model', 'other_models'))
submodel_lbls = kwargs['primary_lbl'] + kwargs['others_lbls']
# sort parameters by model
vparam_map = map_params(hpiter(cp.options('variable_params'),
submodel_lbls))
Expand Down Expand Up @@ -821,11 +821,10 @@ def from_config(cls, cp, **kwargs):
subcp[prior_section_name] = cp[prior_section_name]

# initialize
kwargs['not_marginalize_submodel'] = True
submodel = read_from_config(subcp, **kwargs)
submodel = read_from_config(subcp)
submodels[lbl] = submodel
logging.info("")
# now load the model
logging.info("Loading multiband_relative_time_dom model")
return super(HierarchicalModel, cls).from_config(
cp, submodels=submodels)
cp, submodels=submodels, **kwargs)

0 comments on commit fef5a3f

Please sign in to comment.