Skip to content

Commit

Permalink
update to adhere to python 3.9 and 3.11 convention
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Dec 4, 2023
1 parent cb8d95b commit 76bb56a
Showing 1 changed file with 27 additions and 30 deletions.
57 changes: 27 additions & 30 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def logpdf(self, x: dict) -> Float:


class Uniform(Prior):

xmin: float = 0.0
xmax: float = 1.0

Expand Down Expand Up @@ -138,7 +137,6 @@ def log_prob(self, x: dict) -> Float:


class Unconstrained_Uniform(Prior):

xmin: float = 0.0
xmax: float = 1.0
to_range: Callable = lambda x: x
Expand Down Expand Up @@ -228,11 +226,9 @@ def __init__(self, naming: str):
def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array:
rng_keys = jax.random.split(rng_key, 3)
theta = jnp.arccos(
jax.random.uniform(
rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0
)
jax.random.uniform(rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0)
)
phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi)
phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2 * jnp.pi)
mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1)
return self.add_name(jnp.stack([theta, phi, mag], axis=1).T)

Expand All @@ -257,8 +253,8 @@ class Alignedspin(Prior):
"""

amax: float = 0.99
chi_axis: Array = jnp.linspace(0, 1, num=1000)
cdf_vals: Array = jnp.linspace(0, 1, num=1000)
chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000))
cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000))

def __init__(
self,
Expand All @@ -273,7 +269,7 @@ def __init__(

# build the interpolation table for the ppf of the one-sided distribution
chi_axis = jnp.linspace(1e-31, self.amax, num=1000)
cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.) / self.amax
cdf_vals = -chi_axis * (jnp.log(chi_axis / self.amax) - 1.0) / self.amax
self.chi_axis = chi_axis
self.cdf_vals = cdf_vals

Expand Down Expand Up @@ -306,18 +302,16 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:
Samples from the distribution. The keys are the names of the parameters.
"""
q_samples = jax.random.uniform(
rng_key, (n_samples,), minval=0., maxval=1.
)
q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0)
# 1. calculate the sign of chi from the q_samples
sign_samples = jnp.where(
q_samples >= 0.5,
jnp.zeros_like(q_samples) + 1.,
jnp.zeros_like(q_samples) - 1.,
jnp.zeros_like(q_samples) + 1.0,
jnp.zeros_like(q_samples) - 1.0,
)
# 2. remap q_samples
q_samples = jnp.where(
q_samples >=0.5,
q_samples >= 0.5,
2 * (q_samples - 0.5),
2 * (0.5 - q_samples),
)
Expand All @@ -337,7 +331,7 @@ def log_prob(self, x: dict) -> Float:
log_p = jnp.where(
(variable >= self.amax) | (variable <= -self.amax),
jnp.zeros_like(variable) - jnp.inf,
jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2. / self.amax),
jnp.log(-jnp.log(jnp.absolute(variable) / self.amax) / 2.0 / self.amax),
)
return log_p

Expand All @@ -358,25 +352,26 @@ def __init__(
self,
xmin: float,
xmax: float,
alpha: int | float,
alpha: float,
naming: list[str],
transforms: dict[tuple[str, Callable]] = {},
):
super().__init__(naming, transforms)
assert isinstance(xmin, float), "xmin must be a float"
assert isinstance(xmax, float), "xmax must be a float"
assert isinstance(alpha, (int, float)), "alpha must be a int or a float"
if alpha < 0.:
assert alpha < 0. or xmin > 0., "With negative alpha, xmin must > 0"
assert isinstance(alpha, (float)), "alpha must be a float"
if alpha < 0.0:
assert alpha < 0.0 or xmin > 0.0, "With negative alpha, xmin must > 0"
assert self.n_dim == 1, "Powerlaw needs to be 1D distributions"
self.xmax = xmax
self.xmin = xmin
self.alpha = alpha
if alpha == -1:
self.normalization = 1. / jnp.log(self.xmax / self.xmin)
self.normalization = 1.0 / jnp.log(self.xmax / self.xmin)
else:
self.normalization = (1 + self.alpha) / (self.xmax ** (1 + self.alpha) -
self.xmin ** (1 + self.alpha))
self.normalization = (1 + self.alpha) / (
self.xmax ** (1 + self.alpha) - self.xmin ** (1 + self.alpha)
)

def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:
"""
Expand All @@ -395,14 +390,15 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict:
Samples from the distribution. The keys are the names of the parameters.
"""
q_samples = jax.random.uniform(
rng_key, (n_samples,), minval=0., maxval=1.
)
q_samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0)
if self.alpha == -1:
samples = self.xmin * jnp.exp(q_samples * jnp.log(self.xmax / self.xmin))
else:
samples = (self.xmin ** (1. + self.alpha) + q_samples *
(self.xmax ** (1. + self.alpha) - self.xmin ** (1. + self.alpha))) ** (1. / (1. + self.alpha))
samples = (
self.xmin ** (1.0 + self.alpha)
+ q_samples
* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
) ** (1.0 / (1.0 + self.alpha))
return self.add_name(samples[None])

def log_prob(self, x: dict) -> Float:
Expand All @@ -417,10 +413,11 @@ def log_prob(self, x: dict) -> Float:


class Composite(Prior):

priors: list[Prior] = field(default_factory=list)

def __init__(self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {}):
def __init__(
self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {}
):
naming = []
self.transforms = {}
for prior in priors:
Expand Down

0 comments on commit 76bb56a

Please sign in to comment.