Skip to content

Commit

Permalink
propagating detector changes to likelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
maxisi committed Oct 18, 2024
1 parent 67c2625 commit c114ba0
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 198 deletions.
119 changes: 86 additions & 33 deletions src/jimgw/single_event/data.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,36 @@
__include__ = ["Data", "PowerSpectrum"]

from abc import ABC, abstractmethod
from abc import ABC

import jax
import jax.numpy as jnp
import numpy as np
import requests
from gwpy.timeseries import TimeSeries
from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped, Complex
from typing import Optional, Any
from beartype import beartype as typechecker
from jaxtyping import Array, Float, Complex, PRNGKeyArray
from typing import Optional
# from beartype import beartype as typechecker
from scipy.interpolate import interp1d
import scipy.signal as sig
from scipy.signal.windows import tukey

from jimgw.constants import C_SI, EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS
from jimgw.single_event.wave import Polarization
import logging
import jax


DEG_TO_RAD = jnp.pi / 180

# TODO: Need to expand this list. Currently it is only O3.
asd_file_dict = {
"H1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-H1-C01_CLEAN_SUB60HZ-1251752040.0_sensitivity_strain_asd.txt",
"L1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-L1-C01_CLEAN_SUB60HZ-1240573680.0_sensitivity_strain_asd.txt",
"V1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-V1_sensitivity_strain_asd.txt",
"H1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-H1-C01_CLEAN_SUB60HZ-1251752040.0_sensitivity_strain_asd.txt", # noqa: E501
"L1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-L1-C01_CLEAN_SUB60HZ-1240573680.0_sensitivity_strain_asd.txt", # noqa: E501
"V1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-V1_sensitivity_strain_asd.txt", # noqa: E501
}


