Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4d SO3 optimizer #90

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
248 changes: 229 additions & 19 deletions sxs/waveforms/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from functools import partial



def align1d(wa, wb, t1, t2, n_brute_force=None):
"""Align waveforms by shifting in time

Expand Down Expand Up @@ -126,9 +125,7 @@ def _cost2d(δt_δϕ, args):

# Take the sqrt because least_squares squares the inputs...
diff = trapezoid(
np.sum(
abs(modes_A(t_reference + δt) * np.exp(1j * m * δϕ) * δΨ_factor - modes_B) ** 2, axis=1
),
np.sum(abs(modes_A(t_reference + δt) * np.exp(1j * m * δϕ) * δΨ_factor - modes_B) ** 2, axis=1),
t_reference,
)
return np.sqrt(diff / normalization)
Expand Down Expand Up @@ -218,13 +215,9 @@ def align2d(wa, wb, t1, t2, n_brute_force_δt=None, n_brute_force_δϕ=5, includ
if t2 <= t1:
raise ValueError(f"(t1,t2)=({t1}, {t2}) is out of order")
if wa.t[0] > t1 or wa.t[-1] < t2:
raise ValueError(
f"(t1,t2)=({t1}, {t2}) not contained in wa.t, which spans ({wa.t[0]}, {wa.t[-1]})"
)
raise ValueError(f"(t1,t2)=({t1}, {t2}) not contained in wa.t, which spans ({wa.t[0]}, {wa.t[-1]})")
if wb.t[0] > t1 or wb.t[-1] < t2:
raise ValueError(
f"(t1,t2)=({t1}, {t2}) not contained in wb.t, which spans ({wb.t[0]}, {wb.t[-1]})"
)
raise ValueError(f"(t1,t2)=({t1}, {t2}) not contained in wb.t, which spans ({wb.t[0]}, {wb.t[-1]})")

# Figure out time offsets to try
δt_lower = max(t1 - t2, wa.t[0] - t1)
Expand Down Expand Up @@ -254,17 +247,11 @@ def align2d(wa, wb, t1, t2, n_brute_force_δt=None, n_brute_force_δϕ=5, includ
wb.data[:, wb.index(L, M)] *= 0

# Define the cost function
modes_A = CubicSpline(
wa.t, wa[:, wa.index(2, -2) : wa.index(ell_max + 1, -(ell_max + 1))].data
)
modes_B = CubicSpline(
wb.t, wb[:, wb.index(2, -2) : wb.index(ell_max + 1, -(ell_max + 1))].data
)(t_reference)
modes_A = CubicSpline(wa.t, wa[:, wa.index(2, -2) : wa.index(ell_max + 1, -(ell_max + 1))].data)
modes_B = CubicSpline(wb.t, wb[:, wb.index(2, -2) : wb.index(ell_max + 1, -(ell_max + 1))].data)(t_reference)

normalization = trapezoid(
CubicSpline(
wb.t, wb[:, wb.index(2, -2) : wb.index(ell_max + 1, -(ell_max + 1))].norm ** 2
)(t_reference),
CubicSpline(wb.t, wb[:, wb.index(2, -2) : wb.index(ell_max + 1, -(ell_max + 1))].norm ** 2)(t_reference),
t_reference,
)

Expand Down Expand Up @@ -310,3 +297,226 @@ def align2d(wa, wb, t1, t2, n_brute_force_δt=None, n_brute_force_δϕ=5, includ
idx = np.argmin(abs(np.array([optimum.cost for optimum in optimums])))

return optimums[idx].cost, wa_primes[idx], optimums[idx]


def _cost4d(δt_δSO3, args):
from .. import WaveformModes

keefemitman marked this conversation as resolved.
Show resolved Hide resolved
modes_A, modes_B, t_reference, normalization = args
δt = δt_δSO3[0]
δSO3 = np.exp(np.quaternion(*δt_δSO3[1:]))
keefemitman marked this conversation as resolved.
Show resolved Hide resolved

modes_A_at_δt = modes_A(t_reference + δt)
ell_max = int(np.sqrt(modes_A_at_δt.shape[1] + 4)) - 1

wa_prime = WaveformModes(
input_array=(modes_A_at_δt),
time=t_reference,
time_axis=0,
modes_axis=1,
ell_min=2,
ell_max=ell_max,
spin_weight=-2,
)

wa_prime = wa_prime.rotate(δSO3.components)

# Take the sqrt because least_squares squares the inputs...
diff = trapezoid(
np.sum(abs(wa_prime.data - modes_B) ** 2, axis=1),
t_reference,
)
return np.sqrt(diff / normalization)
moble marked this conversation as resolved.
Show resolved Hide resolved


def align4d(
keefemitman marked this conversation as resolved.
Show resolved Hide resolved
wa,
wb,
t1,
t2,
n_brute_force_δt=None,
n_brute_force_δSO3=None,
max_δt=None,
max_δSO3=None,
keefemitman marked this conversation as resolved.
Show resolved Hide resolved
include_modes=None,
nprocs=None,
):
"""Align waveforms by optimizing over a time translation and an SO(3) rotation.

