Skip to content

Commit

Permalink
pep flake
Browse files Browse the repository at this point in the history
  • Loading branch information
maxisi committed Oct 28, 2024
1 parent e032436 commit a84c8e0
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions example/GW150914_IMRPhenomPV2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import optax
import time

import jax
Expand Down Expand Up @@ -57,7 +58,7 @@
for ifo in ifos:
data = jd.Data.from_gwosc(ifo.name, start, end)
ifo.set_data(data)

psd_data = jd.Data.from_gwosc(ifo.name, psd_start, psd_end)
psd_fftlength = data.duration * data.sampling_frequency
ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength))
Expand Down Expand Up @@ -111,23 +112,39 @@
# Defining Transforms

sample_transforms = [
DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
DistanceToSNRWeightedDistanceTransform(
gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(
gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(
tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping=(["M_c"], [
"M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping=(["q"], ["q_unbounded"]),
original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping=(["s1_phi"], [
"s1_phi_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping=(["s2_phi"], [
"s2_phi_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping=(["iota"], ["iota_unbounded"]),
original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["s1_theta"], [
"s1_theta_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["s2_theta"], [
"s2_theta_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["s1_mag"], [
"s1_mag_unbounded"]), original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping=(["s2_mag"], [
"s2_mag_unbounded"]), original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping=(["phase_det"], [
"phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping=(["psi"], ["psi_unbounded"]),
original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["zenith"], [
"zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["azimuth"], [
"azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
Expand All @@ -147,9 +164,9 @@
# mass_matrix = mass_matrix.at[9, 9].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 1e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)
Adam_optimizer = optimization_Adam(
n_steps=3000, learning_rate=0.01, noise_level=1)

import optax

n_epochs = 20
n_loop_training = 100
Expand Down

0 comments on commit a84c8e0

Please sign in to comment.