Skip to content

Commit

Permalink
Fix bug where update_parameters would not work due to the new initi…
Browse files Browse the repository at this point in the history
…alization
  • Loading branch information
MetinSa committed Jul 3, 2024
1 parent 41470ff commit e266828
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
8 changes: 6 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import pytest
from astropy import time, units
from astropy import coordinates, time, units

from zodipy import Model, grid_number_density, model_registry

Expand Down Expand Up @@ -60,12 +60,16 @@ def test_get_parameters() -> None:
def test_update_model() -> None:
"""Tests that the model can be updated."""
model = Model(20 * units.micron, name="dirbe")

obstime = time.Time("2021-01-01T00:00:00")
skycoord = coordinates.SkyCoord(20, 30, unit=units.deg, obstime=obstime)
emission_before = model.evaluate(skycoord)
parameters = model.get_parameters()
comp = random.choice(list(parameters["comps"].keys()))
parameter = random.choice(list(parameters["comps"][comp]))
parameters["comps"][comp][parameter] = random.random()
model.update_parameters(parameters)
emission_after = model.evaluate(skycoord)
assert not np.allclose(emission_before, emission_after)


def test_get_model_raises_error() -> None:
Expand Down
53 changes: 35 additions & 18 deletions zodipy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,10 @@ def __init__(
normalized_weights = weights / integrate.trapezoid(weights, x)
else:
normalized_weights = None
self._b_nu_table = tabulate_blackbody_emission(x, normalized_weights)

# We interpolate the spectrally dependant zodiacal light parameters over the provided
# bandpass or delta frequency/wavelength.
interp_and_unpack_func = get_model_interp_func(self._ipd_model)
self._interped_comp_params, self._interped_shared_params = interp_and_unpack_func(
x, normalized_weights, self._ipd_model
)

self._ephemeris = ephemeris
self._x = x
self._normalized_weights = normalized_weights
self._b_nu_table = tabulate_blackbody_emission(self._x, self._normalized_weights)

quad_points, quad_weights = np.polynomial.legendre.leggauss(gauss_quad_degree)
self._integrate_leggauss = functools.partial(
Expand All @@ -106,14 +100,14 @@ def __init__(
weights=quad_weights,
)

# Build partial functions to be evaluated when simulating the zodiacal light. These partials
# are pre-populated functions that contains all non line-of-sight related parameters.
self._number_density_partials = get_partial_number_density_func(comps=self._ipd_model.comps)
self._shared_brightness_partial = functools.partial(
self._ipd_model.brightness_at_step_callable,
bp_interpolation_table=self._b_nu_table,
**self._interped_shared_params,
)
self._ephemeris = ephemeris

# Make mypy happy by declaring types of to-be initialized attributes.
self._number_density_partials: dict[ComponentLabel, functools.partial]
self._interped_comp_params: dict[ComponentLabel, dict]
self._interped_shared_params: dict

self._init_ipd_model_partials()

def evaluate(
self,
Expand Down Expand Up @@ -178,6 +172,7 @@ def evaluate(

number_density_partials = self._number_density_partials
shared_brightness_partial = self._shared_brightness_partial

dist_coords_to_cores = skycoord.size > nprocesses and nprocesses > 1
if instantaneous or not dist_coords_to_cores:
# Populate the instantaneous Earth and observer position in the partial functions.
Expand Down Expand Up @@ -276,6 +271,26 @@ def evaluate(
emission <<= units.MJy / units.sr
return emission if return_comps else emission.sum(axis=0)

def _init_ipd_model_partials(self) -> None:
"""Initialize the partial functions for the interplanetary dust model.
The spectrally dependant model parameters are interpolated over the provided bandpass or
delta frequency/wavelength. The partial functions are pre-populated functions that contains
all non line-of-sight related parameters.
"""
interp_and_unpack_func = get_model_interp_func(self._ipd_model)
dicts = interp_and_unpack_func(self._x, self._normalized_weights, self._ipd_model)
self._interped_comp_params = dicts[0]
self._interped_shared_params = dicts[1]

self._shared_brightness_partial = functools.partial(
self._ipd_model.brightness_at_step_callable,
bp_interpolation_table=self._b_nu_table,
**self._interped_shared_params,
)

self._number_density_partials = get_partial_number_density_func(comps=self._ipd_model.comps)

def get_parameters(self) -> dict:
"""Return a dictionary containing the interplanetary dust model parameters.
Expand All @@ -294,7 +309,8 @@ def update_parameters(self, parameters: dict) -> None:
Args:
parameters: Dictionary of parameters to update. The keys must be the names
of the parameters as defined in the model. To get the parameters dict
of an existing model, use `Zodipy("dirbe").get_parameters()`.
of an existing model, use the`get_parameters` method of an initialized
`zodipy.Model`.
"""
_dict = parameters.copy()
_dict["comps"] = {}
Expand All @@ -308,6 +324,7 @@ def update_parameters(self, parameters: dict) -> None:
_dict[key] = {ComponentLabel(k): v for k, v in value.items()}

self._ipd_model = self._ipd_model.__class__(**_dict)
self._init_ipd_model_partials()


def validate_user_input(skycoord: coords.SkyCoord, obspos: units.Quantity | str) -> time.Time:
Expand Down

0 comments on commit e266828

Please sign in to comment.