Skip to content

Commit

Permalink
add in several missing marginalization combinations (gwastro#4376)
Browse files Browse the repository at this point in the history
* allow polarization marginalization with earth rotation

* add in missing marg combos

* fixes

* fixes

* fixes

* cleanup

* cc

* cc
  • Loading branch information
ahnitz authored and PRAVEEN-mnl committed Jun 19, 2023
1 parent f49620a commit e188b2d
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 25 deletions.
93 changes: 68 additions & 25 deletions pycbc/inference/models/relbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from .relbin_cpu import (likelihood_parts, likelihood_parts_v,
likelihood_parts_multi, likelihood_parts_multi_v,
likelihood_parts_det, likelihood_parts_vector,
likelihood_parts_v_pol,
likelihood_parts_v_time,
likelihood_parts_v_pol_time,
likelihood_parts_vectorp, snr_predictor,
likelihood_parts_vectort,
snr_predictor_dom)
from .tools import DistMarg

Expand Down Expand Up @@ -340,6 +344,7 @@ def combine_layout(self):

def setup_antenna(self, earth_rotation, mode, fedges):
# Calculate the times to evaluate fp/fc
self.earth_rotation = earth_rotation
if earth_rotation is not False:
logging.info("Enabling frequency-dependent earth rotation")
from pycbc.waveform.spa_tmplt import spa_length_in_time
Expand All @@ -364,15 +369,32 @@ def setup_antenna(self, earth_rotation, mode, fedges):

@property
def likelihood_function(self):
self.lformat = None
if self.marginalize_vector_params:
p = self.current_params

for k in ['ra', 'dec', 'tc']:
if k in p and not numpy.isscalar(p[k]):
vmarg = set(k for k in self.marginalize_vector_params
if not numpy.isscalar(p[k]))

if self.earth_rotation:
if set(['tc', 'polarization']).issubset(vmarg):
self.lformat = 'earth_time_pol'
return likelihood_parts_v_pol_time
elif set(['polarization']).issubset(vmarg):
self.lformat = 'earth_pol'
return likelihood_parts_v_pol
elif set(['tc']).issubset(vmarg):
self.lformat = 'earth_time'
return likelihood_parts_v_time
else:
if set(['ra', 'dec', 'tc']).issubset(vmarg):
return likelihood_parts_vector

if 'polarization' in p and not numpy.isscalar(p['polarization']):
return likelihood_parts_vectorp
elif set(['tc', 'polarization']).issubset(vmarg):
return likelihood_parts_vector
elif set(['tc']).issubset(vmarg):
return likelihood_parts_vectort
elif set(['polarization']).issubset(vmarg):
return likelihood_parts_vectorp

return self.lik

Expand Down Expand Up @@ -531,14 +553,21 @@ def _loglr(self):
dt = det.time_delay_from_earth_center(p["ra"], p["dec"], times)
dtc = p["tc"] + dt - end_time - self.ta[ifo]

f = (fp + 1.0j * fc) * pol_phase
fp = f.real.copy()
fc = f.imag.copy()
filter_i, norm_i = lik(freqs, fp, fc, dtc,
hp, hc, h00,
sdat['a0'], sdat['a1'],
sdat['b0'], sdat['b1'])
self._current_wf_parts[ifo] = (fp, fc, dtc, hp, hc, h00)
if self.lformat == 'earth_pol':
filter_i, norm_i = lik(freqs, fp, fc, dtc, pol_phase,
hp, hc, h00,
sdat['a0'], sdat['a1'],
sdat['b0'], sdat['b1'])
else:
f = (fp + 1.0j * fc) * pol_phase
fp = f.real.copy()
fc = f.imag.copy()
filter_i, norm_i = lik(freqs, fp, fc, dtc,
hp, hc, h00,
sdat['a0'], sdat['a1'],
sdat['b0'], sdat['b1'])
self._current_wf_parts[ifo] = (fp, fc, dtc, hp, hc, h00)

filt += filter_i
norm += norm_i
loglr = self.marginalize_loglr(filt, norm)
Expand Down Expand Up @@ -649,14 +678,14 @@ def get_snr(self, wfs):
dtc = self.tstart[ifo] - self.end_time[ifo] - self.ta[ifo]

snr = snr_predictor(self.fedges[ifo],
dtc, delta_t,
self.num_samples[ifo],
dtc - delta_t * 2.0, delta_t,
self.num_samples[ifo] + 4,
wfs[ifo][0], wfs[ifo][1],
self.h00_sparse[ifo],
sdat['a0'], sdat['a1'],
sdat['b0'], sdat['b1'])
snrs[ifo] = TimeSeries(snr, delta_t=delta_t,
epoch=self.tstart[ifo])
epoch=self.tstart[ifo] - delta_t * 2.0)
return snrs