This function determines the optimal transformation to apply to `wa` by
minimizing the averaged (over time) L² norm (over the sphere) of the difference
of the waveforms.

The integral is taken from time `t1` to `t2`.

Note that the input waveforms are assumed to be initially aligned at least well
enough that:

1) the time span from `t1` to `t2` in the two waveforms will overlap at
least slightly after the second waveform is shifted in time; and
2) waveform `wb` contains all the times corresponding to `t1` to `t2` in
waveform `wa`.

The first of these can usually be assured by simply aligning the peaks prior to
calling this function:

wa.t -= wa.max_norm_time() - wb.max_norm_time()

The second assumption will be satisfied as long as `t1` is not too close to the
beginning of `wb` and `t2` is not too close to the end.

Parameters
----------
wa : WaveformModes
wb : WaveformModes
Waveforms to be aligned
t1 : float
t2 : float
Beginning and end of integration interval
n_brute_force_δt : int, optional
Number of evenly spaced δt values between (t1-t2) and (t2-t1) to sample
for the initial guess. By default, this is just the maximum number of
time steps in the range (t1, t2) in the input waveforms. If this is
too small, an incorrect local minimum may be found.
n_brute_force_δSO3 : int, optional
Number of evenly spaced values over the two sphere to sample
for the initial guess. Dy default, this is 5.
max_δt : float, optional
Max δt to allow for when choosing the initial guess.
max_δSO3 : float, optional
Max δtheta (away from z) to allow for when choosing the initial guess.
include_modes: list, optional
A list containing the (ell, m) modes to be included in the L² norm.
nprocs: int, optional
Number of cpus to use. Default is maximum number.
If -1 is provided, then no multiprocessing is performed.

Returns
-------
error: float
Cost of scipy.optimize.least_squares
This is 0.5 ||wa - wb||² / ||wb||²
wa_prime: WaveformModes
Resulting waveform after transforming `wa` using `optimum`
optimum: OptimizeResult
Result of scipy.optimize.least_squares

Notes
-----
Choosing the time interval is usually the most difficult choice to make when
aligning waveforms. Assuming you want to align during inspiral, the times must
span sufficiently long that the waveforms' norm (equivalently, orbital
frequency changes) significantly from `t1` to `t2`. This means that you cannot
always rely on a specific number of orbits, for example. Also note that
neither number should be too close to the beginning or end of either waveform,
to provide some "wiggle room".

