diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 3db3f122..87406ca8 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -397,6 +397,7 @@ def __init__( ], ) + def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: if prior.composite: if isinstance(prior.base_prior, list): @@ -410,9 +411,6 @@ def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]: return output - - - # ====================== Things below may need rework ====================== # @jaxtyped(typechecker=typechecker) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 9e775b33..96e11e62 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -584,18 +584,35 @@ def y(x): ) key = jax.random.PRNGKey(0) - initial_position = [] - for _ in range(popsize): - flag = True - while flag: - key = jax.random.split(key)[1] - guess = prior.sample(key, 1) - for transform in sample_transforms: - guess = transform.forward(guess) - guess = jnp.array([i for i in guess.values()]).T[0] - flag = not jnp.all(jnp.isfinite(guess)) - initial_position.append(guess) - initial_position = jnp.array(initial_position) + initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan + while not jax.tree.reduce( + jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position) + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] + + key, subkey = jax.random.split(key) + guess = prior.sample(subkey, popsize) + for transform in sample_transforms: + guess = jax.vmap(transform.forward)(guess) + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] + common_length = min(len(finite_guess), len(non_finite_index)) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) + rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position ) diff --git a/src/jimgw/single_event/prior.py b/src/jimgw/single_event/prior.py index 76ca6376..194262f0 100644 --- a/src/jimgw/single_event/prior.py +++ b/src/jimgw/single_event/prior.py @@ -1,17 +1,11 @@ -import jax.numpy as jnp from beartype import beartype as typechecker from jaxtyping import jaxtyped from jimgw.prior import ( - Prior, - CombinePrior, - UniformPrior, PowerLawPrior, - SinePrior, ) - @jaxtyped(typechecker=typechecker) class UniformComponentChirpMassPrior(PowerLawPrior): """