Skip to content

Commit

Permalink
replace vectorize with vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Sep 20, 2024
1 parent c25048f commit 4822962
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
41 changes: 30 additions & 11 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,38 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict):

def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
if initial_position.size == 0:
initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan

while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all():
non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0]
initial_position = (
jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan
)

while not jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
).all():
non_finite_index = jnp.where(
jnp.any(
~jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
),
axis=1,
)
)[0]

key, subkey = jax.random.split(key)
guess = self.prior.sample(subkey, self.sampler.n_chains)
for transform in self.sample_transforms:
guess = jax.vmap(transform.forward)(guess)
guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T
finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0]
guess = jnp.array(
jax.tree.leaves({key: guess[key] for key in self.parameter_names})
).T
finite_guess = jnp.where(
jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1)
)[0]
common_length = min(len(finite_guess), len(non_finite_index))
initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length])
initial_position = initial_position.at[
non_finite_index[:common_length]
].set(guess[:common_length])
self.sampler.sample(initial_position, None) # type: ignore

def maximize_likelihood(
Expand Down Expand Up @@ -157,7 +176,7 @@ def print_summary(self, transform: bool = True):
training_chain = self.add_name(training_chain)
if transform:
for sample_transform in reversed(self.sample_transforms):
training_chain = sample_transform.backward(training_chain)
training_chain = jax.vmap(sample_transform.backward)(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
Expand All @@ -167,7 +186,7 @@ def print_summary(self, transform: bool = True):
production_chain = self.add_name(production_chain)
if transform:
for sample_transform in reversed(self.sample_transforms):
production_chain = sample_transform.backward(production_chain)
production_chain = jax.vmap(sample_transform.backward)(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]
Expand Down Expand Up @@ -223,10 +242,10 @@ def get_samples(self, training: bool = False) -> dict:
else:
chains = self.sampler.get_sampler_state(training=False)["chains"]

chains = chains.transpose(2, 0, 1)
chains = chains.reshape(-1, self.prior.n_dim)
chains = self.add_name(chains)
for sample_transform in reversed(self.sample_transforms):
chains = sample_transform.backward(chains)
chains = jax.vmap(sample_transform.backward)(chains)
return chains

def plot(self):
Expand Down
13 changes: 4 additions & 9 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
tc_min: Float,
tc_max: Float,
):
name_mapping = [["t_c"], ["t_det_unbounded"]]
name_mapping = (["t_c"], ["t_det_unbounded"])
conditional_names = ["ra", "dec"]
super().__init__(name_mapping, conditional_names)

Expand All @@ -159,7 +159,6 @@ def __init__(
assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1]
assert "ra" in conditional_names and "dec" in conditional_names

@jnp.vectorize
def time_delay(ra, dec, gmst):
return self.ifo.delay_from_geocenter(ra, dec, gmst)

Expand All @@ -181,9 +180,7 @@ def named_transform(x):

def named_inverse_transform(x):

time_shift = jnp.vectorize(self.ifo.delay_from_geocenter)(
x["ra"], x["dec"], self.gmst
)
time_shift = self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst)

t_det_min = self.tc_min + time_shift
t_det_max = self.tc_max + time_shift
Expand Down Expand Up @@ -228,7 +225,7 @@ def __init__(
gps_time: Float,
ifo: GroundBased2G,
):
name_mapping = [["phase_c"], ["phase_det"]]
name_mapping = (["phase_c"], ["phase_det"])
conditional_names = ["ra", "dec", "psi", "iota"]
super().__init__(name_mapping, conditional_names)

Expand All @@ -245,7 +242,6 @@ def __init__(
and "iota" in conditional_names
)

@jnp.vectorize
def _calc_R_det_arg(ra, dec, psi, iota, gmst):
p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0
c_iota_term = jnp.cos(iota)
Expand Down Expand Up @@ -303,7 +299,7 @@ def __init__(
dL_min: Float,
dL_max: Float,
):
name_mapping = [["d_L"], ["d_hat_unbounded"]]
name_mapping = (["d_L"], ["d_hat_unbounded"])
conditional_names = ["M_c", "ra", "dec", "psi", "iota"]
super().__init__(name_mapping, conditional_names)

Expand All @@ -323,7 +319,6 @@ def __init__(
and "M_c" in conditional_names
)

@jnp.vectorize
def _calc_R_dets(ra, dec, psi, iota):
p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0
c_iota_term = jnp.cos(iota)
Expand Down

0 comments on commit 4822962

Please sign in to comment.