Skip to content

Commit

Permalink
update PV2 example
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Sep 20, 2024
1 parent 41abe69 commit 84ff0b7
Showing 1 changed file with 77 additions and 48 deletions.
125 changes: 77 additions & 48 deletions example/GW150914_IMRPhenomPV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from jimgw.single_event.waveform import RippleIMRPhenomPv2
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import (
ComponentMassesToChirpMassSymmetricMassRatioTransform,
SkyFrameToDetectorFrameSkyPositionTransform,
ComponentMassesToChirpMassMassRatioTransform,
SpinToCartesianSpinTransform,
MassRatioToSymmetricMassRatioTransform

)
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam
Expand All @@ -40,7 +41,7 @@
fmin = 20.0
fmax = 1024.0

ifos = ["H1", "L1"]
ifos = [H1, L1]

H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
Expand All @@ -55,29 +56,30 @@

# Mass prior
M_c_min, M_c_max = 10.0, 80.0
eta_min, eta_max = 0.2, 0.25
q_min, q_max = 0.125, 1.0
Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"])
eta_prior = UniformPrior(eta_min, eta_max, parameter_names=["eta"])
q_prior = UniformPrior(q_min, q_max, parameter_names=["q"])

prior = prior + [Mc_prior, eta_prior]
prior = prior + [Mc_prior, q_prior]

# Spin prior
a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"])
a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"])
theta_jn_prior = SinePrior(parameter_names=["theta_jn"])
phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"])
theta_1_prior = SinePrior(parameter_names=["theta_1"])
theta_2_prior = SinePrior(parameter_names=["theta_2"])
tilt_1_prior = SinePrior(parameter_names=["tilt_1"])
tilt_2_prior = SinePrior(parameter_names=["tilt_2"])
phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"])
a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"])
a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"])


prior = prior + [
a_1_prior,
a_2_prior,
theta_jn_prior,
phi_jl_prior,
theta_1_prior,
theta_2_prior,
tilt_1_prior,
tilt_2_prior,
phi_12_prior,
a_1_prior,
a_2_prior,
]

# Extrinsic prior
Expand All @@ -99,52 +101,79 @@
dec_prior,
]


prior = CombinePrior(prior)

# Defining Transforms

sample_transforms = [
# ComponentMassesToChirpMassMassRatioTransform,
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = (["eta"], ["eta_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = (["theta_jn"], ["theta_jn_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["phi_jl"], ["phi_jl_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["theta_1"], ["theta_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["theta_2"], ["theta_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["phi_12"], ["phi_12_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["a_1"], ["a_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = (["a_2"], ["a_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = (["d_L"], ["d_L_unbounded"]) , original_lower_bound=10.0, original_upper_bound=2000.0),
BoundToUnbound(name_mapping = (["t_c"], ["t_c_unbounded"]) , original_lower_bound=-0.05, original_upper_bound=0.05),
BoundToUnbound(name_mapping = (["phase_c"], ["phase_c_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
MassRatioToSymmetricMassRatioTransform,
SpinToCartesianSpinTransform(freq_ref=20.),
]


likelihood = TransientLikelihoodFD(
[H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2
)


mass_matrix = jnp.eye(prior.n_dim)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[9, 9].set(1e-3)
# mass_matrix = mass_matrix.at[1, 1].set(1e-3)
# mass_matrix = mass_matrix.at[9, 9].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 1e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)

# import optax

# n_epochs = 20
# n_loop_training = 100
# total_epochs = n_epochs * n_loop_training
# start = total_epochs // 10
# learning_rate = optax.polynomial_schedule(
# 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start
# )

# jim = Jim(
# likelihood,
# prior,
# n_loop_training=n_loop_training,
# n_loop_production=20,
# n_local_steps=10,
# n_global_steps=1000,
# n_chains=500,
# n_epochs=n_epochs,
# learning_rate=learning_rate,
# n_max_examples=30000,
# n_flow_sample=100000,
# momentum=0.9,
# batch_size=30000,
# use_global=True,
# keep_quantile=0.0,
# train_thinning=1,
# output_thinning=10,
# local_sampler_arg=local_sampler_arg,
# # strategies=[Adam_optimizer,"default"],
# )
import optax

n_epochs = 20
n_loop_training = 100
total_epochs = n_epochs * n_loop_training
start = total_epochs // 10
learning_rate = optax.polynomial_schedule(
1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start
)

jim = Jim(
likelihood,
prior,
n_loop_training=n_loop_training,
n_loop_production=20,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_sample=100000,
momentum=0.9,
batch_size=30000,
use_global=True,
keep_quantile=0.0,
train_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
# strategies=[Adam_optimizer,"default"],
)

# import numpy as np

Expand Down

0 comments on commit 84ff0b7

Please sign in to comment.