diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 93a937af..d96eacab 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp from flowMC.nfmodel.base import Distribution -from jaxtyping import Array, Float, Int +from jaxtyping import Array, Float, Int, PRNGKeyArray from typing import Callable, Union from dataclasses import field