Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Oct 13, 2024
1 parent d69c05a commit d062ee1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
4 changes: 1 addition & 3 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def __init__(
],
)


def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]:
if prior.composite:
if isinstance(prior.base_prior, list):
Expand All @@ -410,9 +411,6 @@ def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]:
return output





# ====================== Things below may need rework ======================

# @jaxtyped(typechecker=typechecker)
Expand Down
41 changes: 29 additions & 12 deletions src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,18 +584,35 @@ def y(x):
)

key = jax.random.PRNGKey(0)
initial_position = []
for _ in range(popsize):
flag = True
while flag:
key = jax.random.split(key)[1]
guess = prior.sample(key, 1)
for transform in sample_transforms:
guess = transform.forward(guess)
guess = jnp.array([i for i in guess.values()]).T[0]
flag = not jnp.all(jnp.isfinite(guess))
initial_position.append(guess)
initial_position = jnp.array(initial_position)
initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan
while not jax.tree.reduce(
jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)
).all():
non_finite_index = jnp.where(
jnp.any(
~jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
),
axis=1,
)
)[0]

key, subkey = jax.random.split(key)
guess = prior.sample(subkey, popsize)
for transform in sample_transforms:
guess = jax.vmap(transform.forward)(guess)
guess = jnp.array(
jax.tree.leaves({key: guess[key] for key in parameter_names})
).T
finite_guess = jnp.where(
jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1)
)[0]
common_length = min(len(finite_guess), len(non_finite_index))
initial_position = initial_position.at[
non_finite_index[:common_length]
].set(guess[:common_length])

rng_key, optimized_positions, summary = optimizer.optimize(
jax.random.PRNGKey(12094), y, initial_position
)
Expand Down
6 changes: 0 additions & 6 deletions src/jimgw/single_event/prior.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import jaxtyped

from jimgw.prior import (
Prior,
CombinePrior,
UniformPrior,
PowerLawPrior,
SinePrior,
)



@jaxtyped(typechecker=typechecker)
class UniformComponentChirpMassPrior(PowerLawPrior):
"""
Expand Down

0 comments on commit d062ee1

Please sign in to comment.