Skip to content

Commit

Permalink
Merge pull request #48 from kazewong/46-restyling-the-code-to-work-wi…
Browse files Browse the repository at this point in the history
…th-pre-commit

46 restyling the code to work with pre commit
  • Loading branch information
kazewong authored Dec 6, 2023
2 parents 282f409 + 37f0b96 commit 733d802
Show file tree
Hide file tree
Showing 12 changed files with 478 additions and 288 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
files: src/
repos:
- repo: https://github.com/ambv/black
rev: 23.9.1
Expand All @@ -12,7 +13,7 @@ repos:
rev: v1.1.338
hooks:
- id: pyright
additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions]
additional_dependencies: [beartype, jax, jaxtyping, pytest, typing_extensions, flowMC, ripplegw, gwpy, astropy]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.1
hooks:
Expand Down
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ignore = ["F722"]
9 changes: 4 additions & 5 deletions src/jimgw/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from astropy.constants import c,au,G,pc
from astropy.constants import c, pc # type: ignore TODO: fix astropy stubs
from astropy.units import year as yr
from astropy.cosmology import WMAP9 as cosmo

Msun = 4.9255e-6
year = (1*yr).cgs.value
Mpc = 1e6*pc.value/c.value
year = (1 * yr).cgs.value # type: ignore
Mpc = 1e6 * pc.value / c.value
euler_gamma = 0.577215664901532860606512090082
MR_sun = 1.476625061404649406193430731479084713e3
C_SI = 299792458.0
Expand All @@ -13,4 +12,4 @@
EARTH_SEMI_MINOR_AXIS = 6356752.314 # in m

DAYSID_SI = 86164.09053133354
DAYJUL_SI = 86400.0
DAYJUL_SI = 86400.0
4 changes: 1 addition & 3 deletions src/jimgw/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import equinox as eqx
from abc import ABC, abstractmethod
from jaxtyping import Array

class Data(ABC):

class Data(ABC):
@abstractmethod
def __init__(self):
raise NotImplementedError
Expand Down
155 changes: 105 additions & 50 deletions src/jimgw/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import numpy as np
import requests
from gwpy.timeseries import TimeSeries
from jaxtyping import Array, PRNGKeyArray
from jaxtyping import Array, PRNGKeyArray, Float
from scipy.interpolate import interp1d
from scipy.signal.windows import tukey

from jimgw.constants import *
from jimgw.constants import EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS, C_SI
from jimgw.wave import Polarization

DEG_TO_RAD = jnp.pi / 180
Expand Down Expand Up @@ -39,39 +39,51 @@ class Detector(ABC):

name: str

data: Float[Array, " n_sample"]
psd: Float[Array, " n_sample"]

@abstractmethod
def load_data(self, data):
raise NotImplementedError

@abstractmethod
def fd_response(self, frequency: Array, h: Array, params: dict) -> Array:
def fd_response(
self,
frequency: Float[Array, " n_sample"],
h: dict[str, Float[Array, " n_sample"]],
params: dict,
) -> Float[Array, " n_sample"]:
"""
Modulate the waveform in the sky frame by the detector response
in the frequency domain."""
pass

@abstractmethod
def td_response(self, time: Array, h: Array, params: dict) -> Array:
def td_response(
self,
time: Float[Array, " n_sample"],
h: dict[str, Float[Array, " n_sample"]],
params: dict,
) -> Float[Array, " n_sample"]:
"""
Modulate the waveform in the sky frame by the detector response
in the time domain."""
pass


class GroundBased2G(Detector):

polarization_mode: list[Polarization]
frequencies: Array = None
data: Array = None
psd: Array = None

latitude: float = 0
longitude: float = 0
xarm_azimuth: float = 0
yarm_azimuth: float = 0
xarm_tilt: float = 0
yarm_tilt: float = 0
elevation: float = 0
frequencies: Float[Array, " n_sample"]
data: Float[Array, " n_sample"]
psd: Float[Array, " n_sample"]

latitude: Float = 0
longitude: Float = 0
xarm_azimuth: Float = 0
yarm_azimuth: Float = 0
xarm_tilt: Float = 0
yarm_tilt: Float = 0
elevation: Float = 0

