Skip to content

Commit

Permalink
Attempts to fix #27, currently fails due to keyword arguments in inte…
Browse files Browse the repository at this point in the history
…rp_kind.
  • Loading branch information
dncnwtts committed Apr 29, 2024
1 parent 9897bda commit 090b557
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
14 changes: 7 additions & 7 deletions zodipy/_interpolate_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def get_source_parameters_kelsall_comp(
if not bandpass.frequencies.unit.is_equivalent(model.spectrum.unit):
bandpass.switch_convention()

spectrum = (
center_freqs = (
model.spectrum.to_value(u.Hz)
if model.spectrum.unit.is_equivalent(u.Hz)
else model.spectrum.to_value(u.micron)
)

interpolator = partial(interpolator, x=spectrum)
interpolator = partial(interpolator, x=center_freqs)

source_parameters: dict[ComponentLabel | str, dict[str, Any]] = {}
for comp_label in model.comps:
Expand All @@ -52,11 +52,11 @@ def get_source_parameters_kelsall_comp(
source_parameters[comp_label]["albedo"] = albedo

if model.phase_coefficients is not None:
phase_coefficients = interpolator(y=np.asarray(model.phase_coefficients))(
bandpass.frequencies.value
)
phase_coefficients = interpolator(y=np.asarray(model.phase_coefficients))(
bandpass.frequencies.value
phase_coefficients = np.array(
[
interpolator(y=np.asarray(model.phase_coefficients[i]))(bandpass.frequencies.value)
for i in range(len(model.phase_coefficients))
]
)
else:
phase_coefficients = np.repeat(np.zeros((3, 1)), repeats=bandpass.frequencies.size, axis=-1)
Expand Down
14 changes: 4 additions & 10 deletions zodipy/zodipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import healpy as hp
import numpy as np
from astropy.coordinates import solar_system_ephemeris
from scipy.interpolate import interp1d
from scipy.interpolate import CubicSpline

from zodipy._bandpass import get_bandpass_interpolation_table, validate_and_get_bandpass
from zodipy._constants import SPECIFIC_INTENSITY_UNITS
Expand Down Expand Up @@ -45,10 +45,6 @@ class Zodipy:
Defaults to DIRBE.
gauss_quad_degree (int): Order of the Gaussian-Legendre quadrature used to evaluate
the line-of-sight integral in the simulations. Default is 50 points.
interp_kind (str): Interpolation kind used to interpolate relevant model parameters.
Defaults to 'linear'. For more information on available interpolation methods,
please visit the [Scipy documentation](
https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interp1d.html).
extrapolate (bool): If `True` all spectral quantities in the selected model are
extrapolated to the requested frequencies or wavelengths. If `False`, an
exception is raised on requested frequencies/wavelengths outside of the
Expand All @@ -72,7 +68,6 @@ def __init__(
model: str = "dirbe",
gauss_quad_degree: int = 50,
extrapolate: bool = False,
interp_kind: str = "linear",
ephemeris: str = "de432s",
solar_cut: u.Quantity[u.deg] | None = None,
solar_cut_fill_value: float = np.nan,
Expand All @@ -81,16 +76,15 @@ def __init__(
self.model = model
self.gauss_quad_degree = gauss_quad_degree
self.extrapolate = extrapolate
self.interp_kind = interp_kind
self.ephemeris = ephemeris
self.solar_cut = solar_cut.to(u.rad) if solar_cut is not None else solar_cut
self.solar_cut_fill_value = solar_cut_fill_value
self.n_proc = n_proc

self._interpolator = partial(
interp1d,
kind=self.interp_kind,
fill_value="extrapolate" if self.extrapolate else np.nan,
CubicSpline,
extrapolate=True,
bc_type="natural",
)
self._ipd_model = model_registry.get_model(model)
self._gauss_points_and_weights = np.polynomial.legendre.leggauss(gauss_quad_degree)
Expand Down

0 comments on commit 090b557

Please sign in to comment.