Skip to content

Commit

Permalink
likelihood api
Browse files Browse the repository at this point in the history
  • Loading branch information
maxisi committed Oct 18, 2024
1 parent c114ba0 commit e032436
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 42 deletions.
21 changes: 18 additions & 3 deletions example/GW150914_IMRPhenomPV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam
from jimgw.single_event import data as jd

jax.config.update("jax_enable_x64", True)

Expand All @@ -36,16 +37,30 @@
total_time_start = time.time()

# first, fetch a 4s segment centered on GW150914
# for the analysis
gps = 1126259462.4
start = gps - 2
end = gps + 2

# fetch 4096s of data to estimate the PSD (to be
# careful we should avoid the on-source segment,
# but we don't do this in this example)
psd_start = gps - 2048
psd_end = gps + 2048

# define frequency integration bounds for the likelihood
fmin = 20.0
fmax = 1024.0
fmax = 1000.0

ifos = [H1, L1]

H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
for ifo in ifos:
data = jd.Data.from_gwosc(ifo.name, start, end)
ifo.set_data(data)

psd_data = jd.Data.from_gwosc(ifo.name, psd_start, psd_end)
psd_fftlength = data.duration * data.sampling_frequency
ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength))

waveform = RippleIMRPhenomPv2(f_ref=20)

Expand Down
2 changes: 1 addition & 1 deletion src/jimgw/single_event/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self, td: Float[Array, " n_time"] = jnp.array([]),
window: array, optional
Window function to apply to the data before FFT (default: None)
"""
self.name = name or ''
self.td = td
self.fd = jnp.zeros(self.n_freq)
self.delta_t = delta_t
Expand All @@ -106,7 +107,6 @@ def __init__(self, td: Float[Array, " n_time"] = jnp.array([]),
self.set_tukey_window()
else:
self.window = window
self.name = name or ''

def __repr__(self):
return f"{self.__class__.__name__}(name='{self.name}', " + \
Expand Down
101 changes: 63 additions & 38 deletions src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,42 @@ def __init__(
self,
detectors: list[Detector],
waveform: Waveform,
f_min: float = 0,
f_max: float = float("inf"),
trigger_time: float = 0,
post_trigger_duration: float = 2,
f_min: Float = 0,
f_max: Float = float("inf"),
trigger_time: Float = 0,
post_trigger_duration: Float = 2,
**kwargs,
) -> None:
self.detectors = detectors

# TODO: we can probably make this a bit more elegant
# make sure data has a Fourier representation
for det in detectors:
if not det.data.has_fd:
logging.info("Computing FFT with default window")
det.data.fft()
det.set_frequency_bounds(f_min, f_max)

freqs = [d.data.frequency_slice(f_min, f_max)[1] for d in detectors]
assert all([
(freqs[0]
== freq).all() # noqa: W503
for freq in freqs]
), "The detectors must have the same frequency grid"

# collect the data, psd and frequencies for the requested band
freqs = []
datas = []
psds = []
for detector in detectors:
data, freq_0 = detector.data.frequency_slice(f_min, f_max)
psd, freq_1 = detector.psd.frequency_slice(f_min, f_max)
freqs.append(freq_0)
datas.append(data)
psds.append(psd)
# make sure the psd and data are consistent
assert (freq_0 == freq_1).all(), \
f"The {detector.name} data and PSD must have same frequencies"

# make sure all detectors are consistent
assert all([(freqs[0] == freq).all() for freq in freqs]), \
"The detectors must have the same frequency grid"

self.frequencies = freqs[0] # type: ignore
self.datas = [d.data.frequency_slice(f_min, f_max)[0] for d in detectors]
self.psds = [d.psd.frequency_slice(f_min, f_max)[0] for d in detectors]

self.waveform = waveform
self.trigger_time = trigger_time
self.gmst = (
Expand All @@ -85,15 +99,15 @@ def __init__(
if self.marginalization == "phase-time":
self.param_func = lambda x: {**x, "phase_c": 0.0, "t_c": 0.0}
self.likelihood_function = phase_time_marginalized_likelihood
print("Marginalizing over phase and time")
logging.info("Marginalizing over phase and time")
elif self.marginalization == "time":
self.param_func = lambda x: {**x, "t_c": 0.0}
self.likelihood_function = time_marginalized_likelihood
print("Marginalizing over time")
logging.info("Marginalizing over time")
elif self.marginalization == "phase":
self.param_func = lambda x: {**x, "phase_c": 0.0}
self.likelihood_function = phase_marginalized_likelihood
print("Marginalizing over phase")
logging.info("Marginalizing over phase")

if "time" in self.marginalization:
fs = kwargs["sampling_rate"]
Expand Down Expand Up @@ -136,22 +150,19 @@ def __init__(

@property
def epoch(self):
"""
The epoch of the data.
"""The epoch of the data.
"""
return self.duration - self.post_trigger_duration