def __init__(self, name: str, **kwargs) -> None:
self.name = name
Expand All @@ -86,22 +98,32 @@ def __init__(self, name: str, **kwargs) -> None:
modes = kwargs.get("mode", "pc")

self.polarization_mode = [Polarization(m) for m in modes]
self.frequencies = jnp.array([])
self.data = jnp.array([])
self.psd = jnp.array([])

@staticmethod
def _get_arm(lat, lon, tilt, azimuth):
def _get_arm(
lat: Float, lon: Float, tilt: Float, azimuth: Float
) -> Float[Array, " 3"]:
"""
Construct detector-arm vectors in Earth-centric Cartesian coordinates.
Parameters
---------
lat : float
lat : Float
vertex latitude in rad.
lon : float
lon : Float
vertex longitude in rad.
tilt : float
tilt : Float
arm tilt in rad.
azimuth : float
azimuth : Float
arm azimuth in rad.
Returns
-------
arm : Float[Array, " 3"]
detector arm vector in Earth-centric Cartesian coordinates.
"""
e_lon = jnp.array([-jnp.sin(lon), jnp.cos(lon), 0])
e_lat = jnp.array(
Expand All @@ -118,9 +140,16 @@ def _get_arm(lat, lon, tilt, azimuth):
)

@property
def arms(self):
def arms(self) -> tuple[Float[Array, " 3"], Float[Array, " 3"]]:
"""
Detector arm vectors (x, y).
Returns
-------
x : Float[Array, " 3"]
x-arm vector.
y : Float[Array, " 3"]
y-arm vector.
"""
x = self._get_arm(
self.latitude, self.longitude, self.xarm_tilt, self.xarm_azimuth
Expand All @@ -131,9 +160,14 @@ def arms(self):
return x, y

@property
def tensor(self):
def tensor(self) -> Float[Array, " 3, 3"]:
"""
Detector tensor defining the strain measurement.
Returns
-------
tensor : Float[Array, " 3, 3"]
detector tensor.
"""
# TODO: this could easily be generalized for other detector geometries
arm1, arm2 = self.arms
Expand All @@ -142,11 +176,16 @@ def tensor(self):
)

@property
def vertex(self):
def vertex(self) -> Float[Array, " 3"]:
"""
Detector vertex coordinates in the reference celestial frame. Based
on arXiv:gr-qc/0008066 Eqs. (B11-B13) except for a typo in the
definition of the local radius; see Section 2.1 of LIGO-T980044-10.
Returns
-------
vertex : Float[Array, " 3"]
detector vertex coordinates.
"""
# get detector and Earth parameters
lat = self.latitude
Expand All @@ -164,33 +203,33 @@ def vertex(self):

