Skip to content

Commit

Permalink
Update tests for time-ordered observations
Browse files Browse the repository at this point in the history
  • Loading branch information
MetinSa committed Jul 25, 2024
1 parent 596f3a5 commit 92518a8
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 25 deletions.
76 changes: 58 additions & 18 deletions tests/strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import datetime

import numpy as np
import numpy.typing as npt
from astropy import time, units
from astropy.coordinates import SkyCoord
from hypothesis import strategies as st
Expand All @@ -10,8 +13,9 @@
from zodipy.model_registry import model_registry
from zodipy.zodiacal_light_model import ZodiacalLightModel

MIN_DATE = datetime.datetime(year=1900, month=1, day=1)
MAX_DATE = datetime.datetime(year=2100, month=1, day=1)
MIN_DATE = time.Time(datetime.datetime(year=1960, month=1, day=1))
MAX_DATE = time.Time(datetime.datetime(year=2080, month=1, day=1))
TEST_SAMPRATE = 1 / 86400 # 1 sec


@st.composite
Expand Down Expand Up @@ -55,33 +59,60 @@ def models(draw: st.DrawFn) -> Model:


@st.composite
def obstimes(draw: st.DrawFn) -> time.Time:
def obstime_inst(draw: st.DrawFn) -> time.Time:
"""Return a strategy for generating astropy Time objects."""
return draw(st.datetimes(min_value=MIN_DATE, max_value=MAX_DATE).map(time.Time))
t0 = draw(st.integers(min_value=MIN_DATE.mjd, max_value=MAX_DATE.mjd))
return time.Time(t0, format="mjd")


@st.composite
def obspos_xyz(draw: st.DrawFn) -> units.Quantity:
def obstime_tod(draw: st.DrawFn, size: int) -> time.Time:
"""Return a strategy for generating astropy Time objects."""
t0 = draw(st.integers(min_value=MIN_DATE.mjd, max_value=MAX_DATE.mjd))
return time.Time(np.linspace(t0, t0 + TEST_SAMPRATE * size, size), format="mjd")


@st.composite
def get_obspos_vec(draw: st.DrawFn, size: int) -> units.Quantity:
"""Return a strategy for generating a heliocentric ecliptic position."""
shape = (3, size) if size != 1 else 3
positive_elements = st.floats(min_value=0.1, max_value=1)
sign = st.sampled_from([-1, 1])
elements = st.builds(lambda x, y: x * y, positive_elements, sign)
vector = draw(arrays(dtype=float, shape=3, elements=elements))
normalized_vector = vector / np.linalg.norm(vector)
vector = draw(arrays(dtype=float, shape=shape, elements=elements))
normalized_vector = vector / np.linalg.norm(vector, axis=0)
magnitude = draw(st.floats(min_value=0.8, max_value=2))
return units.Quantity(normalized_vector * magnitude, unit=units.AU)


@st.composite
def obspos_xyz_inst(draw: st.DrawFn) -> units.Quantity:
"""Return a strategy for generating a heliocentric ecliptic position."""
return draw(get_obspos_vec(1))


@st.composite
def obspos_xyz_tod(draw: st.DrawFn, size: int) -> units.Quantity:
"""Return a strategy for generating a heliocentric ecliptic position."""
return draw(get_obspos_vec(size))


@st.composite
def obspos_str(draw: st.DrawFn) -> str:
"""Return a strategy for generating a heliocentric ecliptic position."""
return draw(st.sampled_from(["earth", "mars", "moon", "semb-l2"]))


@st.composite
def obs(draw: st.DrawFn) -> tuple[str, units.Quantity]:
def obs_inst(draw: st.DrawFn) -> str | units.Quantity:
"""Return a strategy for generating a heliocentric ecliptic position."""
return draw(st.one_of(obspos_str(), obspos_xyz()))
return draw(st.one_of(obspos_str(), obspos_xyz_inst()))


@st.composite
def obs_tod(draw: st.DrawFn, size: int) -> str | units.Quantity:
"""Return a strategy for generating a heliocentric ecliptic position."""
return draw(st.one_of(obspos_str(), obspos_xyz_tod(size)))


@st.composite
Expand All @@ -91,18 +122,27 @@ def frames(draw: st.DrawFn) -> units.Quantity:


