diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b5595520..feda83d8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +files: src/ repos: - repo: https://github.com/ambv/black rev: 23.9.1 @@ -12,7 +13,7 @@ repos: rev: v1.1.338 hooks: - id: pyright - additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions] + additional_dependencies: [beartype, jax, jaxtyping, pytest, typing_extensions, flowMC, ripplegw, gwpy, astropy] - repo: https://github.com/nbQA-dev/nbQA rev: 1.7.1 hooks: diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..bf64e2d0 --- /dev/null +++ b/ruff.toml @@ -0,0 +1 @@ +ignore = ["F722"] \ No newline at end of file diff --git a/src/jimgw/constants.py b/src/jimgw/constants.py index ea3ae17c..d4b20605 100644 --- a/src/jimgw/constants.py +++ b/src/jimgw/constants.py @@ -1,10 +1,9 @@ -from astropy.constants import c,au,G,pc +from astropy.constants import c, pc # type: ignore TODO: fix astropy stubs from astropy.units import year as yr -from astropy.cosmology import WMAP9 as cosmo Msun = 4.9255e-6 -year = (1*yr).cgs.value -Mpc = 1e6*pc.value/c.value +year = (1 * yr).cgs.value # type: ignore +Mpc = 1e6 * pc.value / c.value euler_gamma = 0.577215664901532860606512090082 MR_sun = 1.476625061404649406193430731479084713e3 C_SI = 299792458.0 @@ -13,4 +12,4 @@ EARTH_SEMI_MINOR_AXIS = 6356752.314 # in m DAYSID_SI = 86164.09053133354 -DAYJUL_SI = 86400.0 \ No newline at end of file +DAYJUL_SI = 86400.0 diff --git a/src/jimgw/data.py b/src/jimgw/data.py index 9bc31efe..df31b7ea 100644 --- a/src/jimgw/data.py +++ b/src/jimgw/data.py @@ -1,9 +1,7 @@ -import equinox as eqx from abc import ABC, abstractmethod -from jaxtyping import Array -class Data(ABC): +class Data(ABC): @abstractmethod def __init__(self): raise NotImplementedError diff --git a/src/jimgw/detector.py b/src/jimgw/detector.py index d7580335..9d60c8df 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/detector.py @@ -5,11 +5,11 @@ import numpy as np import requests from gwpy.timeseries import TimeSeries -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, PRNGKeyArray, Float from scipy.interpolate import interp1d from scipy.signal.windows import tukey -from jimgw.constants import * +from jimgw.constants import EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS, C_SI from jimgw.wave import Polarization DEG_TO_RAD = jnp.pi / 180 @@ -39,19 +39,32 @@ class Detector(ABC): name: str + data: Float[Array, " n_sample"] + psd: Float[Array, " n_sample"] + @abstractmethod def load_data(self, data): raise NotImplementedError @abstractmethod - def fd_response(self, frequency: Array, h: Array, params: dict) -> Array: + def fd_response( + self, + frequency: Float[Array, " n_sample"], + h: dict[str, Float[Array, " n_sample"]], + params: dict, + ) -> Float[Array, " n_sample"]: """ Modulate the waveform in the sky frame by the detector response in the frequency domain.""" pass @abstractmethod - def td_response(self, time: Array, h: Array, params: dict) -> Array: + def td_response( + self, + time: Float[Array, " n_sample"], + h: dict[str, Float[Array, " n_sample"]], + params: dict, + ) -> Float[Array, " n_sample"]: """ Modulate the waveform in the sky frame by the detector response in the time domain.""" @@ -59,19 +72,18 @@ def td_response(self, time: Array, h: Array, params: dict) -> Array: class GroundBased2G(Detector): - polarization_mode: list[Polarization] - frequencies: Array = None - data: Array = None - psd: Array = None - - latitude: float = 0 - longitude: float = 0 - xarm_azimuth: float = 0 - yarm_azimuth: float = 0 - xarm_tilt: float = 0 - yarm_tilt: float = 0 - elevation: float = 0 + frequencies: Float[Array, " n_sample"] + data: Float[Array, " n_sample"] + psd: Float[Array, " n_sample"] + + latitude: Float = 0 + longitude: Float = 0 + xarm_azimuth: Float = 0 + yarm_azimuth: Float = 0 + xarm_tilt: Float = 0 + yarm_tilt: Float = 0 + elevation: Float = 0 def __init__(self, name: str, **kwargs) -> None: self.name = name @@ -86,22 +98,32 @@ def __init__(self, name: str, **kwargs) -> None: modes = kwargs.get("mode", "pc") self.polarization_mode = [Polarization(m) for m in modes] + self.frequencies = jnp.array([]) + self.data = jnp.array([]) + self.psd = jnp.array([]) @staticmethod - def _get_arm(lat, lon, tilt, azimuth): + def _get_arm( + lat: Float, lon: Float, tilt: Float, azimuth: Float + ) -> Float[Array, " 3"]: """ Construct detector-arm vectors in Earth-centric Cartesian coordinates. Parameters --------- - lat : float + lat : Float vertex latitude in rad. - lon : float + lon : Float vertex longitude in rad. - tilt : float + tilt : Float arm tilt in rad. - azimuth : float + azimuth : Float arm azimuth in rad. + + Returns + ------- + arm : Float[Array, " 3"] + detector arm vector in Earth-centric Cartesian coordinates. """ e_lon = jnp.array([-jnp.sin(lon), jnp.cos(lon), 0]) e_lat = jnp.array( @@ -118,9 +140,16 @@ def _get_arm(lat, lon, tilt, azimuth): ) @property - def arms(self): + def arms(self) -> tuple[Float[Array, " 3"], Float[Array, " 3"]]: """ Detector arm vectors (x, y). + + Returns + ------- + x : Float[Array, " 3"] + x-arm vector. + y : Float[Array, " 3"] + y-arm vector. """ x = self._get_arm( self.latitude, self.longitude, self.xarm_tilt, self.xarm_azimuth @@ -131,9 +160,14 @@ def arms(self): return x, y @property - def tensor(self): + def tensor(self) -> Float[Array, " 3, 3"]: """ Detector tensor defining the strain measurement. + + Returns + ------- + tensor : Float[Array, " 3, 3"] + detector tensor. """ # TODO: this could easily be generalized for other detector geometries arm1, arm2 = self.arms @@ -142,11 +176,16 @@ def tensor(self): ) @property - def vertex(self): + def vertex(self) -> Float[Array, " 3"]: """ Detector vertex coordinates in the reference celestial frame. Based on arXiv:gr-qc/0008066 Eqs. (B11-B13) except for a typo in the definition of the local radius; see Section 2.1 of LIGO-T980044-10. + + Returns + ------- + vertex : Float[Array, " 3"] + detector vertex coordinates. """ # get detector and Earth parameters lat = self.latitude @@ -164,13 +203,13 @@ def vertex(self): def load_data( self, - trigger_time: float, + trigger_time: Float, gps_start_pad: int, gps_end_pad: int, - f_min: float, - f_max: float, + f_min: Float, + f_max: Float, psd_pad: int = 16, - tukey_alpha: float = 0.2, + tukey_alpha: Float = 0.2, gwpy_kwargs: dict = {"cache": True}, ) -> None: """ @@ -178,19 +217,19 @@ def load_data( Parameters ---------- - trigger_time : float + trigger_time : Float The GPS time of the trigger. gps_start_pad : int The amount of time before the trigger to fetch data. gps_end_pad : int The amount of time after the trigger to fetch data. - f_min : float + f_min : Float The minimum frequency to fetch data. - f_max : float + f_max : Float The maximum frequency to fetch data. psd_pad : int The amount of time to pad the PSD data. - tukey_alpha : float + tukey_alpha : Float The alpha parameter for the Tukey window. """ @@ -202,6 +241,7 @@ def load_data( trigger_time + gps_end_pad, **gwpy_kwargs ) + assert isinstance(data_td, TimeSeries), "Data is not a TimeSeries object." segment_length = data_td.duration.value n = len(data_td) delta_t = data_td.dt.value @@ -217,6 +257,9 @@ def load_data( psd_data_td = TimeSeries.fetch_open_data( self.name, start_psd, end_psd, **gwpy_kwargs ) + assert isinstance( + psd_data_td, TimeSeries + ), "PSD data is not a TimeSeries object." psd = psd_data_td.psd( fftlength=segment_length ).value # TODO: Check whether this is sright. @@ -227,9 +270,15 @@ def load_data( self.data = data[(freq > f_min) & (freq < f_max)] self.psd = psd[(freq > f_min) & (freq < f_max)] - def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array: + def fd_response( + self, + frequency: Float[Array, " n_sample"], + h_sky: dict[str, Float[Array, " n_sample"]], + params: dict[str, Float], + ) -> Array: + """ + Modulate the waveform in the sky frame by the detector response in the frequency domain. """ - Modulate the waveform in the sky frame by the detector response in the frequency domain.""" ra, dec, psi, gmst = params["ra"], params["dec"], params["psi"], params["gmst"] antenna_pattern = self.antenna_pattern(ra, dec, psi, gmst) timeshift = self.delay_from_geocenter(ra, dec, gmst) @@ -244,10 +293,11 @@ def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array: def td_response(self, time: Array, h: Array, params: Array) -> Array: """ - Modulate the waveform in the sky frame by the detector response in the time domain.""" - pass + Modulate the waveform in the sky frame by the detector response in the time domain. + """ + raise NotImplementedError - def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float: + def delay_from_geocenter(self, ra: Float, dec: Float, gmst: Float) -> Float: """ Calculate time delay between two detectors in geocentric coordinates based on XLALArrivaTimeDiff in TimeDelay.c @@ -256,16 +306,16 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float: Parameters --------- - ra : float + ra : Float right ascension of the source in rad. - dec : float + dec : Float declination of the source in rad. - gmst : float + gmst : Float Greenwich mean sidereal time in rad. Returns ------- - float: time delay from Earth center. + Float: time delay from Earth center. """ delta_d = -self.vertex gmst = jnp.mod(gmst, 2 * jnp.pi) @@ -280,7 +330,7 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float: ) return jnp.dot(omega, delta_d) / C_SI - def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dict: + def antenna_pattern(self, ra: Float, dec: Float, psi: Float, gmst: Float) -> dict: """ Computes {name} antenna patterns for {modes} polarizations at the specified sky location, orientation and GMST. @@ -291,13 +341,13 @@ def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dic Parameters --------- - ra : float + ra : Float source right ascension in radians. - dec : float + dec : Float source declination in radians. - psi : float + psi : Float source polarization angle in radians. - gmst : float + gmst : Float Greenwich mean sidereal time (GMST) in radians. modes : str string of polarizations to include, defaults to tensor modes: 'pc'. @@ -324,7 +374,7 @@ def inject_signal( freqs: Array, h_sky: dict, params: dict, - psd_file: str = None, + psd_file: str = "", ) -> None: """ """ self.frequencies = freqs @@ -339,8 +389,10 @@ def inject_signal( signal = self.fd_response(freqs, h_sky, params) * align_time self.data = signal + noise_real + 1j * noise_imag - def load_psd(self, freqs: Array, psd_file: str = None) -> None: - if psd_file is None: + def load_psd( + self, freqs: Float[Array, " n_sample"], psd_file: str = "" + ) -> Float[Array, " n_sample"]: + if psd_file == "": print("Grabbing GWTC-2 PSD for " + self.name) url = psd_file_dict[self.name] data = requests.get(url) @@ -349,7 +401,10 @@ def load_psd(self, freqs: Array, psd_file: str = None) -> None: else: f, asd_vals = np.loadtxt(psd_file, unpack=True) psd_vals = asd_vals**2 - psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) + assert isinstance(f, Float[Array, "n_sample"]) + assert isinstance(psd_vals, Float[Array, "n_sample"]) + psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) # type: ignore + psd = jnp.array(psd) return psd diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 5df9cf48..02a9db83 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -5,7 +5,7 @@ from flowMC.utils.PRNG_keys import initialize_rng_keys from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer from jimgw.prior import Prior -from jaxtyping import Array +from jaxtyping import Array, Float, PRNGKeyArray import jax import jax.numpy as jnp from flowMC.sampler.flowHMC import flowHMC @@ -34,9 +34,10 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): self.posterior, True, local_sampler_arg ) # Remember to add routine to find automated mass matrix - flowHMC_params = kwargs.get("flowHMC_params", {}) - model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) + model = MaskedCouplingRQSpline( + self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1] + ) if len(flowHMC_params) > 0: global_sampler = flowHMC( self.posterior, @@ -51,46 +52,53 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): else: global_sampler = None - self.Sampler = Sampler( self.Prior.n_dim, rng_key_set, - None, + None, # type: ignore local_sampler, model, - global_sampler = global_sampler, - **kwargs) - + global_sampler=global_sampler, + **kwargs, + ) - def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 100, n_loops: int = 2000, seed = 92348): - bounds = jnp.array(bounds).T + def maximize_likelihood( + self, + bounds: Float[Array, " n_dim 2"], + set_nwalkers: int = 100, + n_loops: int = 2000, + seed=92348, + ): key = jax.random.PRNGKey(seed) set_nwalkers = set_nwalkers initial_guess = self.Prior.sample(key, set_nwalkers) - y = lambda x: -self.posterior(x, None) - y = jax.jit(jax.vmap(y)) + def negative_posterior(x: Float[Array, " n_dim"]): + return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet + + negative_posterior = jax.jit(jax.vmap(negative_posterior)) print("Compiling likelihood function") - y(initial_guess) + negative_posterior(initial_guess) print("Done compiling") print("Starting the optimizer") optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True) - state = optimizer.optimize(y, bounds, n_loops=n_loops) + _ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return best_fit def posterior(self, params: Array, data: dict): prior_params = self.Prior.add_name(params.T) - prior = self.Prior.log_prob(prior_params) - return self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior + prior = self.Prior.log_prob(prior_params) + return ( + self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior + ) - def sample(self, key: jax.random.PRNGKey, - initial_guess: Array = None): - if initial_guess is None: - initial_guess = self.Prior.sample(key, self.Sampler.n_chains) - initial_guess = jnp.stack([i for i in initial_guess.values()]).T - self.Sampler.sample(initial_guess, None) + def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): + if initial_guess is jnp.array([]): + initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains) + initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + self.Sampler.sample(initial_guess, None) # type: ignore def print_summary(self): """ @@ -167,7 +175,7 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.Sampler.get_sampler_state(training=False)["chains"] - chains = self.Prior.add_name(chains.transpose(2, 0, 1), transform_name=True) + chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1))) return chains def plot(self): diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index bd37dcf3..488bb51b 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp import numpy as np +import numpy.typing as npt from astropy.time import Time from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer from jaxtyping import Array, Float @@ -44,7 +45,7 @@ def data(self): return self._data @abstractmethod - def evaluate(self, params) -> float: + def evaluate(self, params: dict[str, Float], data: dict) -> Float: """ Evaluate the likelihood for a given set of parameters. """ @@ -64,6 +65,15 @@ def __init__( post_trigger_duration: float = 2, ) -> None: self.detectors = detectors + assert jnp.all( + jnp.array( + [ + (self.detectors[0].frequencies == detector.frequencies).all() # type: ignore + for detector in self.detectors + ] + ) + ), "The detectors must have the same frequency grid" + self.frequencies = self.detectors[0].frequencies # type: ignore self.waveform = waveform self.trigger_time = trigger_time self.gmst = ( @@ -88,16 +98,13 @@ def ifos(self): """ return [detector.name for detector in self.detectors] - def evaluate( - self, params: Array, data: dict - ) -> ( - float - ): # TODO: Test whether we need to pass data in or with class changes is fine. + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + # TODO: Test whether we need to pass data in or with class changes is fine. """ Evaluate the likelihood for a given set of parameters. """ log_likelihood = 0 - frequencies = self.detectors[0].frequencies + frequencies = self.frequencies df = frequencies[1] - frequencies[0] params["gmst"] = self.gmst waveform_sky = self.waveform(frequencies, params) @@ -130,22 +137,30 @@ class HeterodynedTransientLikelihoodFD(TransientLikelihoodFD): freq_grid_low: Array # Heterodyned frequency grid freq_grid_center: Array # Heterodyned frequency grid at the center of the bin waveform_low_ref: dict[ - Array + str, Float[Array, " n_bin"] ] # Reference waveform at the low edge of the frequency bin, keyed by detector name waveform_center_ref: dict[ - Array + str, Float[Array, " n_bin"] ] # Reference waveform at the center of the frequency bin, keyed by detector name - A0_array: dict[Array] # A0 array for the likelihood, keyed by detector name - A1_array: dict[Array] # A1 array for the likelihood, keyed by detector name - B0_array: dict[Array] # B0 array for the likelihood, keyed by detector name - B1_array: dict[Array] # B1 array for the likelihood, keyed by detector name + A0_array: dict[ + str, Float[Array, " n_bin"] + ] # A0 array for the likelihood, keyed by detector name + A1_array: dict[ + str, Float[Array, " n_bin"] + ] # A1 array for the likelihood, keyed by detector name + B0_array: dict[ + str, Float[Array, " n_bin"] + ] # B0 array for the likelihood, keyed by detector name + B1_array: dict[ + str, Float[Array, " n_bin"] + ] # B1 array for the likelihood, keyed by detector name def __init__( self, detectors: list[Detector], waveform: Waveform, prior: Prior, - bounds: tuple[Array, Array], + bounds: Float[Array, " n_dim 2"], n_bins: int = 100, trigger_time: float = 0, duration: float = 4, @@ -161,16 +176,7 @@ def __init__( # Get the original frequency grid - assert jnp.all( - jnp.array( - [ - (self.detectors[0].frequencies == detector.frequencies).all() - for detector in self.detectors - ] - ) - ), "The detectors must have the same frequency grid" - - frequency_original = self.detectors[0].frequencies + frequency_original = self.frequencies # Get the grid of the relative binning scheme (contains the final endpoint) # and the center points freq_grid, self.freq_grid_center = self.make_binning_scheme( @@ -273,7 +279,7 @@ def __init__( self.B0_array[detector.name] = B0[mask_heterodyne_center] self.B1_array[detector.name] = B1[mask_heterodyne_center] - def evaluate(self, params: Array, data: dict) -> float: + def evaluate(self, params: dict[str, Float], data: dict) -> Float: log_likelihood = 0 frequencies_low = self.freq_grid_low frequencies_center = self.freq_grid_center @@ -313,15 +319,15 @@ def evaluate(self, params: Array, data: dict) -> float: return log_likelihood def evaluate_original( - self, params: Array, data: dict + self, params: dict[str, Float], data: dict ) -> ( - float + Float ): # TODO: Test whether we need to pass data in or with class changes is fine. """ Evaluate the likelihood for a given set of parameters. """ log_likelihood = 0 - frequencies = self.detectors[0].frequencies + frequencies = self.frequencies df = frequencies[1] - frequencies[0] params["gmst"] = self.gmst waveform_sky = self.waveform(frequencies, params) @@ -349,10 +355,10 @@ def evaluate_original( @staticmethod def max_phase_diff( - f: Float[Array, "n_dim"], + f: npt.NDArray[np.float_], f_low: float, f_high: float, - chi: float = 1, + chi: Float = 1.0, ): """ Compute the maximum phase difference between the frequencies in the array. @@ -381,8 +387,8 @@ def max_phase_diff( return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1) def make_binning_scheme( - self, freqs: Float[Array, "n_dim"], n_bins: int, chi: float = 1 - ) -> tuple[Float[Array, "n_bins+1"], Float[Array, "n_bins"]]: + self, freqs: npt.NDArray[np.float_], n_bins: int, chi: float = 1 + ) -> tuple[Float[Array, " n_bins+1"], Float[Array, " n_bins"]]: """ Make a binning scheme based on the maximum phase difference between the frequencies in the array. @@ -410,7 +416,7 @@ def make_binning_scheme( for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1): f_bins = np.append(f_bins, bin_f(i)) f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2 - return f_bins, f_bins_center + return jnp.array(f_bins), jnp.array(f_bins_center) @staticmethod def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): @@ -453,23 +459,18 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): def maximize_likelihood( self, - bounds: tuple[Array, Array], + bounds: Float[Array, " n_dim 2"], prior: Prior, popsize: int = 100, n_loops: int = 2000, ): - bounds = jnp.array(bounds).T - popsize = popsize # TODO remove this? - def y(x): - return -self.evaluate_original( - prior.transform(prior.add_name(x)), None - ) + return -self.evaluate_original(prior.transform(prior.add_name(x)), {}) y = jax.jit(jax.vmap(y)) print("Starting the optimizer") optimizer = EvolutionaryOptimizer(len(bounds), popsize=popsize, verbose=True) - state = optimizer.optimize(y, bounds, n_loops=n_loops) + _ = optimizer.optimize(y, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return prior.transform(prior.add_name(best_fit)) diff --git a/src/jimgw/model.py b/src/jimgw/model.py index e18c86ef..add21c2b 100644 --- a/src/jimgw/model.py +++ b/src/jimgw/model.py @@ -1,9 +1,9 @@ import equinox as eqx -from abc import ABC, abstractmethod +from abc import abstractmethod from jaxtyping import Array -class Model(eqx.Module): +class Model(eqx.Module): params: dict @abstractmethod diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 0ce77d09..d96eacab 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp from flowMC.nfmodel.base import Distribution -from jaxtyping import Array, Float +from jaxtyping import Array, Float, Int, PRNGKeyArray from typing import Callable, Union from dataclasses import field @@ -17,13 +17,15 @@ class Prior(Distribution): """ naming: list[str] - transforms: dict[tuple[str, Callable]] = field(default_factory=dict) + transforms: dict[str, tuple[str, Callable]] = field(default_factory=dict) @property def n_dim(self): return len(self.naming) - def __init__(self, naming: list[str], transforms: dict[tuple[str, Callable]] = {}): + def __init__( + self, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {} + ): """ Parameters ---------- @@ -48,7 +50,7 @@ def make_lambda(name): # which will make lambda reference the last value of the variable name self.transforms[name] = (name, make_lambda(name)) - def transform(self, x: dict) -> dict: + def transform(self, x: dict[str, Float]) -> dict[str, Float]: """ Apply the transforms to the parameters. @@ -67,50 +69,54 @@ def transform(self, x: dict) -> dict: output[value[0]] = value[1](x) return output - def add_name(self, x: Array) -> dict: + def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ Turn an array into a dictionary Parameters ---------- x : Array - An array of parameters. Shape (n_dim, n_sample). + An array of parameters. Shape (n_dim,). """ return dict(zip(self.naming, x)) - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: raise NotImplementedError - def logpdf(self, x: dict) -> Float: + def log_prob(self, x: dict[str, Array]) -> Float: raise NotImplementedError class Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 + xmin: Float = 0.0 + xmax: Float = 1.0 def __init__( self, - xmin: float, - xmax: float, + xmin: Float, + xmax: Float, naming: list[str], - transforms: dict[tuple[str, Callable]] = {}, + transforms: dict[str, tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) - assert isinstance(xmin, float), "xmin must be a float" - assert isinstance(xmax, float), "xmax must be a float" + assert isinstance(xmin, Float), "xmin must be a Float" + assert isinstance(xmax, Float), "xmax must be a Float" assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: """ Sample from a uniform distribution. Parameters ---------- - rng_key : jax.random.PRNGKey + rng_key : PRNGKeyArray A random key to use for sampling. n_samples : int The number of samples to draw. @@ -126,7 +132,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: ) return self.add_name(samples[None]) - def log_prob(self, x: dict) -> Float: + def log_prob(self, x: dict[str, Array]) -> Float: variable = x[self.naming[0]] output = jnp.where( (variable >= self.xmax) | (variable <= self.xmin), @@ -137,44 +143,57 @@ def log_prob(self, x: dict) -> Float: class Unconstrained_Uniform(Prior): - xmin: float = 0.0 - xmax: float = 1.0 - to_range: Callable = lambda x: x + xmin: Float = 0.0 + xmax: Float = 1.0 def __init__( self, - xmin: float, - xmax: float, + xmin: Float, + xmax: Float, naming: list[str], - transforms: dict[tuple[str, Callable]] = {}, + transforms: dict[str, tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) - assert isinstance(xmin, float), "xmin must be a float" - assert isinstance(xmax, float), "xmax must be a float" + assert isinstance(xmin, Float), "xmin must be a Float" + assert isinstance(xmax, Float), "xmax must be a Float" assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin local_transform = self.transforms - self.to_range = ( - lambda x: (self.xmax - self.xmin) / (1 + jnp.exp(-x[self.naming[0]])) - + self.xmin - ) def new_transform(param): - param[self.naming[0]] = self.to_range(param) + param[self.naming[0]] = self.to_range(param[self.naming[0]]) return local_transform[self.naming[0]][1](param) self.transforms = { self.naming[0]: (local_transform[self.naming[0]][0], new_transform) } - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + def to_range(self, x: dict[str, Float]) -> Float: + """ + Transform the parameters to the range of the prior. + + Parameters + ---------- + x : dict + A dictionary of parameters. Names should match the ones in the prior. + + Returns + ------- + x : dict + A dictionary of parameters with the transforms applied. + """ + return (self.xmax - self.xmin) / (1 + jnp.exp(-x[self.naming[0]])) + self.xmin + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: """ Sample from a uniform distribution. Parameters ---------- - rng_key : jax.random.PRNGKey + rng_key : PRNGKeyArray A random key to use for sampling. n_samples : int The number of samples to draw. @@ -189,7 +208,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: samples = jnp.log(samples / (1 - samples)) return self.add_name(samples[None]) - def log_prob(self, x: dict) -> Float: + def log_prob(self, x: dict[str, Float]) -> Float: variable = x[self.naming[0]] return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) @@ -223,7 +242,9 @@ def __init__(self, naming: str): ), } - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: rng_keys = jax.random.split(rng_key, 3) theta = jnp.arccos( jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0) @@ -232,7 +253,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - def log_prob(self, x: dict) -> Float: + def log_prob(self, x: dict[str, Float]) -> Float: return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) @@ -252,18 +273,18 @@ class Alignedspin(Prior): See (A7) of https://arxiv.org/abs/1805.10457. """ - amax: float = 0.99 + amax: Float = 0.99 chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) def __init__( self, - amax: float, + amax: Float, naming: list[str], - transforms: dict[tuple[str, Callable]] = {}, + transforms: dict[str, tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) - assert isinstance(amax, float), "xmin must be a float" + assert isinstance(amax, Float), "xmin must be a Float" assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" self.amax = amax @@ -273,7 +294,9 @@ def __init__( self.chi_axis = chi_axis self.cdf_vals = cdf_vals - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: """ Sample from the Alignedspin distribution. @@ -291,7 +314,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: Parameters ---------- - rng_key : jax.random.PRNGKey + rng_key : PRNGKeyArray A random key to use for sampling. n_samples : int The number of samples to draw. @@ -326,7 +349,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: return self.add_name(samples[None]) - def log_prob(self, x: dict) -> Float: + def log_prob(self, x: dict[str, Float]) -> Float: variable = x[self.naming[0]] log_p = jnp.where( (variable >= self.amax) | (variable <= -self.amax), @@ -343,23 +366,23 @@ class Powerlaw(Prior): p(x) ~ x^{\alpha} """ - xmin: float = 0.0 - xmax: float = 1.0 - alpha: int = 0.0 - normalization: float = 1.0 + xmin: Float = 0.0 + xmax: Float = 1.0 + alpha: Float = 0.0 + normalization: Float = 1.0 def __init__( self, - xmin: float, - xmax: float, - alpha: Union[int, float], + xmin: Float, + xmax: Float, + alpha: Union[Int, Float], naming: list[str], - transforms: dict[tuple[str, Callable]] = {}, + transforms: dict[str, tuple[str, Callable]] = {}, ): super().__init__(naming, transforms) - assert isinstance(xmin, float), "xmin must be a float" - assert isinstance(xmax, float), "xmax must be a float" - assert isinstance(alpha, (int, float)), "alpha must be a int or a float" + assert isinstance(xmin, Float), "xmin must be a Float" + assert isinstance(xmax, Float), "xmax must be a Float" + assert isinstance(alpha, (Int, Float)), "alpha must be a int or a Float" if alpha < 0.0: assert xmin > 0.0, "With negative alpha, xmin must > 0" assert self.n_dim == 1, "Powerlaw needs to be 1D distributions" @@ -373,13 +396,15 @@ def __init__( self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha) ) - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: """ Sample from a power-law distribution. Parameters ---------- - rng_key : jax.random.PRNGKey + rng_key : PRNGKeyArray A random key to use for sampling. n_samples : int The number of samples to draw. @@ -401,7 +426,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: ) ** (1.0 / (1.0 + self.alpha)) return self.add_name(samples[None]) - def log_prob(self, x: dict) -> Float: + def log_prob(self, x: dict[str, Float]) -> Float: variable = x[self.naming[0]] log_in_range = jnp.where( (variable >= self.xmax) | (variable <= self.xmin), @@ -416,7 +441,7 @@ class Composite(Prior): priors: list[Prior] = field(default_factory=list) def __init__( - self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {} + self, priors: list[Prior], transforms: dict[str, tuple[str, Callable]] = {} ): naming = [] self.transforms = {} @@ -427,14 +452,16 @@ def __init__( self.naming = naming self.transforms.update(transforms) - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: output = {} for prior in self.priors: rng_key, subkey = jax.random.split(rng_key) output.update(prior.sample(subkey, n_samples)) return output - def log_prob(self, x: dict) -> Float: + def log_prob(self, x: dict[str, Float]) -> Float: output = 0.0 for prior in self.priors: output += prior.log_prob(x) diff --git a/src/jimgw/utils.py b/src/jimgw/utils.py index 2473ab7b..aba78a02 100644 --- a/src/jimgw/utils.py +++ b/src/jimgw/utils.py @@ -1,51 +1,142 @@ import jax.numpy as jnp from jax import jit +from jaxtyping import Float, Array + @jit -def inner_product(h1, h2, frequency, PSD): - """ - Do PSD interpolation outside the inner product loop to speed up the evaluation - """ - #psd_interp = jnp.interp(frequency, PSD_frequency, PSD) - df = frequency[1] - frequency[0] - integrand = jnp.conj(h1)* h2 / PSD - return 4. * jnp.real(jnp.trapz(integrand,dx=df)) +def inner_product( + h1: Float[Array, " n_sample"], + h2: Float[Array, " n_sample"], + frequency: Float[Array, " n_sample"], + psd: Float[Array, " n_sample"], +) -> Float: + """ + Evaluating the inner product of two waveforms h1 and h2 with the psd. + + Do psd interpolation outside the inner product loop to speed up the evaluation + + Parameters + ---------- + h1 : Float[Array, "n_sample"] + First waveform. Can be complex. + h2 : Float[Array, "n_sample"] + Second waveform. Can be complex. + frequency : Float[Array, "n_sample"] + Frequency array. + psd : Float[Array, "n_sample"] + Power spectral density. + + Returns + ------- + Float + Inner product of h1 and h2 with the psd. + """ + # psd_interp = jnp.interp(frequency, psd_frequency, psd) + df = frequency[1] - frequency[0] + integrand = jnp.conj(h1) * h2 / psd + return 4.0 * jnp.real(jnp.trapz(integrand, dx=df)) + @jit -def m1m2_to_Mq(m1: float,m2: float): - """ - Transforming the primary mass m1 and secondary mass m2 to the Total mass M - and mass ratio q. - - Args: - m1: Primary mass of the binary. - m2: Secondary mass of the binary. - - Returns: - A tuple containing both the total mass M and mass ratio q. - """ - M_tot = jnp.log(m1+m2) - q = jnp.log(m2/m1)-jnp.log(1-m2/m1) - return M_tot, q +def m1m2_to_Mq(m1: Float, m2: Float): + """ + Transforming the primary mass m1 and secondary mass m2 to the Total mass M + and mass ratio q. + + Parameters + ---------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + + Returns + ------- + M_tot : Float + Total mass. + q : Float + Mass ratio. + """ + M_tot = jnp.log(m1 + m2) + q = jnp.log(m2 / m1) - jnp.log(1 - m2 / m1) + return M_tot, q + @jit -def Mq_to_m1m2(trans_M_tot,trans_q): - M_tot = jnp.exp(trans_M_tot) - q = 1./(1+jnp.exp(-trans_q)) - m1 = M_tot/(1+q) - m2 = m1*q - return m1, m2 +def Mq_to_m1m2(trans_M_tot: Float, trans_q: Float): + """ + Transforming the Total mass M and mass ratio q to the primary mass m1 and + secondary mass m2. + + Parameters + ---------- + M_tot : Float + Total mass. + q : Float + Mass ratio. + + Returns + ------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + """ + M_tot = jnp.exp(trans_M_tot) + q = 1.0 / (1 + jnp.exp(-trans_q)) + m1 = M_tot / (1 + q) + m2 = m1 * q + return m1, m2 + @jit -def Mc_q_to_m1m2(Mc,q): - eta = q/(1+q)**2 - M_tot = Mc/eta**(3./5) - m1 = M_tot/(1+q) - m2 = m1*q - return m1, m2 - -def ra_dec_to_theta_phi(ra, dec, gmst): +def Mc_q_to_m1m2(Mc: Float, q: Float) -> tuple[Float, Float]: + """ + Transforming the chirp mass Mc and mass ratio q to the primary mass m1 and + secondary mass m2. + + Parameters + ---------- + Mc : Float + Chirp mass. + q : Float + Mass ratio. + + Returns + ------- + m1 : Float + Primary mass. + m2 : Float + Secondary mass. + """ + eta = q / (1 + q) ** 2 + M_tot = Mc / eta ** (3.0 / 5) + m1 = M_tot / (1 + q) + m2 = m1 * q + return m1, m2 + + +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ phi = ra - gmst theta = jnp.pi / 2 - dec return theta, phi - diff --git a/src/jimgw/wave.py b/src/jimgw/wave.py index ddc6cd4c..63acd56d 100644 --- a/src/jimgw/wave.py +++ b/src/jimgw/wave.py @@ -1,14 +1,13 @@ # Credit some part of the source code from bilby import jax.numpy as jnp -from jimgw.constants import * import equinox as eqx -from jaxtyping import Array +from jaxtyping import Array, Float -KNOWN_POLS = 'pcxybl' +KNOWN_POLS = "pcxybl" -class Polarization(eqx.Module): +class Polarization(eqx.Module): name: str """Object defining a given polarization mode, with utilities to produce corresponding tensor in an Earth centric frame. @@ -19,48 +18,53 @@ class Polarization(eqx.Module): one of 'p' (plus), 'c' (cross), 'x' (vector x), 'y' (vector y), 'b' (breathing), or 'l' (longitudinal). """ - def __init__(self, name): + + def __init__(self, name: str): self.name = name.lower() if self.name not in KNOWN_POLS: e = f"unknown mode '{self.name}'; must be one of: {KNOWN_POLS}" raise ValueError(e) - def tensor_from_basis(self, x: Array, y: Array) -> Array: + def tensor_from_basis( + self, x: Float[Array, " 3"], y: Float[Array, " 3"] + ) -> Float[Array, " 3, 3"]: """Constructor to obtain polarization tensor from waveframe basis defined by orthonormal vectors (x, y) in arbitrary Cartesian coordinates. """ - if self.name == 'p': - return jnp.einsum('i,j->ij', x, x) - jnp.einsum('i,j->ij', y, y) - elif self.name == 'c': - return jnp.einsum('i,j->ij', x, y) + jnp.einsum('i,j->ij', y, x) - elif self.name == 'x': + if self.name == "p": + return jnp.einsum("i,j->ij", x, x) - jnp.einsum("i,j->ij", y, y) + elif self.name == "c": + return jnp.einsum("i,j->ij", x, y) + jnp.einsum("i,j->ij", y, x) + elif self.name == "x": z = jnp.cross(x, y) - return jnp.einsum('i,j->ij', x, z) + jnp.einsum('i,j->ij', z, x) - elif self.name == 'y': + return jnp.einsum("i,j->ij", x, z) + jnp.einsum("i,j->ij", z, x) + elif self.name == "y": z = jnp.cross(x, y) - return jnp.einsum('i,j->ij', y, z) + jnp.einsum('i,j->ij', z, y) - elif self.name == 'b': - return jnp.einsum('i,j->ij', x, x) + jnp.einsum('i,j->ij', y, y) - elif self.name == 'l': + return jnp.einsum("i,j->ij", y, z) + jnp.einsum("i,j->ij", z, y) + elif self.name == "b": + return jnp.einsum("i,j->ij", x, x) + jnp.einsum("i,j->ij", y, y) + elif self.name == "l": z = jnp.cross(x, y) - return jnp.einsum('i,j->ij', z, z) + return jnp.einsum("i,j->ij", z, z) else: raise ValueError(f"unrecognized polarization {self.name}") - def tensor_from_sky(self, ra: float, dec: float, psi: float, gmst: float) -> Array: + def tensor_from_sky( + self, ra: Float, dec: Float, psi: Float, gmst: Float + ) -> Float[Array, " 3, 3"]: """Computes {name} polarization tensor in celestial coordinates from sky location and orientation parameters. Arguments --------- - ra : float + ra : Float right ascension in radians. - dec : float + dec : Float declination in radians. - psi : float + psi : Float polarization angle in radians. - gmst : float + gmst : Float Greenwhich mean standard time (GMST) in radians. Returns @@ -68,20 +72,19 @@ def tensor_from_sky(self, ra: float, dec: float, psi: float, gmst: float) -> Arr tensor : array 3x3 polarization tensor. """ - gmst = jnp.mod(gmst, 2*jnp.pi) + gmst = jnp.mod(gmst, 2 * jnp.pi) phi = ra - gmst theta = jnp.pi / 2 - dec - u = jnp.array([jnp.cos(phi) * jnp.cos(theta), - jnp.cos(theta) * jnp.sin(phi), - -jnp.sin(theta)]) + u = jnp.array( + [ + jnp.cos(phi) * jnp.cos(theta), + jnp.cos(theta) * jnp.sin(phi), + -jnp.sin(theta), + ] + ) v = jnp.array([-jnp.sin(phi), jnp.cos(phi), 0]) m = -u * jnp.sin(psi) - v * jnp.cos(psi) n = -u * jnp.cos(psi) + v * jnp.sin(psi) return self.tensor_from_basis(m, n) - - - - - diff --git a/src/jimgw/waveform.py b/src/jimgw/waveform.py index 11220021..582267cf 100644 --- a/src/jimgw/waveform.py +++ b/src/jimgw/waveform.py @@ -1,39 +1,42 @@ -from jaxtyping import Array +from jaxtyping import Array, Float from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc from ripple.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc -import jax.numpy as jnp from abc import ABC +import jax.numpy as jnp class Waveform(ABC): def __init__(self): return NotImplemented - def __call__(self, axis: Array, params: Array) -> dict: + def __call__( + self, axis: Float[Array, " n_dim"], params: dict[str, Float] + ) -> dict[str, Float[Array, " n_dim"]]: return NotImplemented class RippleIMRPhenomD(Waveform): - f_ref: float def __init__(self, f_ref: float = 20.0): self.f_ref = f_ref - def __call__(self, frequency: Array, params: dict) -> dict: + def __call__( + self, frequency: Float[Array, " n_dim"], params: dict[str, Float] + ) -> dict[str, Float[Array, " n_dim"]]: output = {} - ra = params["ra"] - dec = params["dec"] - theta = [ - params["M_c"], - params["eta"], - params["s1_z"], - params["s2_z"], - params["d_L"], - 0, - params["phase_c"], - params["iota"], - ] + theta = jnp.array( + [ + params["M_c"], + params["eta"], + params["s1_z"], + params["s2_z"], + params["d_L"], + 0, + params["phase_c"], + params["iota"], + ] + ) hp, hc = gen_IMRPhenomD_hphc(frequency, theta, self.f_ref) output["p"] = hp output["c"] = hc @@ -41,28 +44,31 @@ def __call__(self, frequency: Array, params: dict) -> dict: class RippleIMRPhenomPv2(Waveform): - f_ref: float def __init__(self, f_ref: float = 20.0): self.f_ref = f_ref - def __call__(self, frequency: Array, params: dict) -> dict: + def __call__( + self, frequency: Float[Array, " n_dim"], params: dict[str, Float] + ) -> dict[str, Float[Array, " n_dim"]]: output = {} - theta = [ - params["M_c"], - params["eta"], - params['s1_x'], - params['s1_y'], - params["s1_z"], - params['s2_x'], - params['s2_y'], - params["s2_z"], - params["d_L"], - 0, - params["phase_c"], - params["iota"], - ] + theta = jnp.array( + [ + params["M_c"], + params["eta"], + params["s1_x"], + params["s1_y"], + params["s1_z"], + params["s2_x"], + params["s2_y"], + params["s2_z"], + params["d_L"], + 0, + params["phase_c"], + params["iota"], + ] + ) hp, hc = gen_IMRPhenomPv2_hphc(frequency, theta, self.f_ref) output["p"] = hp output["c"] = hc