def _loglr(self):
Expand Down Expand Up @@ -697,16 +726,30 @@ def _loglr(self):
det = self.det[ifo]
fp, fc = det.antenna_pattern(p["ra"], p["dec"],
0, times)
dt = det.time_delay_from_earth_center(p["ra"], p["dec"], times)
dtc = p["tc"] + dt - end_time - self.ta[ifo]
times = det.time_delay_from_earth_center(p["ra"], p["dec"], times)
dtc = p["tc"] - end_time - self.ta[ifo]

f = (fp + 1.0j * fc) * pol_phase
fp = f.real.copy()
fc = f.imag.copy()
filter_i, norm_i = lik(freqs, fp, fc, dtc,
hp, hc, h00,
sdat['a0'], sdat['a1'],
sdat['b0'], sdat['b1'])
if self.lformat == 'earth_time_pol':
filter_i, norm_i = lik(
freqs, fp, fc, times, dtc, pol_phase,
hp, hc, h00,
sdat['a0'], sdat['a1'],
sdat['b0'], sdat['b1'])
else:
f = (fp + 1.0j * fc) * pol_phase
fp = f.real.copy()
fc = f.imag.copy()
if self.lformat == 'earth_time':
filter_i, norm_i = lik(
freqs, fp, fc, times, dtc,
hp, hc, h00,
sdat['a0'], sdat['a1'],
sdat['b0'], sdat['b1'])
else:
filter_i, norm_i = lik(freqs, fp, fc, times + dtc,
hp, hc, h00,
sdat['a0'], sdat['a1'],
sdat['b0'], sdat['b1'])
filt += filter_i
norm += norm_i
loglr = self.marginalize_loglr(filt, norm)
Expand Down
212 changes: 212 additions & 0 deletions pycbc/inference/models/relbin_cpu.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cdef extern from "complex.h":
double norm(double complex)
double complex conj(double complex)
double real(double complex)
double imag(double complex)