@st.composite
def sky_coords(draw: st.DrawFn) -> SkyCoord:
"""Return a strategy for generating astropy SkyCoord objects."""
def get_lonlat(draw: st.DrawFn) -> tuple[npt.NDArray, npt.NDArray]:
"""Return a strategy for generating longitude and latitude arrays."""
theta_strategy = st.floats(min_value=0, max_value=360)
phi_strategy = st.floats(min_value=-90, max_value=90)
ncoords = draw(st.integers(min_value=1, max_value=1000))

theta_array_strategy = arrays(dtype=float, shape=ncoords, elements=theta_strategy)
phi_array_strategy = arrays(dtype=float, shape=ncoords, elements=phi_strategy)
return SkyCoord(
draw(theta_array_strategy),
draw(phi_array_strategy),
unit=(units.deg),
frame=draw(frames()),
obstime=draw(obstimes()),
)
return draw(theta_array_strategy), draw(phi_array_strategy)


@st.composite
def sky_coord_inst(draw: st.DrawFn) -> SkyCoord:
"""Return a strategy for generating astropy SkyCoord objects."""
lon, lat = draw(get_lonlat())
return SkyCoord(lon, lat, unit=(units.deg), frame=draw(frames()), obstime=draw(obstime_inst()))


@st.composite
def sky_coord_tod(draw: st.DrawFn) -> SkyCoord:
"""Return a strategy for generating astropy SkyCoord objects."""
lon, lat = draw(get_lonlat())
obstime = draw(obstime_tod(lon.size))
return SkyCoord(lon, lat, unit=(units.deg), frame=draw(frames()), obstime=obstime)
41 changes: 34 additions & 7 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
from astropy import coordinates as coords
from astropy import time, units
from astropy.coordinates import SkyCoord
from hypothesis import given, settings
from hypothesis import HealthCheck, given, settings
from hypothesis.strategies import DataObject, data
from numpy.testing import assert_array_equal

from zodipy import Model

from .dirbe_tabulated import DAYS, DIRBE_START_DAY, LAT, LON, TABULATED_DIRBE_EMISSION
from .strategies import models, obs, sky_coords
from .strategies import (
models,
obs_inst,
obs_tod,
sky_coord_inst,
sky_coord_tod,
)

np.random.seed(42)

Expand Down Expand Up @@ -43,14 +50,29 @@ def test_compare_to_dirbe_idl() -> None:


@settings(deadline=None)
@given(models(), sky_coords(), obs())
def test_evaluate(
@given(models(), sky_coord_inst(), obs_inst())
def test_evaluate_inst(
model: Model,
sky_coord: SkyCoord,
obs: units.Quantity | str,
obspos: units.Quantity,
) -> None:
"""Test that the evaluate function works for valid user input."""
emission = model.evaluate(sky_coord, obspos=obs)
emission = model.evaluate(sky_coord, obspos=obspos)
assert emission.size == sky_coord.size
assert isinstance(emission, units.Quantity)
assert emission.unit == units.MJy / units.sr


@settings(deadline=None, suppress_health_check=[HealthCheck.data_too_large])
@given(data(), models(), sky_coord_tod())
def test_evaluate_tod(
data: DataObject,
model: Model,
sky_coord: SkyCoord,
) -> None:
"""Test that the evaluate function works for valid user input."""
obspos = data.draw(obs_tod(sky_coord.size))
emission = model.evaluate(sky_coord, obspos=obspos)
assert emission.size == sky_coord.size
assert isinstance(emission, units.Quantity)
assert emission.unit == units.MJy / units.sr
Expand Down Expand Up @@ -138,7 +160,12 @@ def test_input_shape() -> None:
with pytest.raises(ValueError):
test_model.evaluate(SkyCoord(20, 30, unit=units.deg, obstime=TEST_TIME), obspos=obsposes)

# obstime > obspos
# obstimes > obspos
with pytest.raises(ValueError):
test_model.evaluate(
SkyCoord([20, 30, 40, 50], [30, 40, 30, 20], unit=units.deg, obstime=obstimes),
obspos=[1, 2, 3] * units.AU,
)
with pytest.raises(ValueError):
test_model.evaluate(
SkyCoord([20, 30, 40, 50], [30, 40, 30, 20], unit=units.deg, obstime=obstimes),
Expand Down

0 comments on commit 92518a8

Please sign in to comment.