diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2f0086ac..3047c0a4 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -104,19 +104,38 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): if initial_position.size == 0: - initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan - - while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): - non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] + initial_position = ( + jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan + ) + + while not jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] key, subkey = jax.random.split(key) guess = self.prior.sample(subkey, self.sampler.n_chains) for transform in self.sample_transforms: guess = jax.vmap(transform.forward)(guess) - guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T - finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0] + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in self.parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] common_length = min(len(finite_guess), len(non_finite_index)) - initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length]) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( @@ -157,7 +176,7 @@ def print_summary(self, transform: bool = True): training_chain = self.add_name(training_chain) if transform: for sample_transform in reversed(self.sample_transforms): - training_chain = sample_transform.backward(training_chain) + training_chain = jax.vmap(sample_transform.backward)(training_chain) training_log_prob = train_summary["log_prob"] training_local_acceptance = train_summary["local_accs"] training_global_acceptance = train_summary["global_accs"] @@ -167,7 +186,7 @@ def print_summary(self, transform: bool = True): production_chain = self.add_name(production_chain) if transform: for sample_transform in reversed(self.sample_transforms): - production_chain = sample_transform.backward(production_chain) + production_chain = jax.vmap(sample_transform.backward)(production_chain) production_log_prob = production_summary["log_prob"] production_local_acceptance = production_summary["local_accs"] production_global_acceptance = production_summary["global_accs"] @@ -223,10 +242,10 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.sampler.get_sampler_state(training=False)["chains"] - chains = chains.transpose(2, 0, 1) + chains = chains.reshape(-1, self.prior.n_dim) chains = self.add_name(chains) for sample_transform in reversed(self.sample_transforms): - chains = sample_transform.backward(chains) + chains = jax.vmap(sample_transform.backward)(chains) return chains def plot(self): diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index d325590b..1bf8f3a7 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -145,7 +145,7 @@ def __init__( tc_min: Float, tc_max: Float, ): - name_mapping = [["t_c"], ["t_det_unbounded"]] + name_mapping = (["t_c"], ["t_det_unbounded"]) conditional_names = ["ra", "dec"] super().__init__(name_mapping, conditional_names) @@ -159,7 +159,6 @@ def __init__( assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1] assert "ra" in conditional_names and "dec" in conditional_names - @jnp.vectorize def time_delay(ra, dec, gmst): return self.ifo.delay_from_geocenter(ra, dec, gmst) @@ -181,9 +180,7 @@ def named_transform(x): def named_inverse_transform(x): - time_shift = jnp.vectorize(self.ifo.delay_from_geocenter)( - x["ra"], x["dec"], self.gmst - ) + time_shift = self.ifo.delay_from_geocenter(x["ra"], x["dec"], self.gmst) t_det_min = self.tc_min + time_shift t_det_max = self.tc_max + time_shift @@ -228,7 +225,7 @@ def __init__( gps_time: Float, ifo: GroundBased2G, ): - name_mapping = [["phase_c"], ["phase_det"]] + name_mapping = (["phase_c"], ["phase_det"]) conditional_names = ["ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) @@ -245,7 +242,6 @@ def __init__( and "iota" in conditional_names ) - @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota) @@ -303,7 +299,7 @@ def __init__( dL_min: Float, dL_max: Float, ): - name_mapping = [["d_L"], ["d_hat_unbounded"]] + name_mapping = (["d_L"], ["d_hat_unbounded"]) conditional_names = ["M_c", "ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) @@ -323,7 +319,6 @@ def __init__( and "M_c" in conditional_names ) - @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota): p_iota_term = (1.0 + jnp.cos(iota) ** 2) / 2.0 c_iota_term = jnp.cos(iota)