Skip to content

Commit

Permalink
update typing hint, pyright is having a stroke, bypassing
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Dec 14, 2023
1 parent 8f6bdda commit fbbbd01
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 22 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
files: src/
repos:
- repo: https://github.com/psf/black
rev: 23.11.0
rev: 23.12.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.1.7'
rev: 'v0.1.8'
hooks:
- id: ruff
args: ["--fix"]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.339
rev: v1.1.340
hooks:
- id: pyright
additional_dependencies: [beartype, jax, jaxtyping, pytest, typing_extensions, flowMC, ripplegw, gwpy, astropy]
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[build-system]
requires = ["setuptools","wheel"]
build-backend = "setuptools.build_meta"
build-backend = "setuptools.build_meta"

[tool.pyright]
reportIncompatibleMethodOverride = "warning"
10 changes: 5 additions & 5 deletions src/jimgw/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import requests
from gwpy.timeseries import TimeSeries
from jaxtyping import Array, PRNGKeyArray, Float
from jaxtyping import Array, PRNGKeyArray, Float, jaxtyped
from scipy.interpolate import interp1d
from scipy.signal.windows import tukey

Expand Down Expand Up @@ -160,13 +160,13 @@ def arms(self) -> tuple[Float[Array, " 3"], Float[Array, " 3"]]:
return x, y

@property
def tensor(self) -> Float[Array, " 3, 3"]:
def tensor(self) -> Float[Array, " 3 3"]:
"""
Detector tensor defining the strain measurement.
Returns
-------
tensor : Float[Array, " 3, 3"]
tensor : Float[Array, " 3 3"]
detector tensor.
"""
# TODO: this could easily be generalized for other detector geometries
Expand Down Expand Up @@ -389,6 +389,7 @@ def inject_signal(
signal = self.fd_response(freqs, h_sky, params) * align_time
self.data = signal + noise_real + 1j * noise_imag

@jaxtyped
def load_psd(
self, freqs: Float[Array, " n_sample"], psd_file: str = ""
) -> Float[Array, " n_sample"]:
Expand All @@ -401,8 +402,7 @@ def load_psd(
else:
f, asd_vals = np.loadtxt(psd_file, unpack=True)
psd_vals = asd_vals**2
assert isinstance(f, Float[Array, "n_sample"])
assert isinstance(psd_vals, Float[Array, "n_sample"])

psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) # type: ignore
psd = jnp.array(psd)
return psd
Expand Down
2 changes: 1 addition & 1 deletion src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def posterior(self, params: Array, data: dict):
)

def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess is jnp.array([]):
if initial_guess.size == 0:
initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
self.Sampler.sample(initial_guess, None) # type: ignore
Expand Down
7 changes: 6 additions & 1 deletion src/jimgw/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(
self.freq_grid_low = freq_grid[:-1]

print("Finding reference parameters..")

self.ref_params = self.maximize_likelihood(
bounds=bounds, prior=prior, popsize=popsize, n_loops=n_loops
)
Expand Down Expand Up @@ -474,3 +474,8 @@ def y(x):
_ = optimizer.optimize(y, bounds, n_loops=n_loops)
best_fit = optimizer.get_result()[0]
return prior.transform(prior.add_name(best_fit))


class PopulationLikelihood(LikelihoodBase):
events: Float[Array, " n_events n_samples n_dim"]
reference_pop: Float[Array, " n_det n_dim"]
14 changes: 5 additions & 9 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
from flowMC.nfmodel.base import Distribution
from jaxtyping import Array, Float, Int, PRNGKeyArray
from jaxtyping import Array, Float, Int, PRNGKeyArray, jaxtyped
from typing import Callable, Union
from dataclasses import field

Expand Down Expand Up @@ -90,6 +90,7 @@ def log_prob(self, x: dict[str, Array]) -> Float:
raise NotImplementedError


@jaxtyped
class Uniform(Prior):
xmin: Float = 0.0
xmax: Float = 1.0
Expand All @@ -102,8 +103,6 @@ def __init__(
transforms: dict[str, tuple[str, Callable]] = {},
):
super().__init__(naming, transforms)
assert isinstance(xmin, Float), "xmin must be a Float"
assert isinstance(xmax, Float), "xmax must be a Float"
assert self.n_dim == 1, "Uniform needs to be 1D distributions"
self.xmax = xmax
self.xmin = xmin
Expand Down Expand Up @@ -142,6 +141,7 @@ def log_prob(self, x: dict[str, Array]) -> Float:
return output + jnp.log(1.0 / (self.xmax - self.xmin))


@jaxtyped
class Unconstrained_Uniform(Prior):
xmin: Float = 0.0
xmax: Float = 1.0
Expand All @@ -154,8 +154,6 @@ def __init__(
transforms: dict[str, tuple[str, Callable]] = {},
):
super().__init__(naming, transforms)
assert isinstance(xmin, Float), "xmin must be a Float"
assert isinstance(xmax, Float), "xmax must be a Float"
assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions"
self.xmax = xmax
self.xmin = xmin
Expand Down Expand Up @@ -257,6 +255,7 @@ def log_prob(self, x: dict[str, Float]) -> Float:
return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]]))


@jaxtyped
class Alignedspin(Prior):

"""
Expand Down Expand Up @@ -284,7 +283,6 @@ def __init__(
transforms: dict[str, tuple[str, Callable]] = {},
):
super().__init__(naming, transforms)
assert isinstance(amax, Float), "xmin must be a Float"
assert self.n_dim == 1, "Alignedspin needs to be 1D distributions"
self.amax = amax

Expand Down Expand Up @@ -359,6 +357,7 @@ def log_prob(self, x: dict[str, Float]) -> Float:
return log_p


@jaxtyped
class Powerlaw(Prior):

"""
Expand All @@ -380,9 +379,6 @@ def __init__(
transforms: dict[str, tuple[str, Callable]] = {},
):
super().__init__(naming, transforms)
assert isinstance(xmin, Float), "xmin must be a Float"
assert isinstance(xmax, Float), "xmax must be a Float"
assert isinstance(alpha, (Int, Float)), "alpha must be a int or a Float"
if alpha < 0.0:
assert xmin > 0.0, "With negative alpha, xmin must > 0"
assert self.n_dim == 1, "Powerlaw needs to be 1D distributions"
Expand Down
4 changes: 2 additions & 2 deletions src/jimgw/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, name: str):

def tensor_from_basis(
self, x: Float[Array, " 3"], y: Float[Array, " 3"]
) -> Float[Array, " 3, 3"]:
) -> Float[Array, " 3 3"]:
"""Constructor to obtain polarization tensor from waveframe basis
defined by orthonormal vectors (x, y) in arbitrary Cartesian
coordinates.
Expand All @@ -52,7 +52,7 @@ def tensor_from_basis(

def tensor_from_sky(
self, ra: Float, dec: Float, psi: Float, gmst: Float
) -> Float[Array, " 3, 3"]:
) -> Float[Array, " 3 3"]:
"""Computes {name} polarization tensor in celestial
coordinates from sky location and orientation parameters.
Expand Down

0 comments on commit fbbbd01

Please sign in to comment.