Skip to content

Commit

Permalink
Move frequency and weights to Zodipy initializer and away from methods
Browse files Browse the repository at this point in the history
  • Loading branch information
MetinSa committed Apr 29, 2024
1 parent 5d49729 commit f43b69e
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 223 deletions.
55 changes: 35 additions & 20 deletions tests/_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
DrawFn,
SearchStrategy,
booleans,
builds,
composite,
datetimes,
floats,
Expand All @@ -28,6 +27,7 @@

import zodipy
from zodipy._line_of_sight import COMPONENT_CUTOFFS
from zodipy.model_registry import model_registry

MIN_FREQ = u.Quantity(10, u.GHz)
MAX_FREQ = u.Quantity(0.1, u.micron).to(u.GHz, equivalencies=u.spectral())
Expand Down Expand Up @@ -151,17 +151,18 @@ def angles(draw: DrawFn, lonlat: bool = False) -> tuple[u.Quantity[u.deg], u.Qua


@composite
def freqs(draw: DrawFn, model: zodipy.Zodipy) -> u.Quantity[u.GHz] | u.Quantity[u.micron]:
if model.extrapolate:
def freqs(
draw: DrawFn,
min_freq: u.Quantity,
max_freq: u.Quantity,
extrapolate: bool,
) -> u.Quantity[u.GHz] | u.Quantity[u.micron]:
if extrapolate:
return draw(sampled_from(FREQ_LOG_RANGE).map(np.exp).map(partial(u.Quantity, unit=u.GHz)))

min_freq = model._ipd_model.spectrum[0]
max_freq = model._ipd_model.spectrum[-1]
freq_range = np.geomspace(np.log(min_freq.value), np.log(max_freq.value), N_FREQS)
freq_strategy = (
sampled_from(freq_range.tolist())
.map(np.exp)
.map(partial(u.Quantity, unit=model._ipd_model.spectrum.unit))
sampled_from(freq_range.tolist()).map(np.exp).map(partial(u.Quantity, unit=min_freq.unit))
)

return np.clip(draw(freq_strategy), min_freq, max_freq)
Expand Down Expand Up @@ -241,18 +242,32 @@ def any_obs(draw: DrawFn, model: zodipy.Zodipy) -> str:
return draw(sampled_from(model.supported_observers))


MODEL_STRATEGY_MAPPINGS: dict[str, SearchStrategy[Any]] = {
"model": sampled_from(AVAILABLE_MODELS),
"gauss_quad_degree": integers(min_value=1, max_value=200),
"extrapolate": booleans(),
}


@composite
def zodipy_models(draw: DrawFn, **static_params: dict[str, Any]) -> zodipy.Zodipy:
strategies = MODEL_STRATEGY_MAPPINGS.copy()
for key in static_params:
if key in strategies:
strategies.pop(key)
extrapolate = static_params.pop("extrapolate", draw(booleans()))
model = static_params.pop("model", draw(sampled_from(AVAILABLE_MODELS)))
ipd_model = model_registry.get_model(model)
min_freq = ipd_model.spectrum.min()
max_freq = ipd_model.spectrum.max()

do_bp = static_params.pop("bandpass_integrate", None)
if do_bp is not None:
frequencies = draw(random_freqs(bandpass=True))
w = draw(weights(frequencies))
else:
frequencies = static_params.pop(
"freq", draw(freqs(min=min_freq, max=max_freq, extrapolate=extrapolate))
)
w = None

return draw(builds(partial(zodipy.Zodipy, **static_params), **strategies))
gauss_quad_degree = static_params.pop(
"gauss_quad_degree", draw(integers(min_value=1, max_value=200))
)

return zodipy.Zodipy(
freq=frequencies,
model=model,
weights=w,
gauss_quad_degree=gauss_quad_degree,
extrapolate=extrapolate,
)
Loading

0 comments on commit f43b69e

Please sign in to comment.