def load_data(
self,
trigger_time: float,
trigger_time: Float,
gps_start_pad: int,
gps_end_pad: int,
f_min: float,
f_max: float,
f_min: Float,
f_max: Float,
psd_pad: int = 16,
tukey_alpha: float = 0.2,
tukey_alpha: Float = 0.2,
gwpy_kwargs: dict = {"cache": True},
) -> None:
"""
Load data from the detector.
Parameters
----------
trigger_time : float
trigger_time : Float
The GPS time of the trigger.
gps_start_pad : int
The amount of time before the trigger to fetch data.
gps_end_pad : int
The amount of time after the trigger to fetch data.
f_min : float
f_min : Float
The minimum frequency to fetch data.
f_max : float
f_max : Float
The maximum frequency to fetch data.
psd_pad : int
The amount of time to pad the PSD data.
tukey_alpha : float
tukey_alpha : Float
The alpha parameter for the Tukey window.
"""
Expand All @@ -202,6 +241,7 @@ def load_data(
trigger_time + gps_end_pad,
**gwpy_kwargs
)
assert isinstance(data_td, TimeSeries), "Data is not a TimeSeries object."
segment_length = data_td.duration.value
n = len(data_td)
delta_t = data_td.dt.value
Expand All @@ -217,6 +257,9 @@ def load_data(
psd_data_td = TimeSeries.fetch_open_data(
self.name, start_psd, end_psd, **gwpy_kwargs
)
assert isinstance(
psd_data_td, TimeSeries
), "PSD data is not a TimeSeries object."
psd = psd_data_td.psd(
fftlength=segment_length
).value # TODO: Check whether this is sright.
Expand All @@ -227,9 +270,15 @@ def load_data(
self.data = data[(freq > f_min) & (freq < f_max)]
self.psd = psd[(freq > f_min) & (freq < f_max)]

def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array:
def fd_response(
self,
frequency: Float[Array, " n_sample"],
h_sky: dict[str, Float[Array, " n_sample"]],
params: dict[str, Float],
) -> Array:
"""
Modulate the waveform in the sky frame by the detector response in the frequency domain.
"""
Modulate the waveform in the sky frame by the detector response in the frequency domain."""
ra, dec, psi, gmst = params["ra"], params["dec"], params["psi"], params["gmst"]
antenna_pattern = self.antenna_pattern(ra, dec, psi, gmst)
timeshift = self.delay_from_geocenter(ra, dec, gmst)
Expand All @@ -244,10 +293,11 @@ def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array:

def td_response(self, time: Array, h: Array, params: Array) -> Array:
"""
Modulate the waveform in the sky frame by the detector response in the time domain."""
pass
Modulate the waveform in the sky frame by the detector response in the time domain.
"""
raise NotImplementedError

def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float:
def delay_from_geocenter(self, ra: Float, dec: Float, gmst: Float) -> Float:
"""
Calculate time delay between two detectors in geocentric
coordinates based on XLALArrivaTimeDiff in TimeDelay.c
Expand All @@ -256,16 +306,16 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float:
Parameters
---------
ra : float
ra : Float
right ascension of the source in rad.
dec : float
dec : Float
declination of the source in rad.
gmst : float
gmst : Float
Greenwich mean sidereal time in rad.
Returns
-------
float: time delay from Earth center.
Float: time delay from Earth center.
"""
delta_d = -self.vertex
gmst = jnp.mod(gmst, 2 * jnp.pi)
Expand All @@ -280,7 +330,7 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float:
)
return jnp.dot(omega, delta_d) / C_SI

def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dict:
def antenna_pattern(self, ra: Float, dec: Float, psi: Float, gmst: Float) -> dict:
"""
Computes {name} antenna patterns for {modes} polarizations
at the specified sky location, orientation and GMST.
Expand All @@ -291,13 +341,13 @@ def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dic
Parameters
---------
ra : float
ra : Float
source right ascension in radians.
dec : float
dec : Float
source declination in radians.
psi : float
psi : Float
source polarization angle in radians.
gmst : float
gmst : Float
Greenwich mean sidereal time (GMST) in radians.
modes : str
string of polarizations to include, defaults to tensor modes: 'pc'.
Expand All @@ -324,7 +374,7 @@ def inject_signal(
freqs: Array,
h_sky: dict,
params: dict,
psd_file: str = None,
psd_file: str = "",
) -> None:
""" """
self.frequencies = freqs
Expand All @@ -339,8 +389,10 @@ def inject_signal(
signal = self.fd_response(freqs, h_sky, params) * align_time
self.data = signal + noise_real + 1j * noise_imag

def load_psd(self, freqs: Array, psd_file: str = None) -> None:
if psd_file is None:
def load_psd(
self, freqs: Float[Array, " n_sample"], psd_file: str = ""
) -> Float[Array, " n_sample"]:
if psd_file == "":
print("Grabbing GWTC-2 PSD for " + self.name)
url = psd_file_dict[self.name]
data = requests.get(url)
Expand All @@ -349,7 +401,10 @@ def load_psd(self, freqs: Array, psd_file: str = None) -> None:
else:
f, asd_vals = np.loadtxt(psd_file, unpack=True)
psd_vals = asd_vals**2
psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs)
assert isinstance(f, Float[Array, "n_sample"])
assert isinstance(psd_vals, Float[Array, "n_sample"])
psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) # type: ignore
psd = jnp.array(psd)
return psd


Expand Down
Loading

0 comments on commit 733d802

Please sign in to comment.