Skip to content

Commit

Permalink
fix Pv2 prior and transform
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Oct 7, 2024
1 parent 4cf6b27 commit 66027c3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
40 changes: 16 additions & 24 deletions example/GW150914_IMRPhenomPV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import (
SkyFrameToDetectorFrameSkyPositionTransform,
SpinToCartesianSpinTransform,
SphereSpinToCartesianSpinTransform,
MassRatioToSymmetricMassRatioTransform,
DistanceToSNRWeightedDistanceTransform,
GeocentricArrivalTimeToDetectorArrivalTimeTransform,
Expand Down Expand Up @@ -65,23 +65,14 @@
prior = prior + [Mc_prior, q_prior]

# Spin prior
a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"])
a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"])
theta_jn_prior = SinePrior(parameter_names=["theta_jn"])
phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"])
tilt_1_prior = SinePrior(parameter_names=["tilt_1"])
tilt_2_prior = SinePrior(parameter_names=["tilt_2"])
phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"])

s1_prior = UniformSpherePrior(parameter_names=["s1"])
s2_prior = UniformSpherePrior(parameter_names=["s2"])
iota_prior = SinePrior(parameter_names=["iota"])

prior = prior + [
a_1_prior,
a_2_prior,
theta_jn_prior,
phi_jl_prior,
tilt_1_prior,
tilt_2_prior,
phi_12_prior,
s1_prior,
s2_prior,
iota_prior,
]

# Extrinsic prior
Expand All @@ -106,20 +97,19 @@
# Defining Transforms

sample_transforms = [
SpinToCartesianSpinTransform(freq_ref=20.),
DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = (["theta_jn"], ["theta_jn_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["phi_jl"], ["phi_jl_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["tilt_1"], ["tilt_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["tilt_2"], ["tilt_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["phi_12"], ["phi_12_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["a_1"], ["a_1_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = (["a_2"], ["a_2_unbounded"]) , original_lower_bound=0.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
Expand All @@ -128,6 +118,8 @@

likelihood_transforms = [
MassRatioToSymmetricMassRatioTransform,
SphereSpinToCartesianSpinTransform("s1"),
SphereSpinToCartesianSpinTransform("s2"),
]


Expand Down
2 changes: 1 addition & 1 deletion src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def get_samples(self, training: bool = False) -> dict:
chains = self.sampler.get_sampler_state(training=False)["chains"]

chains = chains.reshape(-1, self.prior.n_dim)
chains = self.add_name(chains)
chains = jax.vmap(self.add_name)(chains)
for sample_transform in reversed(self.sample_transforms):
chains = jax.vmap(sample_transform.backward)(chains)
return chains
Expand Down
42 changes: 42 additions & 0 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,48 @@ def named_transform(x):
self.transform_func = named_transform


@jaxtyped(typechecker=typechecker)
class SphereSpinToCartesianSpinTransform(BijectiveTransform):
"""
Spin to Cartesian spin transformation
"""

def __init__(
self,
label: str,
):
name_mapping = (
[label + "_mag", label + "_theta", label + "_phi"],
[label + "_x", label + "_y", label + "_z"],
)
super().__init__(name_mapping)

def named_transform(x):
mag, theta, phi = x[label + "_mag"], x[label + "_theta"], x[label + "_phi"]
x = mag * jnp.sin(theta) * jnp.cos(phi)
y = mag * jnp.sin(theta) * jnp.sin(phi)
z = mag * jnp.cos(theta)
return {
label + "_x": x,
label + "_y": y,
label + "_z": z,
}

def named_inverse_transform(x):
x, y, z = x[label + "_x"], x[label + "_y"], x[label + "_z"]
mag = jnp.sqrt(x**2 + y**2 + z**2)
theta = jnp.arccos(z / mag)
phi = jnp.arctan2(y, x)
return {
label + "_mag": mag,
label + "_theta": theta,
label + "_phi": phi,
}

self.transform_func = named_transform
self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):
"""
Expand Down

0 comments on commit 66027c3

Please sign in to comment.