Skip to content

Commit

Permalink
move prior functions to the right place. Also add magnitude options t…
Browse files Browse the repository at this point in the history
…o UniformSphere
  • Loading branch information
kazewong committed Oct 13, 2024
1 parent 66027c3 commit d69c05a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 37 deletions.
19 changes: 17 additions & 2 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class UniformSpherePrior(CombinePrior):
def __repr__(self):
return f"UniformSpherePrior(parameter_names={self.parameter_names})"

def __init__(self, parameter_names: list[str], **kwargs):
def __init__(self, parameter_names: list[str], max_mag: float = 1.0, **kwargs):
self.parameter_names = parameter_names
assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector"
self.parameter_names = [
Expand All @@ -341,7 +341,7 @@ def __init__(self, parameter_names: list[str], **kwargs):
]
super().__init__(
[
UniformPrior(0.0, 1.0, [self.parameter_names[0]]),
UniformPrior(0.0, max_mag, [self.parameter_names[0]]),
SinePrior([self.parameter_names[1]]),
UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]),
]
Expand Down Expand Up @@ -397,6 +397,21 @@ def __init__(
],
)

def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]:
if prior.composite:
if isinstance(prior.base_prior, list):
for subprior in prior.base_prior:
output = trace_prior_parent(subprior, output)
elif isinstance(prior.base_prior, Prior):
output = trace_prior_parent(prior.base_prior, output)
else:
output.append(prior)

return output





# ====================== Things below may need rework ======================

Expand Down
35 changes: 0 additions & 35 deletions src/jimgw/single_event/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,6 @@
)


@jaxtyped(typechecker=typechecker)
class UniformSpherePrior(CombinePrior):

def __repr__(self):
return f"UniformSpherePrior(parameter_names={self.parameter_names})"

def __init__(self, parameter_names: list[str], **kwargs):
self.parameter_names = parameter_names
assert self.n_dim == 1, "UniformSpherePrior only takes the name of the vector"
self.parameter_names = [
f"{self.parameter_names[0]}_mag",
f"{self.parameter_names[0]}_theta",
f"{self.parameter_names[0]}_phi",
]
super().__init__(
[
UniformPrior(0.0, 1.0, [self.parameter_names[0]]),
SinePrior([self.parameter_names[1]]),
UniformPrior(0.0, 2 * jnp.pi, [self.parameter_names[2]]),
]
)


@jaxtyped(typechecker=typechecker)
class UniformComponentChirpMassPrior(PowerLawPrior):
Expand All @@ -50,19 +28,6 @@ def __init__(self, xmin: float, xmax: float):
super().__init__(xmin, xmax, 1.0, ["M_c"])


def trace_prior_parent(prior: Prior, output: list[Prior] = []) -> list[Prior]:
if prior.composite:
if isinstance(prior.base_prior, list):
for subprior in prior.base_prior:
output = trace_prior_parent(subprior, output)
elif isinstance(prior.base_prior, Prior):
output = trace_prior_parent(prior.base_prior, output)
else:
output.append(prior)

return output


# ====================== Things below may need rework ======================


Expand Down

0 comments on commit d69c05a

Please sign in to comment.