import numpy
cimport cython, numpy
Expand Down Expand Up @@ -180,6 +181,171 @@ cpdef likelihood_parts_v(double [::1] freqs,
x0 = x0n
return conj(hd), hh

# Used where the antenna response may be frequency varying
# and there is a polarization vector marginalization
@cython.boundscheck(False) # Deactivate bounds checking
@cython.wraparound(False) # Deactivate negative indexing.
@cython.cdivision(True) # Disable checking for dividing by zero
cpdef likelihood_parts_v_pol(double [::1] freqs,
double[::1] fp,
double[::1] fc,
double[::1] dtc,
double complex[::1] pol_phase,
double complex[::1] hp,
double complex[::1] hc,
double complex[::1] h00,
double complex[::1] a0,
double complex[::1] a1,
double [::1] b0,
double [::1] b1,
) :
cdef size_t i
cdef double complex hd=0, r0, r0n, r1, x0, x0n, x1, fp2, fc2
cdef double hh=0

N = freqs.shape[0]
num_samples = pol_phase.shape[0]

cdef numpy.ndarray[numpy.complex128_t, ndim=1] hdv = numpy.empty(num_samples, dtype=numpy.complex128)
cdef numpy.ndarray[numpy.float64_t, ndim=1] hhv = numpy.empty(num_samples, dtype=numpy.float64)

for j in range(num_samples):
hh = 0
hd = 0
for i in range(N):

f = (fp[i] + 1.0j * fc[i]) * pol_phase[j]
fp2 = real(f)
fc2 = imag(f)

r0n = (exp(-2.0j * 3.141592653 * dtc[i] * freqs[i])
* (fp2 * hp[i] + fc2 * hc[i])) / h00[i]
r1 = r0n - r0

x0n = norm(r0n)
x1 = x0n - x0

if i > 0:
hd += a0[i-1] * r0 + a1[i-1] * r1
hh += real(b0[i-1] * x0 + b1[i-1] * x1)

r0 = r0n
x0 = x0n

hdv[j] = conj(hd)
hhv[j] = hh
return hdv, hhv

# Used where the antenna response may be frequency varying
# and there is a polarization vector marginalization
@cython.boundscheck(False) # Deactivate bounds checking
@cython.wraparound(False) # Deactivate negative indexing.
@cython.cdivision(True) # Disable checking for dividing by zero
cpdef likelihood_parts_v_time(double [::1] freqs,
double[::1] fp,
double[::1] fc,
double[::1] times,
double[::1] dtc,
double complex[::1] hp,
double complex[::1] hc,
double complex[::1] h00,
double complex[::1] a0,
double complex[::1] a1,
double [::1] b0,
double [::1] b1,
) :
cdef size_t i
cdef double complex hd=0, r0, r0n, r1, x0, x0n, x1
cdef double hh=0, ttime;

N = freqs.shape[0]
num_samples = dtc.shape[0]

cdef numpy.ndarray[numpy.complex128_t, ndim=1] hdv = numpy.empty(num_samples, dtype=numpy.complex128)
cdef numpy.ndarray[numpy.float64_t, ndim=1] hhv = numpy.empty(num_samples, dtype=numpy.float64)

for j in range(num_samples):
hh = 0
hd = 0
for i in range(N):
# This allows for multiple time offsets
ttime = times[i] + dtc[j]
r0n = (exp(-2.0j * 3.141592653 * ttime * freqs[i])
* (fp[i] * hp[i] + fc[i] * hc[i])) / h00[i]
r1 = r0n - r0

x0n = norm(r0n)
x1 = x0n - x0

if i > 0:
hd += a0[i-1] * r0 + a1[i-1] * r1
hh += real(b0[i-1] * x0 + b1[i-1] * x1)

r0 = r0n
x0 = x0n

hdv[j] = conj(hd)
hhv[j] = hh
return hdv, hhv

# Used where the antenna response may be frequency varying
# and there is a polarization vector marginalization
@cython.boundscheck(False) # Deactivate bounds checking
@cython.wraparound(False) # Deactivate negative indexing.
@cython.cdivision(True) # Disable checking for dividing by zero
cpdef likelihood_parts_v_pol_time(double [::1] freqs,
double[::1] fp,
double[::1] fc,
double[::1] times,
double[::1] dtc,
double complex[::1] pol_phase,
double complex[::1] hp,
double complex[::1] hc,
double complex[::1] h00,
double complex[::1] a0,
double complex[::1] a1,
double [::1] b0,
double [::1] b1,
) :
cdef size_t i
cdef double complex hd=0, r0, r0n, r1, x0, x0n, x1, fp2, fc2
cdef double hh=0, ttime;

N = freqs.shape[0]
num_samples = pol_phase.shape[0]

cdef numpy.ndarray[numpy.complex128_t, ndim=1] hdv = numpy.empty(num_samples, dtype=numpy.complex128)
cdef numpy.ndarray[numpy.float64_t, ndim=1] hhv = numpy.empty(num_samples, dtype=numpy.float64)

for j in range(num_samples):
hh = 0
hd = 0
for i in range(N):

f = (fp[i] + 1.0j * fc[i]) * pol_phase[j]
fp2 = real(f)
fc2 = imag(f)

# This allows for multiple time offsets
ttime = times[i] + dtc[j]
r0n = (exp(-2.0j * 3.141592653 * ttime * freqs[i])
* (fp2 * hp[i] + fc2 * hc[i])) / h00[i]
r1 = r0n - r0

x0n = norm(r0n)
x1 = x0n - x0

if i > 0:
hd += a0[i-1] * r0 + a1[i-1] * r1
hh += real(b0[i-1] * x0 + b1[i-1] * x1)

r0 = r0n
x0 = x0n

hdv[j] = conj(hd)
hhv[j] = hh
return hdv, hhv

# Standard likelihood but simultaneously handling multiple sky or time points
@cython.boundscheck(False) # Deactivate bounds checking
@cython.wraparound(False) # Deactivate negative indexing.
Expand Down Expand Up @@ -225,6 +391,52 @@ cpdef likelihood_parts_vector(double [::1] freqs,
hdv[j] = conj(hd)
hhv[j] = hh
return hdv, hhv

# Standard likelihood but simultaneously handling multiple time points
@cython.boundscheck(False) # Deactivate bounds checking
@cython.wraparound(False) # Deactivate negative indexing.
@cython.cdivision(True) # Disable checking for dividing by zero
cpdef likelihood_parts_vectort(double [::1] freqs,
double fp,
double fc,
double[::1] dtc,
double complex[::1] hp,
double complex[::1] hc,
double complex[::1] h00,
double complex[::1] a0,
double complex[::1] a1,
double [::1] b0,
double [::1] b1,
) :
cdef size_t i
cdef double complex hd, r0, r0n, r1, x0, x0n, x1
cdef double hh
N = freqs.shape[0]
num_samples = dtc.shape[0]

cdef numpy.ndarray[numpy.complex128_t, ndim=1] hdv = numpy.empty(num_samples, dtype=numpy.complex128)
cdef numpy.ndarray[numpy.float64_t, ndim=1] hhv = numpy.empty(num_samples, dtype=numpy.float64)

for j in range(num_samples):
hd = 0
hh = 0
for i in range(N):
r0n = (exp(-2.0j * 3.141592653 * dtc[j] * freqs[i])
* (fp * hp[i] + fc * hc[i])) / h00[i]
r1 = r0n - r0

x0n = norm(r0n)
x1 = x0n - x0

if i > 0:
hd += a0[i-1] * r0 + a1[i-1] * r1
hh += real(b0[i-1] * x0 + b1[i-1] * x1)

r0 = r0n
x0 = x0n
hdv[j] = conj(hd)
hhv[j] = hh
return hdv, hhv

# Like above, but if only polarization is marginalized
# this is a slow implementation and the loop should be inverted /
Expand Down

0 comments on commit e188b2d

Please sign in to comment.