diff --git a/README.md b/README.md index 8e8f417..a99a92e 100644 --- a/README.md +++ b/README.md @@ -33,23 +33,20 @@ from astropy.time import Time import zodipy # Initialize a zodiacal light model at a wavelength/frequency or over a bandpass -model = zodipy.Model(25*u.micron, name="DIRBE") +model = zodipy.Model(25*u.micron) # Use Astropy's `SkyCoord` to specify coordinate lon = [10, 10.1, 10.2] * u.deg lat = [90, 89, 88] * u.deg -skycoord = SkyCoord( - lon, - lat, - obstime=Time("2022-01-01 12:00:00"), - frame="galactic", -) +obstimes = Time(["2022-01-01 12:00:00", "2022-01-01 12:01:00", "2022-01-01 12:02:00"]) + +skycoord = SkyCoord(lon, lat, obstime=obstimes, frame="galactic") # Compute the zodiacal light as seen from Earth emission = model.evaluate(skycoord, obspos="earth") print(emission) -#> [27.52410841 27.66581351 27.81270207] MJy / sr +#> [27.52410841 27.66572294 27.81251906] MJy / sr ``` ## Related scientific papers diff --git a/docs/index.md b/docs/index.md index d43efae..01d130b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -26,23 +26,20 @@ from astropy.time import Time import zodipy # Initialize a zodiacal light model at a wavelength/frequency or over a bandpass -model = zodipy.Model(25*u.micron, name="DIRBE") +model = zodipy.Model(25*u.micron) # Use Astropy's `SkyCoord` to specify coordinate lon = [10, 10.1, 10.2] * u.deg lat = [90, 89, 88] * u.deg -skycoord = SkyCoord( - lon, - lat, - obstime=Time("2022-01-01 12:00:00"), - frame="galactic", -) +obstimes = Time(["2022-01-01 12:00:00", "2022-01-01 12:01:00", "2022-01-01 12:02:00"]) + +skycoord = SkyCoord(lon, lat, obstime=obstimes, frame="galactic") # Compute the zodiacal light as seen from Earth emission = model.evaluate(skycoord, obspos="earth") print(emission) -#> [27.52410841 27.66581351 27.81270207] MJy / sr +#> [27.52410841 27.66572294 27.81251906] MJy / sr ``` For more information on using ZodiPy, see [the usage section](usage.md). diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index fd5891e..74a6203 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -95,7 +95,7 @@ def test_evaluate_invalid_obspos() -> None: def test_output_shape() -> None: """Test that the return_comps function works for valid user input.""" - n_comps = test_model._interplanetary_dust_model.ncomps + n_comps = test_model._ipd_model.ncomps assert test_model.evaluate(TEST_SKY_COORD, return_comps=True).shape == (n_comps, 1) assert test_model.evaluate(TEST_SKY_COORD, return_comps=False).shape == (1,) @@ -168,7 +168,7 @@ def test_return_comps() -> None: assert_array_equal(emission_comps.sum(axis=0), emission) -def test_multiprocessing_nproc() -> None: +def test_multiprocessing_nproc_inst() -> None: """Test that the multiprocessing works with n_proc > 1.""" model = Model(x=20 * units.micron) @@ -198,3 +198,27 @@ def test_multiprocessing_nproc() -> None: emission = model.evaluate(skycoord, nprocesses=1) assert_array_equal(emission_multi, emission) + + +def test_multiprocessing_nproc_time() -> None: + """Test that the multiprocessing works with n_proc > 1.""" + model = Model(x=20 * units.micron) + + lon = np.random.randint(low=0, high=360, size=10000) + lat = np.random.randint(low=-90, high=90, size=10000) + obstime = np.linspace(0, 300, 10000) + TEST_TIME.mjd + skycoord = SkyCoord( + lon, + lat, + unit=units.deg, + obstime=time.Time(obstime, format="mjd"), + ) + emission_multi = model.evaluate(skycoord, nprocesses=4) + emission = model.evaluate(skycoord, nprocesses=1) + assert_array_equal(emission_multi, emission) + + # model = Model(x=75 * units.micron, name="rrm-experimental") + # emission_multi = model.evaluate(skycoord, nprocesses=4) + # emission = model.evaluate(skycoord, nprocesses=1) + + # assert_array_equal(emission_multi, emission) diff --git a/zodipy/blackbody.py b/zodipy/blackbody.py index b7c2346..e0a9e68 100644 --- a/zodipy/blackbody.py +++ b/zodipy/blackbody.py @@ -9,14 +9,14 @@ MIN_TEMP = 40 * units.K MAX_TEMP = 550 * units.K N_TEMPS = 100 -temperatures = np.linspace(MIN_TEMP, MAX_TEMP, N_TEMPS) -blackbody = BlackBody(temperatures) +TEMPERATURES = np.linspace(MIN_TEMP, MAX_TEMP, N_TEMPS) +blackbody = BlackBody(TEMPERATURES) def get_dust_grain_temperature( R: npt.NDArray[np.float64], T_0: float, delta: float ) -> npt.NDArray[np.float64]: - """Return the dust grain temperature given a radial distance from the Sun. + """Return the dust grain temperature given at a radial distance from the Sun. Args: R: Radial distance from the sun in ecliptic heliocentric coordinates [AU / 1AU]. @@ -43,7 +43,7 @@ def tabulate_blackbody_emission( return np.asarray( [ - temperatures.to_value(units.K), + TEMPERATURES.to_value(units.K), tabulated_blackbody_emission.to_value(units.MJy / units.sr), ] ) diff --git a/zodipy/bodies.py b/zodipy/bodies.py index d4e4568..dedef00 100644 --- a/zodipy/bodies.py +++ b/zodipy/bodies.py @@ -47,7 +47,7 @@ def get_interpolated_body_xyz( """Return interpolated Earth positions in the heliocentric frame.""" dt = (1 * units.hour).to_value(units.day) t0, t1 = obstimes[0].mjd, obstimes[-1].mjd - times = time.Time(np.arange(t0, max(t0 + 365, t1), dt), format="mjd") + times = time.Time(np.arange(t0, max(t0 + 366, t1) + dt, dt), format="mjd") bodypos = ( coords.get_body(body, times, ephemeris=ephemeris) diff --git a/zodipy/brightness.py b/zodipy/brightness.py index df438b4..e587eb9 100644 --- a/zodipy/brightness.py +++ b/zodipy/brightness.py @@ -9,7 +9,7 @@ from zodipy.scattering import get_phase_function, get_scattering_angle if TYPE_CHECKING: - from zodipy.number_density import ComponentNumberDensityCallable + from zodipy.number_density import NumberDensityFunc """ Function that return the zodiacal emission at a step along all lines of sight given @@ -25,7 +25,7 @@ def kelsall_brightness_at_step( X_obs: npt.NDArray[np.float64], u_los: npt.NDArray[np.float64], bp_interpolation_table: npt.NDArray[np.float64], - get_density_function: ComponentNumberDensityCallable, + number_density_func: NumberDensityFunc, T_0: float, delta: float, emissivity: np.float64, @@ -53,7 +53,7 @@ def kelsall_brightness_at_step( phase_function = get_phase_function(scattering_angle, C1, C2, C3) emission += albedo * solar_flux * phase_function - return emission * get_density_function(X_helio) * 0.5 * (stop - start) + return emission * number_density_func(X_helio) * 0.5 * (stop - start) def rrm_brightness_at_step( @@ -63,7 +63,7 @@ def rrm_brightness_at_step( X_obs: npt.NDArray[np.float64], u_los: npt.NDArray[np.float64], bp_interpolation_table: npt.NDArray[np.float64], - get_density_function: ComponentNumberDensityCallable, + number_density_func: NumberDensityFunc, T_0: float, delta: float, calibration: np.float64, @@ -80,4 +80,4 @@ def rrm_brightness_at_step( temperature = get_dust_grain_temperature(R_helio, T_0, delta) blackbody_emission = np.interp(temperature, *bp_interpolation_table) - return blackbody_emission * get_density_function(X_helio) * calibration * 0.5 * (stop - start) + return blackbody_emission * number_density_func(X_helio) * calibration * 0.5 * (stop - start) diff --git a/zodipy/line_of_sight.py b/zodipy/line_of_sight.py index b373f60..07878a4 100644 --- a/zodipy/line_of_sight.py +++ b/zodipy/line_of_sight.py @@ -21,8 +21,8 @@ ComponentLabel.BAND1: (R_0, R_JUPITER), ComponentLabel.BAND2: (R_0, R_JUPITER), ComponentLabel.BAND3: (R_0, R_JUPITER), - ComponentLabel.RING: (R_0, R_EARTH + 0.2), - ComponentLabel.FEATURE: (R_0, R_EARTH + 0.2), + ComponentLabel.RING: (R_EARTH - 0.2, R_EARTH + 0.2), + ComponentLabel.FEATURE: (R_EARTH - 0.3, R_EARTH + 0.3), } RRM_CUTOFFS: dict[ComponentLabel, tuple[float | np.float64, float | np.float64]] = { @@ -53,12 +53,12 @@ def integrate_leggauss( - fn: Callable[[float], npt.NDArray[np.float64]], + func: Callable[[float], npt.NDArray[np.float64]], points: npt.NDArray[np.float64], weights: npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """Integrate a function using Gauss-Laguerre quadrature.""" - return np.squeeze(sum(fn(x) * w for x, w in zip(points, weights))) + return np.squeeze(sum(func(x) * w for x, w in zip(points, weights))) def get_sphere_intersection( @@ -67,8 +67,8 @@ def get_sphere_intersection( cutoff: float | np.float64, ) -> npt.NDArray[np.float64]: """Returns the distance from the observer to a heliocentric sphere with radius `cutoff`.""" - x, y, z = obs_pos - r_obs = np.sqrt(x**2 + y**2 + z**2) + x_0, y_0, z_0 = obs_pos + r_obs = np.sqrt(x_0**2 + y_0**2 + z_0**2) if (r_obs > cutoff).any(): if obs_pos.ndim == 1: return np.array([np.finfo(float).eps]) @@ -79,7 +79,7 @@ def get_sphere_intersection( lat = np.arcsin(u_z) cos_lat = np.cos(lat) - b = 2 * (x * cos_lat * np.cos(lon) + y * cos_lat * np.sin(lon)) + b = 2 * (x_0 * cos_lat * np.cos(lon) + y_0 * cos_lat * np.sin(lon)) c = r_obs**2 - cutoff**2 q = -0.5 * b * (1 + np.sqrt(b**2 - 4 * c) / np.abs(b)) @@ -87,7 +87,7 @@ def get_sphere_intersection( return np.maximum(q, c / q) -def get_line_of_sight_range( +def get_line_of_sight_range_dicts( components: Iterable[ComponentLabel], unit_vectors: npt.NDArray[np.float64], obs_pos: npt.NDArray[np.float64], diff --git a/zodipy/model.py b/zodipy/model.py index 9d7d589..94ff412 100644 --- a/zodipy/model.py +++ b/zodipy/model.py @@ -2,7 +2,6 @@ import functools import multiprocessing -import multiprocessing.pool import platform import typing @@ -16,14 +15,14 @@ from zodipy.bodies import get_earthpos_xyz, get_obspos_xyz from zodipy.component import ComponentLabel from zodipy.line_of_sight import ( - get_line_of_sight_range, + get_line_of_sight_range_dicts, integrate_leggauss, ) from zodipy.model_registry import model_registry -from zodipy.number_density import populate_number_density_with_model -from zodipy.unpack_model import get_model_to_dicts_callable +from zodipy.number_density import get_partial_number_density_func, update_partial_earth_pos +from zodipy.unpack_model import get_model_interp_func -_PLATFORM_METHOD = "fork" if "windows" not in platform.system().lower() else None +PLATFORM_METHOD = "fork" if "windows" not in platform.system().lower() else None class Model: @@ -62,26 +61,26 @@ def __init__( """ try: if not x.isscalar and weights is None: - msg = "Several wavelengths are provided by no weights." + msg = "Bandpass weights must be provided for non-scalar `x`." raise ValueError(msg) except AttributeError as error: msg = "The input 'x' must be an astropy Quantity." raise TypeError(msg) from error if x.isscalar and weights is not None: - msg = "A single wavelength is provided with weights." + msg = "Bandpass weights should not be provided for scalar `x`." raise ValueError(msg) - self._interplanetary_dust_model = model_registry.get_model(name) + self._ipd_model = model_registry.get_model(name) - if not extrapolate and not self._interplanetary_dust_model.is_valid_at(x): + if not extrapolate and not self._ipd_model.is_valid_at(x): msg = ( "The requested frequencies are outside the valid range of the model. " "If this was intended, set the extrapolate argument to True." ) raise ValueError(msg) - bandpass_is_provided = weights is not None - if bandpass_is_provided: + # Bandpass is provided rather than a delta wavelength or frequency. + if weights is not None: weights = np.asarray(weights) if x.size != weights.size: msg = "Number of wavelengths and weights must be the same in the bandpass." @@ -91,15 +90,30 @@ def __init__( normalized_weights = None self._b_nu_table = tabulate_blackbody_emission(x, normalized_weights) - # Interpolate and convert the model parameters to dictionaries which can be used to evaluate - # the zodiacal light model. - brightness_callable_dicts = get_model_to_dicts_callable(self._interplanetary_dust_model)( - x, normalized_weights, self._interplanetary_dust_model + # We interpolate the spectrally dependant zodiacal light parameters over the provided + # bandpass or delta frequency/wavelength. + interp_and_unpack_func = get_model_interp_func(self._ipd_model) + self._interped_comp_params, self._interped_shared_params = interp_and_unpack_func( + x, normalized_weights, self._ipd_model ) - self._comp_parameters, self._common_parameters = brightness_callable_dicts self._ephemeris = ephemeris - self._leggauss_points_and_weights = np.polynomial.legendre.leggauss(gauss_quad_degree) + + quad_points, quad_weights = np.polynomial.legendre.leggauss(gauss_quad_degree) + self._integrate_leggauss = functools.partial( + integrate_leggauss, + points=quad_points, + weights=quad_weights, + ) + + # Build partial functions to be evaluated when simulating the zodiacal light. These partials + # are pre-populated functions that contains all non line-of-sight related parameters. + self._number_density_partials = get_partial_number_density_func(comps=self._ipd_model.comps) + self._shared_brightness_partial = functools.partial( + self._ipd_model.brightness_at_step_callable, + bp_interpolation_table=self._b_nu_table, + **self._interped_shared_params, + ) def evaluate( self, @@ -113,9 +127,8 @@ def evaluate( The zodiacal light is simulated for a single, or a sequence of observations. If a single `obspos` and `obstime` is provided for multiple coordinates, all coordinates are assumed to - be observed from that position at that time. Otherwise, when `obspos` and `obstime` contains - multiple values, corresponding to coordinates in `skycoord`, the zodiacal light is simulated - in a time-ordered manner. + be observed from that position at that time. Otherwise, each coordinate is simulated from + the corresponding observer position and time. Args: skycoord: `astropy.coordinates.SkyCoord` object representing the coordinates or @@ -140,6 +153,10 @@ def evaluate( earth_xyz = get_earthpos_xyz(obstime, self._ephemeris) obs_xyz = get_obspos_xyz(obstime, obspos, earth_xyz, self._ephemeris) + # Model evaluation is performed in heliocentric ecliptic coordinates. We transform + # to the barycentric frame, which is compatiable with the Galactic and Celestial, + # and pretend that that this is the heliocentric frame as we only need the correction + # rotation. skycoord = skycoord.transform_to(coords.BarycentricMeanEcliptic) if skycoord.isscalar: skycoord_xyz = typing.cast( @@ -148,78 +165,113 @@ def evaluate( else: skycoord_xyz = typing.cast(npt.NDArray[np.float64], skycoord.cartesian.xyz.value) - start, stop = get_line_of_sight_range( - components=self._interplanetary_dust_model.comps.keys(), + start, stop = get_line_of_sight_range_dicts( + components=self._ipd_model.comps.keys(), unit_vectors=skycoord_xyz, obs_pos=obs_xyz, ) - # If a time and obspos is provided per coordinate, we need to make arrays broadcastable. - if obs_xyz.ndim == 1: + instantaneous = obs_xyz.ndim == 1 + if instantaneous: obs_xyz = obs_xyz[:, np.newaxis] earth_xyz = earth_xyz[:, np.newaxis] - # Return a dict of partial functions corresponding to the number density of each zodiacal - # component in the interplanetary dust model. - density_callables = populate_number_density_with_model( - comps=self._interplanetary_dust_model.comps, - dynamic_params={"X_earth": earth_xyz}, - ) + number_density_partials = self._number_density_partials + shared_brightness_partial = self._shared_brightness_partial + dist_coords_to_cores = skycoord.size > nprocesses and nprocesses > 1 + if instantaneous or not dist_coords_to_cores: + # Populate the instantaneous Earth and observer position in the partial functions. + number_density_partials = update_partial_earth_pos( + number_density_partials, earth_pos=earth_xyz + ) + shared_brightness_partial = functools.partial(shared_brightness_partial, X_obs=obs_xyz) - # Create partial function of the brightness integral at a step along the line-of-sight with - # shared arguments between zodiacal components. - common_integrand = functools.partial( - self._interplanetary_dust_model.brightness_at_step_callable, - X_obs=obs_xyz, - bp_interpolation_table=self._b_nu_table, - **self._common_parameters, - ) + emission = np.zeros((self._ipd_model.ncomps, skycoord.size)) - emission = np.zeros((self._interplanetary_dust_model.ncomps, skycoord.size)) - dist_to_cores = skycoord.size > nprocesses and nprocesses > 1 - if dist_to_cores: + if dist_coords_to_cores: skycoord_xyz_splits = np.array_split(skycoord_xyz, nprocesses, axis=-1) - with multiprocessing.get_context(_PLATFORM_METHOD).Pool(nprocesses) as pool: - for idx, comp_label in enumerate(self._interplanetary_dust_model.comps.keys()): - stop_chunks = np.array_split(stop[comp_label], nprocesses, axis=-1) - if start[comp_label].size == 1: - start_chunks = [start[comp_label]] * nprocesses + if not instantaneous: + # The observer and Earth positions are applied into the partial functions. In the + # case where we have coordinate-by-coordinate observer positions, we need to ensure + # that these positions are distributed accordingly over the cores. This means that + # we need to create partial functions for each core. + earth_xyz_splits = np.array_split(earth_xyz, nprocesses, axis=-1) + obs_xyz_splits = np.array_split(obs_xyz, nprocesses, axis=-1) + + number_density_partial_splits = [ + update_partial_earth_pos( + number_density_partials, + earth_pos=earth_xyz_split, + ) + for earth_xyz_split in earth_xyz_splits + ] + shared_brightness_partial_splits = [ + functools.partial(shared_brightness_partial, X_obs=obs_xyz_split) + for obs_xyz_split in obs_xyz_splits + ] + with multiprocessing.get_context(PLATFORM_METHOD).Pool(nprocesses) as pool: + for idx, comp_label in enumerate(self._ipd_model.comps.keys()): + stop_chunks = ( + [stop[comp_label]] * nprocesses + if stop[comp_label].size == 1 + else np.array_split(stop[comp_label], nprocesses, axis=-1) + ) + start_chunks = ( + [start[comp_label]] * nprocesses + if start[comp_label].size == 1 + else np.array_split(start[comp_label], nprocesses, axis=-1) + ) + + if instantaneous: + comp_funcs = [ + functools.partial( + shared_brightness_partial, + u_los=skycoord_xyz, + start=start, + stop=stop, + number_density_func=number_density_partials[comp_label], + **self._interped_comp_params[comp_label], + ) + for skycoord_xyz, start, stop in zip( + skycoord_xyz_splits, start_chunks, stop_chunks + ) + ] else: - start_chunks = np.array_split(start[comp_label], nprocesses, axis=-1) - comp_integrands = [ - functools.partial( - common_integrand, - u_los=vec, - start=start, - stop=stop, - get_density_function=density_callables[comp_label], - **self._comp_parameters[comp_label], - ) - for vec, start, stop in zip(skycoord_xyz_splits, start_chunks, stop_chunks) - ] - + comp_funcs = [ + functools.partial( + brightness_partial, + u_los=skycoord_xyz, + start=start, + stop=stop, + number_density_func=dens_partial[comp_label], + **self._interped_comp_params[comp_label], + ) + for skycoord_xyz, start, stop, dens_partial, brightness_partial in zip( + skycoord_xyz_splits, + start_chunks, + stop_chunks, + number_density_partial_splits, + shared_brightness_partial_splits, + ) + ] proc_chunks = [ - pool.apply_async( - integrate_leggauss, - args=(comp_integrand, *self._leggauss_points_and_weights), - ) - for comp_integrand in comp_integrands + pool.apply_async(self._integrate_leggauss, args=(func,)) + for func in comp_funcs ] emission[idx] = np.concatenate([result.get() for result in proc_chunks]) + # Simulate the zodiacal light over the coordinates sequentially. else: - for idx, comp_label in enumerate(self._interplanetary_dust_model.comps.keys()): - comp_integrand = functools.partial( - common_integrand, + for idx, comp_label in enumerate(self._ipd_model.comps.keys()): + comp_func = functools.partial( + shared_brightness_partial, u_los=skycoord_xyz, start=start[comp_label], stop=stop[comp_label], - get_density_function=density_callables[comp_label], - **self._comp_parameters[comp_label], - ) - emission[idx] = integrate_leggauss( - comp_integrand, *self._leggauss_points_and_weights + number_density_func=number_density_partials[comp_label], + **self._interped_comp_params[comp_label], ) + emission[idx] = self._integrate_leggauss(comp_func) emission <<= units.MJy / units.sr return emission if return_comps else emission.sum(axis=0) @@ -232,7 +284,7 @@ def get_parameters(self) -> dict: Returns: parameters: Dictionary of parameters of the interplanetary dust model. """ - return self._interplanetary_dust_model.to_dict() + return self._ipd_model.to_dict() def update_parameters(self, parameters: dict) -> None: """Update the interplanetary dust model parameters. @@ -250,12 +302,12 @@ def update_parameters(self, parameters: dict) -> None: if key == "comps": for comp_key, comp_value in value.items(): _dict["comps"][ComponentLabel(comp_key)] = type( - self._interplanetary_dust_model.comps[ComponentLabel(comp_key)] + self._ipd_model.comps[ComponentLabel(comp_key)] )(**comp_value) elif isinstance(value, dict): _dict[key] = {ComponentLabel(k): v for k, v in value.items()} - self._interplanetary_dust_model = self._interplanetary_dust_model.__class__(**_dict) + self._ipd_model = self._ipd_model.__class__(**_dict) def validate_user_input(skycoord: coords.SkyCoord, obspos: units.Quantity | str) -> time.Time: diff --git a/zodipy/number_density.py b/zodipy/number_density.py index 379b259..9ce8d3c 100644 --- a/zodipy/number_density.py +++ b/zodipy/number_density.py @@ -1,12 +1,13 @@ from __future__ import annotations +import copy import inspect from dataclasses import asdict from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Mapping, Protocol +from typing import TYPE_CHECKING, Callable, Mapping, Protocol import numpy as np -import numpy.typing as npt # type: ignore +import numpy.typing as npt from zodipy.bodies import get_earthpos_xyz from zodipy.component import ( @@ -161,6 +162,7 @@ def feature_number_density( - X_feature[1] * cos_Omega_rad * sin_i_rad + X_feature[2] * cos_i_rad ) + X_earth_comp = X_earth - X_0 theta_comp = np.arctan2(X_feature[1], X_feature[0]) - np.arctan2( @@ -418,24 +420,23 @@ def rrm_feature_number_density( } -class ComponentNumberDensityCallable(Protocol): +class NumberDensityFunc(Protocol): """Protocol for a zodiacal components number density function.""" def __call__(self, X_helio: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]: """Return the number density of the component at the heliocentric position.""" -def populate_number_density_with_model( +def get_partial_number_density_func( comps: Mapping[ComponentLabel, ZodiacalComponent], - dynamic_params: dict[str, Any], -) -> dict[ComponentLabel, ComponentNumberDensityCallable]: +) -> dict[ComponentLabel, partial[npt.NDArray[np.float64]]]: """Construct density partials for components. Return a tuple of the density expressions above which has been prepopulated with model and configuration parameters, leaving only the `X_helio` argument to be supplied. Raises exception for incorrectly defined components or component density functions. """ - partial_density_funcs: dict[ComponentLabel, ComponentNumberDensityCallable] = {} + partial_density_funcs: dict[ComponentLabel, partial[npt.NDArray[np.float64]]] = {} for comp_label, comp in comps.items(): comp_dict = asdict(comp) func_params = inspect.signature(DENSITY_FUNCS[type(comp)]).parameters.keys() @@ -446,15 +447,16 @@ def populate_number_density_with_model( msg = "X_helio must be be the first argument to the density function of a component." raise ValueError(msg) from err + if "X_earth" in residual_params: + residual_params.remove("X_earth") + if residual_params: - if residual_params - dynamic_params.keys(): - msg = ( - f"Argument(s) {residual_params} required by the density function " - f"{DENSITY_FUNCS[type(comp)]} are not provided by instance variables in " - f"{type(comp)} or by the `computed_parameters` dict." - ) - raise ValueError(msg) - comp_dict.update(dynamic_params) + msg = ( + f"Argument(s) {residual_params} required by the density function " + f"{DENSITY_FUNCS[type(comp)]} are not provided by instance variables in " + f"{type(comp)} or by the `computed_parameters` dict." + ) + raise ValueError(msg) # Remove excess intermediate parameters from the component dict. comp_params = {key: value for key, value in comp_dict.items() if key in func_params} @@ -464,6 +466,19 @@ def populate_number_density_with_model( return partial_density_funcs +def update_partial_earth_pos( + partials: dict[ComponentLabel, partial[npt.NDArray[np.float64]]], + earth_pos: npt.NDArray[np.float64], +) -> dict[ComponentLabel, partial[npt.NDArray[np.float64]]]: + """Inplace populate the `X_earth` parameter in the partial density functions.""" + updated_partials = copy.deepcopy(partials) + for partial_func in updated_partials.values(): + remaining = inspect.signature(partial_func).parameters.keys() - partial_func.keywords.keys() + if "X_earth" in remaining: + partial_func.keywords["X_earth"] = earth_pos + return updated_partials + + def grid_number_density( x: units.Quantity, y: units.Quantity, @@ -496,9 +511,11 @@ def grid_number_density( for comp in ipd_model.comps.values(): comp.X_0 = comp.X_0.reshape(3, 1, 1, 1) - density_partials = populate_number_density_with_model( + density_partials = get_partial_number_density_func( comps=ipd_model.comps, - dynamic_params={"X_earth": earthpos_xyz[:, np.newaxis, np.newaxis, np.newaxis]}, + ) + density_partials = update_partial_earth_pos( + density_partials, earthpos_xyz[:, np.newaxis, np.newaxis, np.newaxis] ) number_density_grid = np.zeros((len(ipd_model.comps), *grid.shape[1:])) diff --git a/zodipy/unpack_model.py b/zodipy/unpack_model.py index c1a8259..48c99e8 100644 --- a/zodipy/unpack_model.py +++ b/zodipy/unpack_model.py @@ -17,7 +17,7 @@ UnpackModelCallable = Callable[[units.Quantity, Union[units.Quantity, None], T], UnpackedModelDicts] -def unpack_kelsall( +def interp_and_unpack_kelsall( wavelengths: units.Quantity, weights: units.Quantity | None, model: Kelsall, @@ -33,14 +33,14 @@ def unpack_kelsall( for comp_label in model.comps: comp_params[comp_label] = {} - comp_params[comp_label]["emissivity"] = interpolate_spectral_parameter( + comp_params[comp_label]["emissivity"] = interp_spectral_param( wavelengths, weights, model_spectrum, spectral_parameter=model.emissivities[comp_label], ) if model.albedos is not None: - comp_params[comp_label]["albedo"] = interpolate_spectral_parameter( + comp_params[comp_label]["albedo"] = interp_spectral_param( wavelengths, weights, model_spectrum, @@ -50,7 +50,7 @@ def unpack_kelsall( comp_params[comp_label]["albedo"] = 0 common_params["C1"] = ( - interpolate_spectral_parameter( + interp_spectral_param( wavelengths, weights, model_spectrum, @@ -61,7 +61,7 @@ def unpack_kelsall( else 0 ) common_params["C2"] = ( - interpolate_spectral_parameter( + interp_spectral_param( wavelengths, weights, model_spectrum, @@ -73,7 +73,7 @@ def unpack_kelsall( ) common_params["C3"] = ( - interpolate_spectral_parameter( + interp_spectral_param( wavelengths, weights, model_spectrum, @@ -86,7 +86,7 @@ def unpack_kelsall( if model.solar_irradiance is None: common_params["solar_irradiance"] = 0 else: - common_params["solar_irradiance"] = interpolate_spectral_parameter( + common_params["solar_irradiance"] = interp_spectral_param( wavelengths, weights, model_spectrum, @@ -96,7 +96,7 @@ def unpack_kelsall( return comp_params, common_params -def unpack_rrm( +def interp_and_unpack_rrm( wavelengths: units.Quantity, weights: units.Quantity | None, model: RRM, @@ -112,7 +112,7 @@ def unpack_rrm( comp_params[comp_label]["T_0"] = model.T_0[comp_label] comp_params[comp_label]["delta"] = model.delta[comp_label] - calibration = interpolate_spectral_parameter( + calibration = interp_spectral_param( wavelengths, weights, model_spectrum, spectral_parameter=model.calibration ) calibration_quantity = units.Quantity(calibration, unit=units.MJy / units.AU) @@ -120,7 +120,7 @@ def unpack_rrm( return comp_params, common_params -def interpolate_spectral_parameter( +def interp_spectral_param( wavelengths: units.Quantity, weights: units.Quantity | None, model_spectrum: units.Quantity, @@ -131,25 +131,25 @@ def interpolate_spectral_parameter( paramameter = np.asarray(spectral_parameter) if use_nearest: - interpolated_parameter = interpolate.interp1d( - model_spectrum.value, paramameter, kind="nearest" - )(wavelengths.value) + interped_param = interpolate.interp1d(model_spectrum.value, paramameter, kind="nearest")( + wavelengths.value + ) else: - interpolated_parameter = np.interp(wavelengths.value, model_spectrum.value, paramameter) + interped_param = np.interp(wavelengths.value, model_spectrum.value, paramameter) if weights is not None: - return integrate.trapezoid(weights.value * interpolated_parameter, x=wavelengths.value) - return interpolated_parameter + return integrate.trapezoid(weights.value * interped_param, x=wavelengths.value) + return interped_param -model_unpack_mapping: dict[type[ZodiacalLightModel], UnpackModelCallable] = { - Kelsall: unpack_kelsall, - RRM: unpack_rrm, +interp_and_unpack_func_mapping: dict[type[ZodiacalLightModel], UnpackModelCallable] = { + Kelsall: interp_and_unpack_kelsall, + RRM: interp_and_unpack_rrm, } -def get_model_to_dicts_callable( +def get_model_interp_func( model: ZodiacalLightModel, ) -> UnpackModelCallable: """Get the appropriate parameter unpacker for the model.""" - return model_unpack_mapping[type(model)] + return interp_and_unpack_func_mapping[type(model)]