Skip to content

Commit

Permalink
Add tests for skycoord api
Browse files Browse the repository at this point in the history
  • Loading branch information
MetinSa committed Apr 26, 2024
1 parent 0bf63f8 commit b484d74
Show file tree
Hide file tree
Showing 11 changed files with 403 additions and 346 deletions.
53 changes: 41 additions & 12 deletions tests/_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from math import log2
from typing import Any, Callable, Sequence

import astropy.coordinates as coords
import astropy.units as u
import healpy as hp
import numpy as np
import numpy.typing as npt
from astropy.coordinates import HeliocentricMeanEcliptic, get_body
from astropy.time import Time
from hypothesis.extra.numpy import arrays
from hypothesis.strategies import (
DrawFn,
Expand Down Expand Up @@ -66,15 +65,39 @@ def quantities(


@composite
def time(draw: DrawFn) -> Time:
return draw(datetimes(min_value=MIN_DATE, max_value=MAX_DATE).map(Time))
def times(draw: DrawFn) -> times.Time:
return draw(datetimes(min_value=MIN_DATE, max_value=MAX_DATE).map(times.Time))


@composite
def nside(draw: Callable[[SearchStrategy[int]], int]) -> int:
def nsides(draw: Callable[[SearchStrategy[int]], int]) -> int:
return draw(integers(min_value=MIN_NSIDE_EXP, max_value=MAX_NSIDE_EXP).map(partial(pow, 2)))


@composite
def frames(draw: DrawFn) -> type[coords.BaseCoordinateFrame]:
return draw(sampled_from([coords.ICRS, coords.Galactic, coords.HeliocentricMeanEcliptic]))


@composite
def sky_coords(draw: DrawFn) -> coords.SkyCoord:
theta_strategy = floats(min_value=0, max_value=360)
phi_strategy = floats(min_value=-90, max_value=90)

shape = draw(integers(min_value=1, max_value=MAX_ANGELS_LEN))

theta_array_strategy = arrays(dtype=float, shape=shape, elements=theta_strategy).map(
partial(u.Quantity, unit=u.deg)
)
phi_array_strategy = arrays(dtype=float, shape=shape, elements=phi_strategy).map(
partial(u.Quantity, unit=u.deg)
)
frame = draw(frames())
lon = draw(theta_array_strategy)
lat = draw(phi_array_strategy)
return coords.SkyCoord(lon, lat, frame=frame)


@composite
def pixels(draw: DrawFn, nside: int) -> int | list[int] | npt.NDArray[np.integer]:
npix = hp.nside2npix(nside)
Expand Down Expand Up @@ -110,7 +133,7 @@ def angles(draw: DrawFn, lonlat: bool = False) -> tuple[u.Quantity[u.deg], u.Qua


@composite
def freq(draw: DrawFn, model: zodipy.Zodipy) -> u.Quantity[u.GHz] | u.Quantity[u.micron]:
def freqs(draw: DrawFn, model: zodipy.Zodipy) -> u.Quantity[u.GHz] | u.Quantity[u.micron]:
if model.extrapolate:
return draw(sampled_from(FREQ_LOG_RANGE).map(np.exp).map(partial(u.Quantity, unit=u.GHz)))

Expand All @@ -127,7 +150,7 @@ def freq(draw: DrawFn, model: zodipy.Zodipy) -> u.Quantity[u.GHz] | u.Quantity[u


@composite
def random_freq(draw: DrawFn, unit: u.Unit | None = None, bandpass: bool = False) -> u.Quantity:
def random_freqs(draw: DrawFn, unit: u.Unit | None = None, bandpass: bool = False) -> u.Quantity:
random_freq = draw(
sampled_from(FREQ_LOG_RANGE).map(np.exp).map(partial(u.Quantity, unit=u.GHz))
)
Expand Down Expand Up @@ -166,15 +189,21 @@ def normalize_array(


@composite
def obs(draw: DrawFn, model: zodipy.Zodipy, obs_time: Time) -> str:
def get_obs_dist(obs: str, obs_time: Time) -> u.Quantity[u.AU]:
def obs_positions(draw: DrawFn, model: zodipy.Zodipy, obs_time: times.Time) -> str:
def get_obs_dist(obs: str, obs_time: times.Time) -> u.Quantity[u.AU]:
if obs == "semb-l2":
obs_pos = (
get_body("earth", obs_time).transform_to(HeliocentricMeanEcliptic).cartesian.xyz
coords.get_body("earth", obs_time)
.transform_to(coords.HeliocentricMeanEcliptic)
.cartesian.xyz
)
obs_pos += 0.01 * u.AU
else:
obs_pos = get_body(obs, obs_time).transform_to(HeliocentricMeanEcliptic).cartesian.xyz
obs_pos = (
coords.get_body(obs, obs_time)
.transform_to(coords.HeliocentricMeanEcliptic)
.cartesian.xyz
)
return u.Quantity(np.linalg.norm(obs_pos.value), u.AU)

los_dist_cut = min(
Expand Down Expand Up @@ -202,7 +231,7 @@ def any_obs(draw: DrawFn, model: zodipy.Zodipy) -> str:


@composite
def model(draw: DrawFn, **static_params: dict[str, Any]) -> zodipy.Zodipy:
def zodipy_models(draw: DrawFn, **static_params: dict[str, Any]) -> zodipy.Zodipy:
strategies = MODEL_STRATEGY_MAPPINGS.copy()
for key in static_params:
if key in strategies:
Expand Down
Loading

0 comments on commit b484d74

Please sign in to comment.