class Data(ABC):
"""
Base class for all data. The time domain data are considered the primary
entitiy; the Fourier domain data are derived from an FFT after applying a
window. The structure is set up so that :attr:`td` and :attr:`fd` are always
Fourier conjugates of each other: the one-sided Fourier series is complete
up to the Nyquist frequency
"""Base class for all data. The time domain data are considered the primary
entity; the Fourier domain data are derived from an FFT after applying a
window. The structure is set up so that :attr:`td` and :attr:`fd` are
always Fourier conjugates of each other: the one-sided Fourier series is
complete up to the Nyquist frequency.
"""
name: str

Expand Down Expand Up @@ -108,13 +103,14 @@ def __init__(self, td: Float[Array, " n_time"] = jnp.array([]),
self.delta_t = delta_t
self.epoch = epoch
if window is None:
self.window = jnp.ones_like(self.td)
self.set_tukey_window()
else:
self.window = window
self.name = name or ''

def __repr__(self):
return f"{self.__class__.__name__}(name='{self.name}', delta_t={self.delta_t}, epoch={self.epoch})"
return f"{self.__class__.__name__}(name='{self.name}', " + \
f"delta_t={self.delta_t}, epoch={self.epoch})"

def __bool__(self) -> bool:
"""Check if the data is empty."""
Expand Down Expand Up @@ -215,7 +211,8 @@ def from_gwosc(cls,

data_td = TimeSeries.fetch_open_data(ifo, gps_start_time, gps_end_time,
cache=cache, **kws)
return cls(data_td.value, data_td.dt.value, data_td.epoch.value, ifo)
return cls(data_td.value, data_td.dt.value, data_td.epoch.value, ifo) # type: ignore # noqa: E501


class PowerSpectrum(ABC):
name: str
Expand All @@ -240,7 +237,7 @@ def duration(self) -> Float:
@property
def sampling_frequency(self) -> Float:
"""Sampling frequency of the data in Hz."""
return self.frequencies[-1] * 2
return self.frequencies[-1] * 2

def __init__(self, values: Float[Array, " n_freq"] = jnp.array([]),
frequencies: Float[Array, " n_freq"] = jnp.array([]),
Expand All @@ -252,10 +249,15 @@ def __init__(self, values: Float[Array, " n_freq"] = jnp.array([]),
self.name = name or ''

def __repr__(self) -> str:
return f"{self.__class__.__name__}(name='{self.name}', frequencies={self.frequencies})"
return f"{self.__class__.__name__}(name='{self.name}', " + \
f"frequencies={self.frequencies})"

def __bool__(self) -> bool:
"""Check if the power spectrum is empty."""
return len(self.values) > 0

def slice(self, f_min: float, f_max: float) -> \
tuple[Float[Array, " n_sample"], Float[Array, " n_sample"]]:
def frequency_slice(self, f_min: float, f_max: float) -> \
tuple[Float[Array, " n_sample"], Float[Array, " n_sample"]]:
"""Slice the power spectrum.
Arguments
Expand All @@ -270,24 +272,75 @@ def slice(self, f_min: float, f_max: float) -> \
psd_slice: PowerSpectrum
Sliced power spectrum.
"""
values = self.values[(self.frequencies >= f_min) &
(self.frequencies <= f_max)]
frequencies = self.frequencies[(self.frequencies >= f_min) &
(self.frequencies <= f_max)]
return values, frequencies
mask = (self.frequencies >= f_min) & (self.frequencies <= f_max)
return self.values[mask], self.frequencies[mask]

def interpolate(self, f: Float[Array, " n_sample"]) -> "PowerSpectrum":
def interpolate(self, f: Float[Array, " n_sample"],
kind: str = 'cubic', **kws) -> "PowerSpectrum":
"""Interpolate the power spectrum to a new set of frequencies.
Arguments
---------
f: array
Frequencies to interpolate the power spectrum to.
kind: str, optional
Interpolation method (default: 'cubic')
**kws: dict, optional
Keyword arguments for `scipy.interpolate.interp1d`
Returns
-------
psd_interp: array
Interpolated power spectrum.
"""
interp = interp1d(self.frequencies, self.values, kind='cubic')
interp = interp1d(self.frequencies, self.values, kind=kind, **kws)
return PowerSpectrum(interp(f), f, self.name)

def simulate_data(
self,
key: PRNGKeyArray,
# freqs: Float[Array, " n_sample"],
# h_sky: dict[str, Float[Array, " n_sample"]],
# params: dict[str, Float],
# psd_file: str = "",
) -> Complex[Array, " n_sample"]:
"""
Inject a signal into the detector data.
Parameters
----------
key : PRNGKeyArray
JAX PRNG key.
h_sky : dict[str, Float[Array, " n_sample"]]
Array of waveforms in the sky frame. The key is the polarization
mode.
params : dict[str, Float]
Dictionary of parameters.
psd_file : str
Path to the PSD file.
Returns
-------
None
"""
key, subkey = jax.random.split(key, 2)
var = self.values / (4 * self.delta_f)
noise_real = jax.random.normal(key, shape=var.shape) * jnp.sqrt(var)
noise_imag = jax.random.normal(subkey, shape=var.shape) * jnp.sqrt(var)
return noise_real + 1j * noise_imag

# WIP: this should be moved to Detector class

# align_time = jnp.exp(
# -1j * 2 * jnp.pi * freqs * (params["epoch"] + params["t_c"])
# )
# signal = self.fd_response(freqs, h_sky, params) * align_time
# self.data = signal + noise_real + 1j * noise_imag

# # also calculate the optimal SNR and match filter SNR
# optimal_SNR = jnp.sqrt(jnp.sum(signal * signal.conj() / var).real)
# match_filter_SNR = jnp.sum(self.data * signal.conj() / var) / optimal_SNR

# print(f"For detector {self.name}:")
# print(f"The injected optimal SNR is {optimal_SNR}")
# print(f"The injected match filter SNR is {match_filter_SNR}")
Loading

0 comments on commit c114ba0

Please sign in to comment.