"""
from scipy.interpolate import CubicSpline
from scipy.optimize import least_squares
from .. import WaveformModes

wa_orig = wa
wa = wa.copy()
wb = wb.copy()

# Check that (t1, t2) makes sense and is actually contained in both waveforms
if t2 <= t1:
raise ValueError(f"(t1,t2)=({t1}, {t2}) is out of order")
if wa.t[0] > t1 or wa.t[-1] < t2:
raise ValueError(f"(t1,t2)=({t1}, {t2}) not contained in wa.t, which spans ({wa.t[0]}, {wa.t[-1]})")
if wb.t[0] > t1 or wb.t[-1] < t2:
raise ValueError(f"(t1,t2)=({t1}, {t2}) not contained in wb.t, which spans ({wb.t[0]}, {wb.t[-1]})")

if max_δt is None:
max_δt = np.inf

# Figure out time offsets to try
δt_lower = min(max_δt, max(t1 - t2, wa.t[0] - t1))
δt_upper = min(max_δt, min(t2 - t1, wa.t[-1] - t2))

# We'll start by brute forcing, sampling time offsets evenly at as many
keefemitman marked this conversation as resolved.
Show resolved Hide resolved
# points as there are time steps in (t1,t2) in the input waveforms
if n_brute_force_δt is None:
n_brute_force_δt = max(sum((wa.t >= t1) & (wa.t <= t2)), sum((wb.t >= t1) & (wb.t <= t2)))
δt_brute_force = np.linspace(δt_lower, δt_upper, num=n_brute_force_δt)

if max_δSO3 is None:
max_δSO3 = np.pi

if n_brute_force_δSO3 is None:
n_brute_force_δSO3 = 5

# pick (angle, theta, phi) such that exp(q) corresponds to the expected (angle, theta, phi)
δSO3_brute_force = [
[
angle / 2 * np.sin(theta) * np.cos(phi),
angle / 2 * np.sin(theta) * np.sin(phi),
angle / 2 * np.cos(theta),
]
for phi in np.linspace(0.0, 2 * np.pi, num=n_brute_force_δSO3, endpoint=False)
for theta in np.linspace(0.0, max_δSO3, num=n_brute_force_δSO3, endpoint=True)
for angle in np.linspace(0.0, 2 * np.pi, num=n_brute_force_δSO3, endpoint=False)
]

δt_δSO3_brute_force = []
for i in range(len(δt_brute_force)):
for j in range(len(δSO3_brute_force)):
if np.quaternion(*δSO3_brute_force[j]).norm() == 0:
continue
δt_δSO3_brute_force.append([δt_brute_force[i], *δSO3_brute_force[j]])

t_reference = wa.t[np.argmin(abs(wa.t - t1)) : np.argmin(abs(wa.t - t2)) + 1]

# Remove certain modes, if requested
ell_max = min(wa.ell_max, wb.ell_max)
if include_modes != None:
for L in range(2, ell_max + 1):
for M in range(-L, L + 1):
if not (L, M) in include_modes:
wa.data[:, wa.index(L, M)] *= 0
wb.data[:, wb.index(L, M)] *= 0

# Define the cost function
modes_A = CubicSpline(wa.t, wa[:, wa.index(2, -2) : wa.index(ell_max + 1, -(ell_max + 1))].data)
modes_B = CubicSpline(wb.t, wb[:, wb.index(2, -2) : wb.index(ell_max + 1, -(ell_max + 1))].data)(t_reference)

normalization = trapezoid(
CubicSpline(wb.t, wb[:, wb.index(2, -2) : wb.index(ell_max + 1, -(ell_max + 1))].norm ** 2)(t_reference),
t_reference,
)

# Optimize by brute force with multiprocessing
cost_wrapper = partial(_cost4d, args=[modes_A, modes_B, t_reference, normalization])

if nprocs != -1:
if nprocs is None:
nprocs = mp.cpu_count()
pool = mp.Pool(processes=nprocs)
cost_brute_force = pool.map(cost_wrapper, δt_δSO3_brute_force)
pool.close()
pool.join()
else:
cost_brute_force = [cost_wrapper(δt_δSO3_brute_force_item) for δt_δSO3_brute_force_item in δt_δSO3_brute_force]

δt_δSO3 = δt_δSO3_brute_force[np.argmin(cost_brute_force)]

# Optimize explicitly
optimum = least_squares(
cost_wrapper, δt_δSO3, bounds=[(δt_lower, 0, 0, 0), (δt_upper, 2 * np.pi, np.pi, 2 * np.pi)], max_nfev=50000
)
δt = optimum.x[0]
δSO3 = np.exp(np.quaternion(*optimum.x[1:]))

wa_prime = WaveformModes(
input_array=(wa_orig[:, wa_orig.index(2, -2) : wa_orig.index(ell_max + 1, -(ell_max + 1))].data),
time=wa_orig.t - δt,
time_axis=0,
modes_axis=1,
ell_min=2,
ell_max=ell_max,
spin_weight=-2,
)
wa_prime = wa_prime.rotate(δSO3.components)

return optimum.cost, wa_prime, optimum
Loading