diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index c58fd662..30cbfb35 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -25,4 +25,4 @@ jobs: python -m pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi python -m pip install . - - uses: pre-commit/action@v3.0.0 + - uses: pre-commit/action@v3.0.1 diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2f0086ac..3047c0a4 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -104,19 +104,38 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): if initial_position.size == 0: - initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan - - while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): - non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] + initial_position = ( + jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan + ) + + while not jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] key, subkey = jax.random.split(key) guess = self.prior.sample(subkey, self.sampler.n_chains) for transform in self.sample_transforms: guess = jax.vmap(transform.forward)(guess) - guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T - finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0] + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in self.parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] common_length = min(len(finite_guess), len(non_finite_index)) - initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length]) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( @@ -157,7 +176,7 @@ def print_summary(self, transform: bool = True): training_chain = self.add_name(training_chain) if transform: for sample_transform in reversed(self.sample_transforms): - training_chain = sample_transform.backward(training_chain) + training_chain = jax.vmap(sample_transform.backward)(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] @@ -167,7 +186,7 @@ def print_summary(self, transform: bool = True): production_chain = self.add_name(production_chain) if transform: for sample_transform in reversed(self.sample_transforms): - production_chain = sample_transform.backward(production_chain) + production_chain = jax.vmap(sample_transform.backward)(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -223,10 +242,10 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.sampler.get_sampler_state(training=False)["chains"] - chains = chains.transpose(2, 0, 1) + chains = chains.reshape(-1, self.prior.n_dim) chains = self.add_name(chains) for sample_transform in reversed(self.sample_transforms): - chains = sample_transform.backward(chains) + chains = jax.vmap(sample_transform.backward)(chains) return chains def plot(self): diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 00e6ce6b..9e775b33 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -26,6 +26,15 @@ def __init__(self, detectors: list[Detector], waveform: Waveform) -> None: self.waveform = waveform +class ZeroLikelihood(LikelihoodBase): + + def __init__(self): + pass + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + return 0.0 + + class TransientLikelihoodFD(SingleEventLiklihood): def __init__( self, diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 0cfae761..1bf8f3a7 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -5,6 +5,7 @@ from jimgw.single_event.detector import GroundBased2G from jimgw.transforms import ( + ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform, reverse_bijective_transform, @@ -111,34 +112,318 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform + +@jaxtyped(typechecker=typechecker) +class GeocentricArrivalTimeToDetectorArrivalTimeTransform( + ConditionalBijectiveTransform +): + """ + Transform the geocentric arrival time to detector arrival time + + In the geocentric convention, the arrival time of the signal at the + center of Earth is gps_time + t_c + + In the detector convention, the arrival time of the signal at the + detecotr is gps_time + time_delay_from_geo_to_det + t_det + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + ifo: GroundBased2G + tc_min: Float + tc_max: Float + + def __init__( + self, + gps_time: Float, + ifo: GroundBased2G, + tc_min: Float, + tc_max: Float, + ): + name_mapping = (["t_c"], ["t_det_unbounded"]) + conditional_names = ["ra", "dec"] + super().__init__(name_mapping, conditional_names) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifo = ifo + self.tc_min = tc_min + self.tc_max = tc_max + + assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1] + assert "ra" in conditional_names and "dec" in conditional_names + + def time_delay(ra, dec, gmst): + return self.ifo.delay_from_geocenter(ra, dec, gmst) + + def named_transform(x): + + time_shift = time_delay(x["ra"], x["dec"], self.gmst) + + t_det = x["t_c"] + time_shift + t_det_min = self.tc_min + time_shift + t_det_max = self.tc_max + time_shift + + y = (t_det - t_det_min) / (t_det_max - t_det_min) + t_det_unbounded = jnp.log(y / (1.0 - y)) + return { + "t_det_unbounded": t_det_unbounded, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + + time_shift = self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) + + t_det_min = self.tc_min + time_shift + t_det_max = self.tc_max + time_shift + t_det = (t_det_max - t_det_min) / ( + 1.0 + jnp.exp(-x["t_det_unbounded"]) + ) + t_det_min + + t_c = t_det - time_shift + + return { + "t_c": t_c, + } + + self.inverse_transform_func = named_inverse_transform + + +@jaxtyped(typechecker=typechecker) +class GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( + ConditionalBijectiveTransform +): + """ + Transform the geocentric arrival phase to detector arrival phase + + In the geocentric convention, the arrival phase of the signal at the + center of Earth is phase_c / 2 (in ripple, phase_c is the orbital phase) + + In the detector convention, the arrival phase of the signal at the + detecotr is phase_det = phase_c / 2 + arg R_det + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + ifo: GroundBased2G + + def __init__( + self, + gps_time: Float, + ifo: GroundBased2G, + ): + name_mapping = (["phase_c"], ["phase_det"]) + conditional_names = ["ra", "dec", "psi", "iota"] + super().__init__(name_mapping, conditional_names) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifo = ifo + + assert "phase_c" in name_mapping[0] and "phase_det" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + and "psi" in conditional_names + and "iota" in conditional_names + ) + + def _calc_R_det_arg(ra, dec, psi, iota, gmst): + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + + antenna_pattern = self.ifo.antenna_pattern(ra, dec, psi, gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + + return jnp.angle(p_mode_term - 1j * c_mode_term) + + def named_transform(x): + R_det_arg = _calc_R_det_arg( + x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + ) + phase_det = R_det_arg + x["phase_c"] / 2.0 + return { + "phase_det": phase_det % (2.0 * jnp.pi), + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + R_det_arg = _calc_R_det_arg( + x["ra"], x["dec"], x["psi"], x["iota"], self.gmst + ) + phase_c = -R_det_arg + x["phase_det"] * 2.0 + return { + "phase_c": phase_c % (2.0 * jnp.pi), + } + + self.inverse_transform_func = named_inverse_transform + + +@jaxtyped(typechecker=typechecker) +class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform): + """ + Transform the luminosity distance to network SNR weighted distance + + Parameters + ---------- + name_mapping : tuple[list[str], list[str]] + The name mapping between the input and output dictionary. + + """ + + gmst: Float + ifos: list[GroundBased2G] + dL_min: Float + dL_max: Float + + def __init__( + self, + gps_time: Float, + ifos: list[GroundBased2G], + dL_min: Float, + dL_max: Float, + ): + name_mapping = (["d_L"], ["d_hat_unbounded"]) + conditional_names = ["M_c", "ra", "dec", "psi", "iota"] + super().__init__(name_mapping, conditional_names) + + self.gmst = ( + Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad + ) + self.ifos = ifos + self.dL_min = dL_min + self.dL_max = dL_max + + assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1] + assert ( + "ra" in conditional_names + and "dec" in conditional_names + and "psi" in conditional_names + and "iota" in conditional_names + and "M_c" in conditional_names + ) + + def _calc_R_dets(ra, dec, psi, iota): + p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 + c_iota_term = jnp.cos(iota) + R_dets2 = 0.0 + for ifo in self.ifos: + antenna_pattern = ifo.antenna_pattern(ra, dec, psi, self.gmst) + p_mode_term = p_iota_term * antenna_pattern["p"] + c_mode_term = c_iota_term * antenna_pattern["c"] + R_dets2 += p_mode_term**2 + c_mode_term**2 + + return jnp.sqrt(R_dets2) + + def named_transform(x): + d_L, M_c = ( + x["d_L"], + x["M_c"], + ) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) + + scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets + d_hat = scale_factor * d_L + + d_hat_min = scale_factor * self.dL_min + d_hat_max = scale_factor * self.dL_max + + y = (d_hat - d_hat_min) / (d_hat_max - d_hat_min) + d_hat_unbounded = jnp.log(y / (1.0 - y)) + + return { + "d_hat_unbounded": d_hat_unbounded, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + d_hat_unbounded, M_c = ( + x["d_hat_unbounded"], + x["M_c"], + ) + R_dets = _calc_R_dets(x["ra"], x["dec"], x["psi"], x["iota"]) + + scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets + + d_hat_min = scale_factor * self.dL_min + d_hat_max = scale_factor * self.dL_max + + d_hat = (d_hat_max - d_hat_min) / ( + 1.0 + jnp.exp(-d_hat_unbounded) + ) + d_hat_min + d_L = d_hat / scale_factor + return { + "d_L": d_L, + } + + self.inverse_transform_func = named_inverse_transform + + def named_m1_m2_to_Mc_q(x): Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) return {"M_c": Mc, "q": q} + def named_Mc_q_to_m1_m2(x): m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) + +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "q"]) +) ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q -ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = ( + named_Mc_q_to_m1_m2 +) + def named_m1_m2_to_Mc_eta(x): Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) return {"M_c": Mc, "eta": eta} + def named_Mc_eta_to_m1_m2(x): m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) -ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta -ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 + +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "eta"]) +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = ( + named_m1_m2_to_Mc_eta +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = ( + named_Mc_eta_to_m1_m2 +) + def named_q_to_eta(x): return {"eta": q_to_eta(x["q"])} + + def named_eta_to_q(x): return {"q": eta_to_q(x["eta"])} + + MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 3ad51e62..39c55642 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -61,8 +61,6 @@ def forward(self, x: dict[str, Float]) -> dict[str, Float]: class NtoNTransform(NtoMTransform): - transform_func: Callable[[dict[str, Float]], dict[str, Float]] - @property def n_dim(self) -> int: return len(self.name_mapping[0]) @@ -164,6 +162,71 @@ def backward(self, y: dict[str, Float]) -> dict[str, Float]: return y_copy +class ConditionalBijectiveTransform(BijectiveTransform): + + conditional_names: list[str] + + def __init__( + self, + name_mapping: tuple[list[str], list[str]], + conditional_names: list[str], + ): + super().__init__(name_mapping) + self.conditional_names = conditional_names + + def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: + x_copy = x.copy() + transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) + transform_params.update( + dict((key, x_copy[key]) for key in self.conditional_names) + ) + output_params = self.transform_func(transform_params) + jacobian = jax.jacfwd(self.transform_func)(transform_params) + jacobian_copy = { + key1: {key2: jacobian[key1][key2] for key2 in self.name_mapping[0]} + for key1 in self.name_mapping[1] + } + jacobian = jnp.array(jax.tree.leaves(jacobian_copy)) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) + jax.tree.map( + lambda key: x_copy.pop(key), + self.name_mapping[0], + ) + jax.tree.map( + lambda key: x_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return x_copy, jacobian + + def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: + y_copy = y.copy() + transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) + transform_params.update( + dict((key, y_copy[key]) for key in self.conditional_names) + ) + output_params = self.inverse_transform_func(transform_params) + jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) + jacobian_copy = { + key1: {key2: jacobian[key1][key2] for key2 in self.name_mapping[1]} + for key1 in self.name_mapping[0] + } + jacobian = jnp.array(jax.tree.leaves(jacobian_copy)) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) + jax.tree.map( + lambda key: y_copy.pop(key), + self.name_mapping[1], + ) + jax.tree.map( + lambda key: y_copy.update({key: output_params[key]}), + list(output_params.keys()), + ) + return y_copy, jacobian + + @jaxtyped(typechecker=typechecker) class ScaleTransform(BijectiveTransform): scale: Float @@ -355,17 +418,25 @@ class SingleSidedUnboundTransform(BijectiveTransform): """ + original_lower_bound: Float + def __init__( self, name_mapping: tuple[list[str], list[str]], + original_lower_bound: Float, ): super().__init__(name_mapping) + self.original_lower_bound = jnp.atleast_1d(original_lower_bound) + self.transform_func = lambda x: { - name_mapping[1][i]: jnp.exp(x[name_mapping[0][i]]) + name_mapping[1][i]: jnp.log( + x[name_mapping[0][i]] - self.original_lower_bound[i] + ) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { - name_mapping[0][i]: jnp.log(x[name_mapping[1][i]]) + name_mapping[0][i]: jnp.exp(x[name_mapping[1][i]]) + + self.original_lower_bound[i] for i in range(len(name_mapping[1])) } diff --git a/test/integration/test_extrinsic.py b/test/integration/test_extrinsic.py new file mode 100644 index 00000000..300f2132 --- /dev/null +++ b/test/integration/test_extrinsic.py @@ -0,0 +1,106 @@ +from astropy.time import Time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import ZeroLikelihood +from jimgw.transforms import BoundToUnbound, SingleSidedUnboundTransform +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 + +ifos = [H1, L1, V1] + +M_c_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + M_c_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + + +sample_transforms = [ + # all the user reparametrization transform + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + # all the bound to unbound transform + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["phase_det"], ["phase_det_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [] + +likelihood = ZeroLikelihood() + +mass_matrix = jnp.eye(len(prior.base_prior)) +#mass_matrix = mass_matrix.at[1, 1].set(1e-3) +#mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + +n_epochs = 2 +n_loop_training = 1 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=2, + n_global_steps=2, + n_chains=10, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=["default"], +) + +print("Start sampling") +key = jax.random.PRNGKey(42) +jim.sample(key) +jim.print_summary() +samples = jim.get_samples()