diff --git a/example/Multiple_event_runManager.py b/example/Multiple_event_runManager.py new file mode 100644 index 00000000..eb6d87ce --- /dev/null +++ b/example/Multiple_event_runManager.py @@ -0,0 +1,12 @@ +import jax +import jax.numpy as jnp + +from jimgw.single_event.runManager import MultipleEventPERunManager + +jax.config.update("jax_enable_x64", True) + +run_manager = MultipleEventPERunManager( + run_config_path="config", # the configuration file is stored in the config folder +) + +run_manager.run() diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 88ffe52b..26280ead 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -1,4 +1,3 @@ - import jax import jax.numpy as jnp @@ -12,58 +11,50 @@ mass_matrix = mass_matrix.at[5, 5].set(1e-3) mass_matrix = mass_matrix * 3e-3 local_sampler_arg = {"step_size": mass_matrix} -bounds = jnp.array( - [ - [10.0, 40.0], - [0.125, 1.0], - [-1.0, 1.0], - [-1.0, 1.0], - [0.0, 2000.0], - [-0.05, 0.05], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - [0.0, jnp.pi], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - ] -) - run = SingleEventRun( seed=0, detectors=["H1", "L1"], + data_parameters={ + "trigger_time": 1126259462.4, + "duration": 4, + "post_trigger_duration": 2, + "f_min": 20.0, + "f_max": 1024.0, + "tukey_alpha": 0.2, + "f_sampling": 4096.0, + }, priors={ - "M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0}, - "q": {"name": "MassRatio"}, - "s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0}, - "s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0}, - "d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0}, - "t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05}, - "phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, - "cos_iota": {"name": "CosIota"}, - "psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi}, - "ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi}, - "sin_dec": {"name": "SinDec"}, + "M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0}, + "q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0}, + "s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0}, + "t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "iota": {"name": "SinePrior"}, + "psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "dec": {"name": "CosinePrior"}, }, waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, - jim_parameters={ - "n_loop_training": 10, - "n_loop_production": 10, - "n_local_steps": 150, - "n_global_steps": 150, - "n_chains": 500, - "n_epochs": 50, - "learning_rate": 0.001, - "n_max_examples": 45000, - "momentum": 0.9, - "batch_size": 50000, - "use_global": True, - "keep_quantile": 0.0, - "train_thinning": 1, - "output_thinning": 10, - "local_sampler_arg": local_sampler_arg, - }, - likelihood_parameters={"name": "HeterodynedTransientLikelihoodFD", "bounds": bounds}, + likelihood_parameters={"name": "TransientLikelihoodFD"}, + sample_transforms=[ + {"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,}, + {"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,}, + {"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,}, + {"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + ], + likelihood_transforms=[ + {"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]}, + ], injection=True, injection_parameters={ "M_c": 28.6, @@ -78,15 +69,28 @@ "ra": 1.2, "dec": 0.3, }, - data_parameters={ - "trigger_time": 1126259462.4, - "duration": 4, - "post_trigger_duration": 2, - "f_min": 20.0, - "f_max": 1024.0, - "tukey_alpha": 0.2, - "f_sampling": 4096.0, + jim_parameters={ + "n_loop_training": 100, + "n_loop_production": 20, + "n_local_steps": 10, + "n_global_steps": 1000, + "n_chains": 500, + "n_epochs": 30, + "learning_rate": 1e-4, + "n_max_examples": 30000, + "momentum": 0.9, + "batch_size": 30000, + "use_global": True, + "train_thinning": 1, + "output_thinning": 10, + "local_sampler_arg": local_sampler_arg, }, ) run_manager = SingleEventPERunManager(run=run) +run_manager.sample() + +# plot the corner plot and diagnostic plot +run_manager.plot_corner() +run_manager.plot_diagnostic() +run_manager.save_summary() diff --git a/src/jimgw/single_event/detector.py b/src/jimgw/single_event/detector.py index 6c3079cf..580fe6f0 100644 --- a/src/jimgw/single_event/detector.py +++ b/src/jimgw/single_event/detector.py @@ -373,7 +373,7 @@ def inject_signal( h_sky: dict[str, Float[Array, " n_sample"]], params: dict[str, Float], psd_file: str = "", - ) -> None: + ) -> tuple[Float, Float]: """ Inject a signal into the detector data. @@ -392,7 +392,7 @@ def inject_signal( Returns ------- - None + SNR """ self.frequencies = freqs self.psd = self.load_psd(freqs, psd_file) @@ -415,6 +415,8 @@ def inject_signal( print(f"The injected optimal SNR is {optimal_SNR}") print(f"The injected match filter SNR is {match_filter_SNR}") + return optimal_SNR, match_filter_SNR + @jaxtyped(typechecker=typechecker) def load_psd( self, freqs: Float[Array, " n_sample"], psd_file: str = "" diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 3f65166d..de5d84f7 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -4,12 +4,18 @@ import jax import jax.numpy as jnp import matplotlib.pyplot as plt +import corner +import sys +import os +import numpy as np import yaml from astropy.time import Time from jaxlib.xla_extension import ArrayImpl from jaxtyping import Array, Float, PyTree -from jimgw import prior +from jimgw import prior, transforms +from jimgw.single_event import prior as single_event_prior +from jimgw.single_event import transforms as single_event_transforms from jimgw.base import RunManager from jimgw.jim import Jim from jimgw.single_event.detector import Detector, detector_preset @@ -23,44 +29,6 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): yaml.add_representer(ArrayImpl, jaxarray_representer) # type: ignore -prior_presets = { - "Unconstrained_Uniform": prior.Unconstrained_Uniform, - "Uniform": prior.Uniform, - "Sphere": prior.Sphere, - "AlignedSpin": prior.AlignedSpin, - "PowerLaw": prior.PowerLaw, - "Composite": prior.Composite, - "MassRatio": lambda **kwargs: prior.Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, - ), - "CosIota": lambda **kwargs: prior.Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos(params["cos_iota"]), - ) - }, - ), - "SinDec": lambda **kwargs: prior.Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin(params["sin_dec"]), - ) - }, - ), - "EarthFrame": prior.EarthFrame, -} - @dataclass class SingleEventRun: @@ -71,7 +39,8 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + path: str = "single_event_run" + injection_parameters: dict[str, float] = field(default_factory=lambda: {}) injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -90,11 +59,18 @@ class SingleEventRun: "f_sampling": 4096.0, } ) + sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field( + default_factory=lambda: [] + ) + likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field( + default_factory=lambda: [] + ) class SingleEventPERunManager(RunManager): run: SingleEventRun jim: Jim + SNRs: list[float] @property def waveform(self): @@ -123,9 +99,21 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError + if self.run.injection and not self.run.injection_parameters: + raise ValueError("Injection mode requires injection parameters.") + local_prior = self.initialize_prior() - local_likelihood = self.initialize_likelihood(local_prior) - self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) + sample_transforms, likelihood_transforms = self.initialize_transforms() + local_likelihood = self.initialize_likelihood( + local_prior, sample_transforms, likelihood_transforms + ) + self.jim = Jim( + local_likelihood, + local_prior, + sample_transforms, + likelihood_transforms, + **self.run.jim_parameters, + ) def save(self, path: str): output_dict = asdict(self.run) @@ -135,11 +123,23 @@ def save(self, path: str): def load_from_path(self, path: str) -> SingleEventRun: with open(path, "r") as f: data = yaml.safe_load(f) + try: + data["jim_parameters"]["local_sampler_arg"]["step_size"] = jnp.array( + data["jim_parameters"]["local_sampler_arg"]["step_size"] + ) + except KeyError as e: + print("No local sampler argument provided in the configuration.") + return SingleEventRun(**data) ### Initialization functions ### - def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: + def initialize_likelihood( + self, + prior: prior.CombinePrior, + sample_transforms: transforms.Transform, + likelihood_transforms: transforms.Transform, + ) -> SingleEventLiklihood: """ Since prior contains information about types, naming and ranges of parameters, some of the likelihood class require the prior to be initialized, such as the @@ -150,6 +150,7 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: waveform = self.initialize_waveform() name = self.run.likelihood_parameters["name"] assert isinstance(name, str), "Likelihood name must be a string." + assert name in likelihood_presets, f"Likelihood {name} not recognized." if self.run.injection: freqs = jnp.linspace( self.run.data_parameters["f_min"], @@ -179,34 +180,87 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood: - self.run.data_parameters["post_trigger_duration"], } key, subkey = jax.random.split(jax.random.PRNGKey(self.run.seed + 1901)) + SNRs = [] for detector in detectors: - detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore + optimal_SNR, _ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore + SNRs.append(optimal_SNR) key, subkey = jax.random.split(key) + self.SNRs = SNRs + return likelihood_presets[name]( detectors, waveform, prior=prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, **self.run.likelihood_parameters, **self.run.data_parameters, ) - def initialize_prior(self) -> prior.Prior: + def initialize_prior(self) -> prior.CombinePrior: priors = [] for name, parameters in self.run.priors.items(): - if parameters["name"] not in prior_presets: - raise ValueError(f"Prior {name} not recognized.") - if parameters["name"] == "EarthFrame": - priors.append( - prior.EarthFrame( - gps=self.run.data_parameters["trigger_time"], - ifos=self.run.detectors, + assert isinstance( + parameters, dict + ), "Prior parameters must be a dictionary." + assert "name" in parameters, "Prior name must be provided." + assert isinstance(parameters["name"], str), "Prior name must be a string." + try: + prior_class = getattr(single_event_prior, parameters["name"]) + except AttributeError: + try: + prior_class = getattr(prior, parameters["name"]) + except AttributeError: + raise ValueError(f"{parameters['name']} not recognized.") + parameters = parameters.copy() + parameters.pop("name") + priors.append(prior_class(parameter_names=[name], **parameters)) + return prior.CombinePrior(priors) + + def initialize_transforms( + self, + ) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]: + sample_transforms = [] + likelihood_transforms = [] + if self.run.sample_transforms: + for transform in self.run.sample_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance( + transform["name"], str + ), "Transform name must be a string." + try: + transform_class = getattr( + single_event_transforms, transform["name"] ) - ) - else: - priors.append( - prior_presets[parameters["name"]](naming=[name], **parameters) - ) - return prior.Composite(priors) + except AttributeError: + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform = transform.copy() + transform.pop("name") + sample_transforms.append(transform_class(**transform)) + if self.run.likelihood_transforms: + for transform in self.run.likelihood_transforms: + assert isinstance(transform, dict), "Transform must be a dictionary." + assert "name" in transform, "Transform name must be provided." + assert isinstance( + transform["name"], str + ), "Transform name must be a string." + try: + transform_class = getattr( + single_event_transforms, transform["name"] + ) + except AttributeError: + try: + transform_class = getattr(transforms, transform["name"]) + except AttributeError: + raise ValueError(f"{transform['name']} not recognized.") + transform = transform.copy() + transform.pop("name") + likelihood_transforms.append(transform_class(**transform)) + return sample_transforms, likelihood_transforms def initialize_detector(self) -> list[Detector]: """ @@ -351,3 +405,166 @@ def plot_data(self, path: str): plt.ylabel("Amplitude") plt.legend() plt.savefig(path) + + def sample(self): + self.jim.sample(jax.random.PRNGKey(self.run.seed)) + + def get_samples(self): + return self.jim.get_samples() + + def plot_corner(self, path: str = "corner.jpeg", **kwargs): + """ + plot corner plot of the samples. + """ + plot_datapoint = kwargs.get("plot_datapoints", False) + title_quantiles = kwargs.get("title_quantiles", [0.16, 0.5, 0.84]) + show_titles = kwargs.get("show_titles", True) + title_fmt = kwargs.get("title_fmt", ".2E") + use_math_text = kwargs.get("use_math_text", True) + + samples = self.jim.get_samples() + param_names = list(samples.keys()) + samples = np.array(list(samples.values())).reshape(int(len(param_names)), -1).T + corner.corner( + samples, + labels=param_names, + plot_datapoints=plot_datapoint, + title_quantiles=title_quantiles, + show_titles=show_titles, + title_fmt=title_fmt, + use_math_text=use_math_text, + **kwargs, + ) + plt.savefig(path) + plt.close() + + def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): + """ + plot diagnostic plot of the samples. + """ + summary = self.jim.sampler.get_sampler_state(training=True) + chains, log_prob, local_accs, global_accs, loss_vals = summary.values() + log_prob = np.array(log_prob) + + plt.figure(figsize=(10, 10)) + axs = [plt.subplot(2, 2, i + 1) for i in range(4)] + plt.sca(axs[0]) + plt.title("log probability") + plt.plot(log_prob.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[1]) + plt.title("NF loss") + plt.plot(loss_vals.reshape(-1)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[2]) + plt.title("Local Acceptance") + plt.plot(local_accs.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + + plt.sca(axs[3]) + plt.title("Global Acceptance") + plt.plot(global_accs.mean(0)) + plt.xlabel("iteration") + plt.xlim(0, None) + plt.tight_layout() + + plt.savefig(path) + plt.close() + + def save_summary(self, path: str = "", **kwargs): + if path == "": + path = self.run.path + "run_manager_summary.txt" + orig_stdout = sys.stdout + sys.stdout = open(path, "wt") + self.jim.print_summary() + if self.run.injection: + for detector, SNR in zip(self.detectors, self.SNRs): + print("SNR of detector " + detector + " is " + str(SNR)) + networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) + print("network SNR is", networkSNR) + sys.stdout.close() + sys.stdout = orig_stdout + + +class MultipleEventPERunManager: + """ + Class to manage multiple events run. + """ + + run_config_path: str + output_path: str + + def __init__(self, run_config_path: str, output_path: str = "output") -> None: + """ + Arguments: + run_config_path (str): load the run configuration from the path. + output_path (str, optional): save the output to this path. Defaults to "output". + """ + + self.run_config_path = run_config_path + self.output_path = output_path + + def run( + self, + plot_corner: bool = True, + plot_diagnostic: bool = True, + save_summary: bool = True, + ): + """ + Loop over all the configuration files in the run_config_path and run the PE for each configuration. + """ + + if plot_corner and not os.path.exists(self.output_path + "/corner_plots"): + os.makedirs(self.output_path + "/corner_plots") + if plot_diagnostic and not os.path.exists( + self.output_path + "/diagnostic_plots" + ): + os.makedirs(self.output_path + "/diagnostic_plots") + if save_summary and not os.path.exists(self.output_path + "/summaries"): + os.makedirs(self.output_path + "/summaries") + if not os.path.exists(self.output_path + "/error_log"): + os.makedirs(self.output_path + "/error_log") + + config_directory = os.fsencode(self.run_config_path) + for file in os.listdir(config_directory): + filename = os.fsdecode(file) + + try: + if filename.endswith(".yaml"): + config_path = os.path.join(self.run_config_path, filename) + run_manager = SingleEventPERunManager(path=config_path) + run_manager.sample() + + if plot_corner: + run_manager.plot_corner( + self.output_path + + "/corner_plots/" + + filename + + "_corner.jpeg" + ) + if plot_diagnostic: + run_manager.plot_diagnostic( + self.output_path + + "/diagnostic_plots/" + + filename + + "_diagnostic.jpeg" + ) + if save_summary: + run_manager.save_summary( + self.output_path + "/summaries/" + filename + "_summary.txt" + ) + + except Exception as e: + orig_stdout = sys.stdout + sys.stdout = open( + self.output_path + "/error_log/" + filename + "error_log.txt", "wt" + ) + print(f"Error in running {filename}. Error: {e}") + sys.stdout.close() + sys.stdout = orig_stdout + continue diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index a15bd7bf..62e8ad6d 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -393,6 +393,93 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F return ra, dec +def zenith_azimuth_to_ra_dec( + zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. + + Parameters + ---------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + gmst : Float + Greenwich mean sidereal time. + rotation : Float[Array, " 3 3"] + The rotation matrix. + + Copied and modified from bilby/gw/utils.py + + Returns + ------- + ra : Float + Right ascension. + dec : Float + Declination. + """ + theta, phi = angle_rotation(zenith, azimuth, rotation) + ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) + return ra, dec + + +def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: + """ + Transforming the right ascension ra and declination dec to the polar angle + theta and azimuthal angle phi. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + + Returns + ------- + theta : Float + Polar angle. + phi : Float + Azimuthal angle. + """ + phi = ra - gmst + theta = jnp.pi / 2 - dec + phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) + return theta, phi + + +def ra_dec_to_zenith_azimuth( + ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] +) -> tuple[Float, Float]: + """ + Transforming the right ascension and declination to the zenith angle and azimuthal angle. + + Parameters + ---------- + ra : Float + Right ascension. + dec : Float + Declination. + gmst : Float + Greenwich mean sidereal time. + rotation : Float[Array, " 3 3"] + The rotation matrix. + + Returns + ------- + zenith : Float + Zenith angle. + azimuth : Float + Azimuthal angle. + """ + theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) + zenith, azimuth = angle_rotation(theta, phi, rotation) + return zenith, azimuth + + def spin_to_cartesian_spin( thetaJN: Float, phiJL: Float, @@ -549,90 +636,3 @@ def rotate_z(angle, vec): S1 = s1hat * chi1 S2 = s2hat * chi2 return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2] - - -def zenith_azimuth_to_ra_dec( - zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"] -) -> tuple[Float, Float]: - """ - Transforming the azimuthal angle and zenith angle in Earth frame to right ascension and declination. - - Parameters - ---------- - zenith : Float - Zenith angle. - azimuth : Float - Azimuthal angle. - gmst : Float - Greenwich mean sidereal time. - rotation : Float[Array, " 3 3"] - The rotation matrix. - - Copied and modified from bilby/gw/utils.py - - Returns - ------- - ra : Float - Right ascension. - dec : Float - Declination. - """ - theta, phi = angle_rotation(zenith, azimuth, rotation) - ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) - return ra, dec - - -def ra_dec_to_theta_phi(ra: Float, dec: Float, gmst: Float) -> tuple[Float, Float]: - """ - Transforming the right ascension ra and declination dec to the polar angle - theta and azimuthal angle phi. - - Parameters - ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. - - Returns - ------- - theta : Float - Polar angle. - phi : Float - Azimuthal angle. - """ - phi = ra - gmst - theta = jnp.pi / 2 - dec - phi = (phi + 2 * jnp.pi) % (2 * jnp.pi) - return theta, phi - - -def ra_dec_to_zenith_azimuth( - ra: Float, dec: Float, gmst: Float, rotation: Float[Array, " 3 3"] -) -> tuple[Float, Float]: - """ - Transforming the right ascension and declination to the zenith angle and azimuthal angle. - - Parameters - ---------- - ra : Float - Right ascension. - dec : Float - Declination. - gmst : Float - Greenwich mean sidereal time. - rotation : Float[Array, " 3 3"] - The rotation matrix. - - Returns - ------- - zenith : Float - Zenith angle. - azimuth : Float - Azimuthal angle. - """ - theta, phi = ra_dec_to_theta_phi(ra, dec, gmst) - zenith, azimuth = angle_rotation(theta, phi, rotation) - return zenith, azimuth diff --git a/test/integration/.gitignore b/test/integration/.gitignore new file mode 100644 index 00000000..a0783193 --- /dev/null +++ b/test/integration/.gitignore @@ -0,0 +1,3 @@ +*.txt +*.jpeg +*.jpg diff --git a/test/integration/test_single_event_run_manager.py b/test/integration/test_single_event_run_manager.py new file mode 100644 index 00000000..48bec85f --- /dev/null +++ b/test/integration/test_single_event_run_manager.py @@ -0,0 +1,96 @@ +import jax +import jax.numpy as jnp + +from jimgw.single_event.runManager import (SingleEventPERunManager, + SingleEventRun) + +jax.config.update("jax_enable_x64", True) + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +mass_matrix = mass_matrix * 3e-3 +local_sampler_arg = {"step_size": mass_matrix} + +run = SingleEventRun( + seed=0, + detectors=["H1", "L1"], + data_parameters={ + "trigger_time": 1126259462.4, + "duration": 4, + "post_trigger_duration": 2, + "f_min": 20.0, + "f_max": 1024.0, + "tukey_alpha": 0.2, + "f_sampling": 4096.0, + }, + priors={ + "M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0}, + "q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0}, + "s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0}, + "d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0}, + "t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05}, + "phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "iota": {"name": "SinePrior"}, + "psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi}, + "ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi}, + "dec": {"name": "CosinePrior"}, + }, + waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0}, + likelihood_parameters={"name": "TransientLikelihoodFD"}, + sample_transforms=[ + {"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,}, + {"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,}, + {"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,}, + {"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,}, + {"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,}, + {"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,}, + ], + likelihood_transforms=[ + {"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]}, + ], + injection=True, + injection_parameters={ + "M_c": 28.6, + "eta": 0.24, + "s1_z": 0.05, + "s2_z": 0.05, + "d_L": 440.0, + "t_c": 0.0, + "phase_c": 0.0, + "iota": 0.5, + "psi": 0.7, + "ra": 1.2, + "dec": 0.3, + }, + jim_parameters={ + "n_loop_training": 1, + "n_loop_production": 1, + "n_local_steps": 5, + "n_global_steps": 5, + "n_chains": 4, + "n_epochs": 2, + "learning_rate": 1e-4, + "n_max_examples": 30, + "momentum": 0.9, + "batch_size": 100, + "use_global": True, + "train_thinning": 1, + "output_thinning": 1, + "local_sampler_arg": local_sampler_arg, + }, +) + +run_manager = SingleEventPERunManager(run=run) +run_manager.sample() + +# plot the corner plot and diagnostic plot +run_manager.plot_corner() +run_manager.plot_diagnostic() +run_manager.save_summary()