@property
def ifos(self):
"""
The interferometers for the likelihood.
"""The interferometers for the likelihood.
"""
return [detector.name for detector in self.detectors]

def evaluate(self, params: dict[str, Float], data: dict) -> Float:
# TODO: Test whether we need to pass data in or with class changes is fine.
"""
Evaluate the likelihood for a given set of parameters.
"""Evaluate the likelihood for a given set of parameters.
"""
frequencies = self.frequencies
params["gmst"] = self.gmst
Expand All @@ -169,6 +180,8 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float:
waveform_sky,
self.detectors,
frequencies,
self.datas,
self.psds,
align_time,
**self.kwargs,
)
Expand Down Expand Up @@ -203,9 +216,10 @@ def __init__(
self,
detectors: list[Detector],
waveform: Waveform,
f_min: Float = 0,
f_max: Float = float("inf"),
n_bins: int = 100,
trigger_time: float = 0,
duration: float = 4,
post_trigger_duration: float = 2,
popsize: int = 100,
n_steps: int = 2000,
Expand All @@ -217,10 +231,10 @@ def __init__(
**kwargs,
) -> None:
super().__init__(
detectors, waveform, trigger_time, duration, post_trigger_duration
detectors, waveform, f_min, f_max, trigger_time, post_trigger_duration
)

print("Initializing heterodyned likelihood..")
logging.info("Initializing heterodyned likelihood..")

# Can use another waveform to use as reference waveform, but if not provided, use the same waveform
if reference_waveform is None:
Expand Down Expand Up @@ -299,7 +313,7 @@ def __init__(
print("The eta of the reference parameter is close to 0.25")
print(f"The eta is adjusted to {self.ref_params['eta']}")

print("Constructing reference waveforms..")
logging.info("Constructing reference waveforms..")

self.ref_params["gmst"] = self.gmst
# adjust the params due to different marginalzation scheme
Expand Down Expand Up @@ -647,15 +661,16 @@ def original_likelihood(
h_sky: dict[str, Float[Array, " n_dim"]],
detectors: list[Detector],
freqs: Float[Array, " n_dim"],
datas: list[Float[Array, " n_dim"]],
psds: list[Float[Array, " n_dim"]],
align_time: Float,
**kwargs,
) -> Float:
log_likelihood = 0.0
df = freqs[1] - freqs[0]
for detector in detectors:
for detector, data, psd in zip(detectors, datas, psds):
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
data = detector.fd_data_slice
psd = detector.psd_slice
# NOTE: do we want to take the slide outside the likelihood?
match_filter_SNR = (
4 * jnp.sum((jnp.conj(h_dec) * data) / psd * df).real
)
Expand All @@ -670,18 +685,22 @@ def phase_marginalized_likelihood(
h_sky: dict[str, Float[Array, " n_dim"]],
detectors: list[Detector],
freqs: Float[Array, " n_dim"],
datas: list[Float[Array, " n_dim"]],
psds: list[Float[Array, " n_dim"]],
align_time: Float,
**kwargs,
) -> Float:
log_likelihood = 0.0
complex_d_inner_h = 0.0
df = freqs[1] - freqs[0]
for detector in detectors:
f_min = freqs[0]
f_max = freqs[-1]
for detector, data, psd in zip(detectors, datas, psds):
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_d_inner_h += 4 * jnp.sum(
(jnp.conj(h_dec) * detector.data) / detector.psd * df
(jnp.conj(h_dec) * data) / psd * df
)
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / psd * df).real
log_likelihood += -optimal_SNR / 2

log_likelihood += log_i0(jnp.absolute(complex_d_inner_h))
Expand All @@ -694,17 +713,21 @@ def time_marginalized_likelihood(
h_sky: dict[str, Float[Array, " n_dim"]],
detectors: list[Detector],
freqs: Float[Array, " n_dim"],
datas: list[Float[Array, " n_dim"]],
psds: list[Float[Array, " n_dim"]],
align_time: Float,
**kwargs,
) -> Float:
log_likelihood = 0.0
df = freqs[1] - freqs[0]
f_min = freqs[0]
f_max = freqs[-1]
# using <h|d> instead of <d|h>
complex_h_inner_d = jnp.zeros_like(freqs)
for detector in detectors:
for detector, data, psd in zip(detectors, datas, psds):
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
complex_h_inner_d += 4 * h_dec * jnp.conj(data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / psd * df).real
log_likelihood += -optimal_SNR / 2

# fetch the tc range tc_array, lower padding and higher padding
Expand Down Expand Up @@ -743,17 +766,19 @@ def phase_time_marginalized_likelihood(
h_sky: dict[str, Float[Array, " n_dim"]],
detectors: list[Detector],
freqs: Float[Array, " n_dim"],
datas: list[Float[Array, " n_dim"]],
psds: list[Float[Array, " n_dim"]],
align_time: Float,
**kwargs,
) -> Float:
log_likelihood = 0.0
df = freqs[1] - freqs[0]
# using <h|d> instead of <d|h>
complex_h_inner_d = jnp.zeros_like(freqs)
for detector in detectors:
for detector, data, psd in zip(detectors, datas, psds):
h_dec = detector.fd_response(freqs, h_sky, params) * align_time
complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
complex_h_inner_d += 4 * h_dec * jnp.conj(data) / psd * df
optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / psd * df).real
log_likelihood += -optimal_SNR / 2

# fetch the tc range tc_array, lower padding and higher padding
Expand Down

0 comments on commit e032436

Please sign in to comment.