diff --git a/example/GW150914.py b/example/GW150914.py deleted file mode 100644 index 559b5b7c..00000000 --- a/example/GW150914.py +++ /dev/null @@ -1,138 +0,0 @@ -import time - -import jax -import jax.numpy as jnp - -from jimgw.jim import Jim -from jimgw.prior import Composite, Unconstrained_Uniform -from jimgw.single_event.detector import H1, L1 -from jimgw.single_event.likelihood import TransientLikelihoodFD -from jimgw.single_event.waveform import RippleIMRPhenomD -from flowMC.strategy.optimization import optimization_Adam - -jax.config.update("jax_enable_x64", True) - -########################################### -########## First we grab data ############# -########################################### - -total_time_start = time.time() - -# first, fetch a 4s segment centered on GW150914 -gps = 1126259462.4 -duration = 4 -post_trigger_duration = 2 -start_pad = duration - post_trigger_duration -end_pad = post_trigger_duration -fmin = 20.0 -fmax = 1024.0 - -ifos = ["H1", "L1"] - -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) - -Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) -q_prior = Unconstrained_Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, -) -s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) -s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) -dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) -t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos( - jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) -psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin( - jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) - -prior = Composite( - [ - Mc_prior, - q_prior, - s1z_prior, - s2z_prior, - dL_prior, - t_c_prior, - phase_c_prior, - cos_iota_prior, - psi_prior, - ra_prior, - sin_dec_prior, - ] -) -likelihood = TransientLikelihoodFD( - [H1, L1], - waveform=RippleIMRPhenomD(), - trigger_time=gps, - duration=4, - post_trigger_duration=2, -) - - -mass_matrix = jnp.eye(11) -mass_matrix = mass_matrix.at[1, 1].set(1e-3) -mass_matrix = mass_matrix.at[5, 5].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} - -Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) - -import optax -n_epochs = 20 -n_loop_training = 100 -total_epochs = n_epochs * n_loop_training -start = total_epochs//10 -learning_rate = optax.polynomial_schedule( - 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -) - - -jim = Jim( - likelihood, - prior, - n_loop_training=n_loop_training, - n_loop_production=20, - n_local_steps=10, - n_global_steps=1000, - n_chains=500, - n_epochs=n_epochs, - learning_rate=learning_rate, - n_max_examples=30000, - n_flow_samples=100000, - momentum=0.9, - batch_size=30000, - use_global=True, - train_thinning=1, - output_thinning=10, - local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer,"default"], -) - -jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW150914_IMRPhenomD.py b/example/GW150914_IMRPhenomD.py new file mode 100644 index 00000000..66619ddc --- /dev/null +++ b/example/GW150914_IMRPhenomD.py @@ -0,0 +1,133 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.single_event.detector import H1, L1 +from jimgw.single_event.likelihood import TransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +eta_min, eta_max = 0.2, 0.25 +# m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +# m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"]) +eta_prior = UniformPrior(eta_min, eta_max, parameter_names=["eta"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( + [ + Mc_prior, + eta_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + iota_prior, + psi_prior, + ra_prior, + dec_prior, + ] +) + +sample_transforms = [ + # ComponentMassesToChirpMassMassRatioTransform, + BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = (["eta"], ["eta_unbounded"]), original_lower_bound=eta_min, original_upper_bound=eta_max), + BoundToUnbound(name_mapping = (["s1_z"], ["s1_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["s2_z"], ["s2_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=1.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]), original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + # ComponentMassesToChirpMassSymmetricMassRatioTransform, +] + +likelihood = TransientLikelihoodFD( + ifos, + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +n_epochs = 30 +n_loop_training = 20 +learning_rate = 1e-4 + + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_sample=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], + verbose=True +) + +jim.sample(jax.random.PRNGKey(42)) +# jim.get_samples() +# jim.print_summary() \ No newline at end of file diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 88ffe52b..5c678b22 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -28,21 +28,20 @@ ] ) - run = SingleEventRun( seed=0, detectors=["H1", "L1"], priors={ - "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0}, + "M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0}, "q": {"name": "MassRatio"}, - "s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0}, - "s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0}, - "d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0}, - "t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05}, - "phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0}, + "t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, "cos_iota": {"name": "CosIota"}, - "psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi}, - "ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, "sin_dec": {"name": "SinDec"}, }, waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, @@ -90,3 +89,9 @@ ) run_manager = SingleEventPERunManager(run=run) +run_manager.jim.sample(jax.random.PRNGKey(42)) + +# plot the corner plot and diagnostic plot +run_manager.plot_corner() +run_manager.plot_diagnostic() + diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index fae0bc98..2f0086ac 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -104,18 +104,19 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): if initial_position.size == 0: - initial_guess = [] - for _ in range(self.sampler.n_chains): - flag = True - while flag: - key = jax.random.split(key)[1] - guess = self.prior.sample(key, 1) - for transform in self.sample_transforms: - guess = transform.forward(guess) - guess = jnp.array([i for i in guess.values()]).T[0] - flag = not jnp.all(jnp.isfinite(guess)) - initial_guess.append(guess) - initial_position = jnp.array(initial_guess) + initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan + + while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): + non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] + + key, subkey = jax.random.split(key) + guess = self.prior.sample(subkey, self.sampler.n_chains) + for transform in self.sample_transforms: + guess = jax.vmap(transform.forward)(guess) + guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T + finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0] + common_length = min(len(finite_guess), len(non_finite_index)) + initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length]) self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 3f65166d..7b720c30 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -4,6 +4,8 @@ import jax import jax.numpy as jnp import matplotlib.pyplot as plt +import corner +import numpy as np import yaml from astropy.time import Time from jaxlib.xla_extension import ArrayImpl @@ -71,7 +73,8 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + path: str = "./experiment" + injection_parameters: dict[str, float] = field(default_factory=lambda: {}) injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -123,6 +126,9 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError + if self.run.injection and not self.run.injection_parameters: + raise ValueError("Injection mode requires injection parameters.") + local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) @@ -150,6 +156,7 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: waveform = self.initialize_waveform() name = self.run.likelihood_parameters["name"] assert isinstance(name, str), "Likelihood name must be a string." + assert name in likelihood_presets, f"Likelihood {name} not recognized." if self.run.injection: freqs = jnp.linspace( self.run.data_parameters["f_min"], @@ -351,3 +358,73 @@ def plot_data(self, path: str): plt.ylabel("Amplitude") plt.legend() plt.savefig(path) + + def sample(self): + self.jim.sample(jax.random.PRNGKey(self.run.seed)) + + def get_samples(self): + return self.jim.get_samples() + + def plot_corner(self, path: str = "corner.jpeg", **kwargs): + """ + plot corner plot of the samples. + """ + plot_datapoint = kwargs.get("plot_datapoints", False) + title_quantiles = kwargs.get("title_quantiles", [0.16, 0.5, 0.84]) + show_titles = kwargs.get("show_titles", True) + title_fmt = kwargs.get("title_fmt", ".2E") + use_math_text = kwargs.get("use_math_text", True) + + samples = self.jim.get_samples() + param_names = list(samples.keys()) + samples = np.array(list(samples.values())).reshape(int(len(param_names)), -1).T + corner.corner( + samples, + labels=param_names, + plot_datapoints=plot_datapoint, + title_quantiles=title_quantiles, + show_titles=show_titles, + title_fmt=title_fmt, + use_math_text=use_math_text, + **kwargs, + ) + plt.savefig(path) + plt.close() + + def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): + """ + plot diagnostic plot of the samples. + """ + summary = self.jim.Sampler.get_sampler_state(training=True) + chains, log_prob, local_accs, global_accs, loss_vals = summary.values() + log_prob = np.array(log_prob) + + plt.figure(figsize=(10, 10)) + axs = [plt.subplot(2, 2, i + 1) for i in range(4)] + plt.sca(axs[0]) + plt.title("log probability") + plt.plot(log_prob.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[1]) + plt.title("NF loss") + plt.plot(loss_vals.reshape(-1)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[2]) + plt.title("Local Acceptance") + plt.plot(local_accs.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[3]) + plt.title("Global Acceptance") + plt.plot(global_accs.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + plt.tight_layout() + + plt.savefig(path) + plt.close() diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 084fe368..da7864ba 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -8,6 +8,7 @@ ConditionalBijectiveTransform, BijectiveTransform, NtoNTransform, + reverse_bijective_transform, ) from jimgw.single_event.utils import ( m1_m2_to_Mc_q, @@ -24,111 +25,56 @@ @jaxtyped(typechecker=typechecker) -class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): - """ - Transform chirp mass and mass ratio to component masses - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. +class SpinToCartesianSpinTransform(NtoNTransform): """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - assert ( - "m_1" in name_mapping[0] - and "m_2" in name_mapping[0] - and "M_c" in name_mapping[1] - and "q" in name_mapping[1] - ) - - def named_transform(x): - Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) - return {"M_c": Mc, "q": q} - - self.transform_func = named_transform - - def named_inverse_transform(x): - m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform): + Spin to Cartesian spin transformation """ - Transform mass ratio to symmetric mass ratio - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ + freq_ref: Float def __init__( self, - name_mapping: tuple[list[str], list[str]], + freq_ref: Float, ): - super().__init__(name_mapping) - assert ( - "m_1" in name_mapping[0] - and "m_2" in name_mapping[0] - and "M_c" in name_mapping[1] - and "eta" in name_mapping[1] + name_mapping = ( + ["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], + ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"], ) + super().__init__(name_mapping) + + self.freq_ref = freq_ref def named_transform(x): - Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) - return {"M_c": Mc, "eta": eta} + iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( + x["theta_jn"], + x["phi_jl"], + x["theta_1"], + x["theta_2"], + x["phi_12"], + x["a_1"], + x["a_2"], + x["M_c"], + x["q"], + self.freq_ref, + x["phase_c"], + ) + return { + "iota": iota, + "s1_x": s1x, + "s1_y": s1y, + "s1_z": s1z, + "s2_x": s2x, + "s2_y": s2y, + "s2_z": s2z, + } self.transform_func = named_transform - def named_inverse_transform(x): - m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): - """ - Transform mass ratio to symmetric mass ratio - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0] - - self.transform_func = lambda x: {"eta": q_to_eta(x["q"])} - self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])} - @jaxtyped(typechecker=typechecker) class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): """ Transform sky frame to detector frame sky position - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ gmst: Float @@ -137,10 +83,10 @@ class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): def __init__( self, - name_mapping: tuple[list[str], list[str]], gps_time: Float, - ifos: GroundBased2G, + ifos: list[GroundBased2G], ): + name_mapping = (["ra", "dec"], ["zenith", "azimuth"]) super().__init__(name_mapping) self.gmst = ( @@ -150,13 +96,6 @@ def __init__( self.rotation = euler_rotation(delta_x) self.rotation_inv = jnp.linalg.inv(self.rotation) - assert ( - "ra" in name_mapping[0] - and "dec" in name_mapping[0] - and "zenith" in name_mapping[1] - and "azimuth" in name_mapping[1] - ) - def named_transform(x): zenith, azimuth = ra_dec_to_zenith_azimuth( x["ra"], x["dec"], self.gmst, self.rotation @@ -501,3 +440,48 @@ def named_transform(x): } self.transform_func = named_transform + + +def named_m1_m2_to_Mc_q(x): + Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) + return {"M_c": Mc, "q": q} + + +def named_Mc_q_to_m1_m2(x): + m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} + +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) +ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 + +def named_m1_m2_to_Mc_eta(x): + Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) + return {"M_c": Mc, "eta": eta} + +def named_Mc_eta_to_m1_m2(x): + m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) + return {"m_1": m1, "m_2": m2} + +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 + +def named_q_to_eta(x): + return {"eta": q_to_eta(x["q"])} +def named_eta_to_q(x): + return {"q": eta_to_q(x["eta"])} +MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) +MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta +MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q + + +ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform( + ComponentMassesToChirpMassMassRatioTransform +) +ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform( + ComponentMassesToChirpMassSymmetricMassRatioTransform +) +SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform( + MassRatioToSymmetricMassRatioTransform +) diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index a15bd7bf..fb35bf27 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -276,6 +276,165 @@ def eta_to_q(eta: Float) -> Float: return temp - (temp**2 - 1) ** 0.5 +def spin_to_cartesian_spin( + thetaJN: Float, + phiJL: Float, + theta1: Float, + theta2: Float, + phi12: Float, + chi1: Float, + chi2: Float, + M_c: Float, + eta: Float, + fRef: Float, + phiRef: Float, +) -> tuple[Float, Float, Float, Float, Float, Float, Float]: + """ + Transforming the spin parameters + + The code is based on the approach used in LALsimulation: + https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html + + Parameters: + ------- + thetaJN: Float + Zenith angle between the total angular momentum and the line of sight + phiJL: Float + Difference between total and orbital angular momentum azimuthal angles + theta1: Float + Zenith angle between the spin and orbital angular momenta for the primary object + theta2: Float + Zenith angle between the spin and orbital angular momenta for the secondary object + phi12: Float + Difference between the azimuthal angles of the individual spin vector projections + onto the orbital plane + chi1: Float + Primary object aligned spin: + chi2: Float + Secondary object aligned spin: + M_c: Float + The chirp mass + eta: Float + The symmetric mass ratio + fRef: Float + The reference frequency + phiRef: Float + Binary phase at a reference frequency + + Returns: + ------- + iota: Float + Zenith angle between the orbital angular momentum and the line of sight + S1x: Float + The x-component of the primary spin + S1y: Float + The y-component of the primary spin + S1z: Float + The z-component of the primary spin + S2x: Float + The x-component of the secondary spin + S2y: Float + The y-component of the secondary spin + S2z: Float + The z-component of the secondary spin + """ + + def rotate_y(angle, vec): + """ + Rotate the vector (x, y, z) about y-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + def rotate_z(angle, vec): + """ + Rotate the vector (x, y, z) about z-axis + """ + cos_angle = jnp.cos(angle) + sin_angle = jnp.sin(angle) + rotation_matrix = jnp.array( + [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]] + ) + rotated_vec = jnp.dot(rotation_matrix, vec) + return rotated_vec + + LNh = jnp.array([0.0, 0.0, 1.0]) + + s1hat = jnp.array( + [ + jnp.sin(theta1) * jnp.cos(phiRef), + jnp.sin(theta1) * jnp.sin(phiRef), + jnp.cos(theta1), + ] + ) + s2hat = jnp.array( + [ + jnp.sin(theta2) * jnp.cos(phi12 + phiRef), + jnp.sin(theta2) * jnp.sin(phi12 + phiRef), + jnp.cos(theta2), + ] + ) + + temp = 1 / eta / 2 - 1 + q = temp - (temp**2 - 1) ** 0.5 + m1, m2 = Mc_q_to_m1m2(M_c, q) + v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef) + + Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0)) + s1 = m1 * m1 * chi1 * s1hat + s2 = m2 * m2 * chi2 * s2hat + J = s1 + s2 + jnp.array([0.0, 0.0, Lmag]) + + Jhat = J / jnp.linalg.norm(J) + theta0 = jnp.arccos(Jhat[2]) + phi0 = jnp.arctan2(Jhat[1], Jhat[0]) + + # Rotation 1: + s1hat = rotate_z(-phi0, s1hat) + s2hat = rotate_z(-phi0, s2hat) + + # Rotation 2: + LNh = rotate_y(-theta0, LNh) + s1hat = rotate_y(-theta0, s1hat) + s2hat = rotate_y(-theta0, s2hat) + + # Rotation 3: + LNh = rotate_z(phiJL - jnp.pi, LNh) + s1hat = rotate_z(phiJL - jnp.pi, s1hat) + s2hat = rotate_z(phiJL - jnp.pi, s2hat) + + # Compute iota + N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)]) + iota = jnp.arccos(jnp.dot(N, LNh)) + + thetaLJ = jnp.arccos(LNh[2]) + phiL = jnp.arctan2(LNh[1], LNh[0]) + + # Rotation 4: + s1hat = rotate_z(-phiL, s1hat) + s2hat = rotate_z(-phiL, s2hat) + N = rotate_z(-phiL, N) + + # Rotation 5: + s1hat = rotate_y(-thetaLJ, s1hat) + s2hat = rotate_y(-thetaLJ, s2hat) + N = rotate_y(-thetaLJ, N) + + # Rotation 6: + phiN = jnp.arctan2(N[1], N[0]) + s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat) + s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat) + + S1 = s1hat * chi1 + S2 = s2hat * chi2 + return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] + + def euler_rotation(delta_x: Float[Array, " 3"]): """ Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index f7d4c702..39c55642 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -87,7 +87,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -124,7 +124,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], @@ -516,3 +516,19 @@ def __init__( ) for i in range(len(name_mapping[1])) } + + +def reverse_bijective_transform( + original_transform: BijectiveTransform, +) -> BijectiveTransform: + + reversed_name_mapping = ( + original_transform.name_mapping[1], + original_transform.name_mapping[0], + ) + reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping) + reversed_transform.transform_func = original_transform.inverse_transform_func + reversed_transform.inverse_transform_func = original_transform.transform_func + reversed_transform.__repr__ = lambda: f"Reversed{repr(original_transform)}" + + return reversed_transform diff --git a/test/integration/.gitignore b/test/integration/.gitignore new file mode 100644 index 00000000..a7f7ef0e --- /dev/null +++ b/test/integration/.gitignore @@ -0,0 +1,2 @@ +outdir/ +figures/ diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py index e1eee9ac..946b0735 100644 --- a/test/integration/test_GW150914_D.py +++ b/test/integration/test_GW150914_D.py @@ -1,3 +1,7 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + import jax import jax.numpy as jnp @@ -10,6 +14,8 @@ from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam +from flowMC.utils.postprocessing import plot_summary +import optax jax.config.update("jax_enable_x64", True) @@ -62,7 +68,7 @@ ) sample_transforms = [ - ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + ComponentMassesToChirpMassMassRatioTransform, BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), @@ -72,13 +78,13 @@ BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ - ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform, ] likelihood = TransientLikelihoodFD( @@ -89,7 +95,6 @@ post_trigger_duration=2, ) - mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[5, 5].set(1e-3) @@ -101,7 +106,6 @@ n_loop_training = 1 learning_rate = 1e-4 - jim = Jim( likelihood, prior, @@ -115,7 +119,7 @@ n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30, - n_flow_samples=100, + n_flow_sample=100, momentum=0.9, batch_size=100, use_global=True, @@ -127,4 +131,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() -jim.print_summary() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py index bf97efdb..cbac1788 100644 --- a/test/integration/test_GW150914_D_heterodyne.py +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -62,7 +62,7 @@ ) sample_transforms = [ - ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + ComponentMassesToChirpMassMassRatioTransform, BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), @@ -72,13 +72,13 @@ BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ - ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform, ] likelihood = HeterodynedTransientLikelihoodFD( @@ -132,4 +132,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() -jim.print_summary() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py index c9d83a5e..9892058d 100644 --- a/test/integration/test_GW150914_Pv2.py +++ b/test/integration/test_GW150914_Pv2.py @@ -89,8 +89,8 @@ ] likelihood_transforms = [ - SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), - MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), + SpinToCartesianSpinTransform(freq_ref=20.0), + MassRatioToSymmetricMassRatioTransform, ] likelihood = TransientLikelihoodFD( @@ -139,4 +139,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() -jim.print_summary() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_mass_transforms.py b/test/integration/test_mass_transforms.py new file mode 100644 index 00000000..65b95244 --- /dev/null +++ b/test/integration/test_mass_transforms.py @@ -0,0 +1,106 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + +import numpy as np +import matplotlib.pyplot as plt +import corner +import jax +import jax.numpy as jnp +from jaxtyping import Float + +from jimgw.prior import UniformPrior, CombinePrior +from jimgw.single_event.transforms import ChirpMassMassRatioToComponentMassesTransform +from jimgw.base import LikelihoodBase +from jimgw.jim import Jim + +params = {"axes.grid": True, + "text.usetex" : True, + "font.family" : "serif", + "ytick.color" : "black", + "xtick.color" : "black", + "axes.labelcolor" : "black", + "axes.edgecolor" : "black", + "font.serif" : ["Computer Modern Serif"], + "xtick.labelsize": 16, + "ytick.labelsize": 16, + "axes.labelsize": 16, + "legend.fontsize": 16, + "legend.title_fontsize": 16, + "figure.titlesize": 16} + +plt.rcParams.update(params) + +# Improved corner kwargs +default_corner_kwargs = dict(bins=40, + smooth=1., + show_titles=False, + label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), + color="blue", + # quantiles=[], + # levels=[0.9], + plot_density=True, + plot_datapoints=False, + fill_contours=True, + max_n_ticks=4, + min_n_ticks=3, + truth_color = "red", + save=False) + +# Likelihood for this test: + +class MyLikelihood(LikelihoodBase): + """Simple toy likelihood: Gaussian centered on the true component masses""" + + true_m1: Float + true_m2: Float + + def __init__(self, + true_m1: Float, + true_m2: Float): + + self.true_m1 = true_m1 + self.true_m2 = true_m2 + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + m1, m2 = params['m_1'], params['m_2'] + m1_std = 0.1 + m2_std = 0.1 + return -0.5 * (((m1 - self.true_m1) / m1_std)**2 + ((m2 - self.true_m2) / m2_std)**2) + +# Setup +true_m1 = 1.6 +true_m2 = 1.4 +true_mc = (true_m1 * true_m2)**(3/5) / (true_m1 + true_m2)**(1/5) +true_q = true_m2 / true_m1 + +# Priors +eps = 0.5 # half of width of the chirp mass prior +mc_prior = UniformPrior(true_mc - eps, true_mc + eps, parameter_names=['M_c']) +q_prior = UniformPrior(0.125, 1.0, parameter_names=['q']) +combine_prior = CombinePrior([mc_prior, q_prior]) + +# Likelihood and transform +likelihood = MyLikelihood(true_m1, true_m2) +mass_transform = ChirpMassMassRatioToComponentMassesTransform + +print("Checking mass_transform repr") +print(repr(mass_transform)) + +# Other stuff we have to give to Jim to make it work +step = 5e-3 +local_sampler_arg = {"step_size": step * jnp.eye(2)} + +# Jim: +jim = Jim(likelihood, + combine_prior, + likelihood_transforms=[mass_transform], + n_chains = 10, + parameter_names=['M_c', 'q'], + n_loop_training=2, + n_loop_production=2, + local_sampler_arg=local_sampler_arg) + +jim.sample(jax.random.PRNGKey(0)) +jim.print_summary() \ No newline at end of file diff --git a/test/unit/test_transform.py b/test/unit/test_transform.py new file mode 100644 index 00000000..14162f79 --- /dev/null +++ b/test/unit/test_transform.py @@ -0,0 +1,48 @@ +# import numpy as np +# import jax.numpy as jnp + +# class TestTransform: +# def test_sky_location_transform(self): +# from bilby.gw.utils import zenith_azimuth_to_ra_dec as bilby_earth_to_sky +# from bilby.gw.detector.networks import InterferometerList + +# from jimgw.single_event.utils import zenith_azimuth_to_ra_dec as jimgw_earth_to_sky +# from jimgw.single_event.detector import detector_preset +# from astropy.time import Time + +# ifos = ["H1", "L1"] +# geocent_time = 1000000000 + +# import matplotlib.pyplot as plt + +# for zenith in np.linspace(0, np.pi, 10): +# for azimuth in np.linspace(0, 2*np.pi, 10): +# bilby_sky_location = np.array(bilby_earth_to_sky(zenith, azimuth, geocent_time, InterferometerList(ifos))) +# jimgw_sky_location = np.array(jimgw_earth_to_sky(zenith, azimuth, Time(geocent_time, format="gps").sidereal_time("apparent", "greenwich").rad, detector_preset[ifos[0]].vertex - detector_preset[ifos[1]].vertex)) +# assert np.allclose(bilby_sky_location, jimgw_sky_location, atol=1e-4) + +# def test_spin_transform(self): +# from bilby.gw.conversion import bilby_to_lalsimulation_spins as bilby_spin_transform +# from bilby.gw.conversion import symmetric_mass_ratio_to_mass_ratio, chirp_mass_and_mass_ratio_to_component_masses + +# from jimgw.single_event.utils import spin_to_cartesian_spin as jimgw_spin_transform + +# for _ in range(100): +# thetaJN = jnp.array(np.random.uniform(0, np.pi)) +# phiJL = jnp.array(np.random.uniform(0, np.pi)) +# theta1 = jnp.array(np.random.uniform(0, np.pi)) +# theta2 = jnp.array(np.random.uniform(0, np.pi)) +# phi12 = jnp.array(np.random.uniform(0, np.pi)) +# chi1 = jnp.array(np.random.uniform(0, 1)) +# chi2 = jnp.array(np.random.uniform(0, 1)) +# M_c = jnp.array(np.random.uniform(1, 100)) +# eta = jnp.array(np.random.uniform(0.1, 0.25)) +# fRef = jnp.array(np.random.uniform(10, 1000)) +# phiRef = jnp.array(np.random.uniform(0, 2*np.pi)) + +# q = symmetric_mass_ratio_to_mass_ratio(eta) +# m1, m2 = chirp_mass_and_mass_ratio_to_component_masses(M_c, q) +# MsunInkg = 1.9884e30 +# bilby_spin = jnp.array(bilby_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, m1*MsunInkg, m2*MsunInkg, fRef, phiRef)) +# jimgw_spin = jnp.array(jimgw_spin_transform(thetaJN, phiJL, theta1, theta2, phi12, chi1, chi2, M_c, eta, fRef, phiRef)) +# assert np.allclose(bilby_spin, jimgw_spin, atol=1e-4)