Skip to content

Commit

Permalink
Change quadrature from Gauss-Legendre to Gauss-Laguerre and rescale i…
Browse files Browse the repository at this point in the history
…ntegration intervals. Remove usage of start position. Integration is now from 0 -> ~5.2
  • Loading branch information
MetinSa committed Feb 2, 2024
1 parent 6b9894f commit 5ec2cd6
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 55 deletions.
2 changes: 1 addition & 1 deletion tests/_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def any_obs(draw: DrawFn, model: zodipy.Zodipy) -> str:

MODEL_STRATEGY_MAPPINGS: dict[str, SearchStrategy[Any]] = {
"model": sampled_from(AVAILABLE_MODELS),
"gauss_quad_degree": integers(min_value=1, max_value=200),
"gauss_quad_degree": integers(min_value=10, max_value=100),
"extrapolate": booleans(),
"solar_cut": quantities(min_value=0, max_value=360, unit=u.deg),
}
Expand Down
14 changes: 3 additions & 11 deletions zodipy/_emission.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

def kelsall(
r: npt.NDArray[np.float64],
start: np.float64,
stop: npt.NDArray[np.float64],
X_obs: npt.NDArray[np.float64],
u_los: npt.NDArray[np.float64],
get_density_function: ComponentDensityFn,
Expand All @@ -38,9 +36,7 @@ def kelsall(
bp_interpolation_table: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
"""Kelsall uses common line of sight grid from obs to 5.2 AU."""
# Convert the quadrature range from [-1, 1] to the true ecliptic positions
R_los = ((stop - start) / 2) * r + (stop + start) / 2
X_los = R_los * u_los
X_los = r * u_los
X_helio = X_los + X_obs
R_helio = np.sqrt(X_helio[0] ** 2 + X_helio[1] ** 2 + X_helio[2] ** 2)

Expand All @@ -50,7 +46,7 @@ def kelsall(

if albedo != 0:
solar_flux = solar_irradiance / R_helio**2
scattering_angle = get_scattering_angle(R_los, R_helio, X_los, X_helio)
scattering_angle = get_scattering_angle(r, R_helio, X_los, X_helio)
phase_function = get_phase_function(scattering_angle, phase_coefficients)

emission += albedo * solar_flux * phase_function
Expand All @@ -60,8 +56,6 @@ def kelsall(

def rrm(
r: npt.NDArray[np.float64],
start: npt.NDArray[np.float64],
stop: npt.NDArray[np.float64],
X_obs: npt.NDArray[np.float64],
u_los: npt.NDArray[np.float64],
get_density_function: ComponentDensityFn,
Expand All @@ -71,9 +65,7 @@ def rrm(
bp_interpolation_table: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
"""RRM is implented with component specific line-of-sight grids."""
# Convert the quadrature range from [-1, 1] to the true ecliptic positions
R_los = ((stop - start) / 2) * r + (stop + start) / 2
X_los = R_los * u_los
X_los = r * u_los
X_helio = X_los + X_obs
R_helio = np.sqrt(X_helio[0] ** 2 + X_helio[1] ** 2 + X_helio[2] ** 2)

Expand Down
20 changes: 6 additions & 14 deletions zodipy/_line_of_sight.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
ComponentLabel.BAND1: (R_0, R_JUPITER),
ComponentLabel.BAND2: (R_0, R_JUPITER),
ComponentLabel.BAND3: (R_0, R_JUPITER),
ComponentLabel.RING: (R_EARTH - 0.2, R_EARTH + 0.2),
ComponentLabel.FEATURE: (R_EARTH - 0.2, R_EARTH + 0.2),
ComponentLabel.RING: (R_EARTH - 0.5, R_EARTH + 0.5),
ComponentLabel.FEATURE: (R_EARTH - 0.5, R_EARTH + 0.5),
}

RRM_CUTOFFS: dict[ComponentLabel, tuple[float | np.float64, float]] = {
Expand Down Expand Up @@ -78,21 +78,13 @@ def get_sphere_intersection(
return np.maximum(q, c / q)


def get_line_of_sight_start_and_stop_distances(
def get_line_of_sight_distances(
components: Iterable[ComponentLabel],
unit_vectors: npt.NDArray[np.float64],
obs_pos: npt.NDArray[np.float64],
) -> tuple[
dict[ComponentLabel, npt.NDArray[np.float64]],
dict[ComponentLabel, npt.NDArray[np.float64]],
]:
"""Get the start and stop distances for each component."""
start = {
comp: get_sphere_intersection(obs_pos, unit_vectors, COMPONENT_CUTOFFS[comp][0])
for comp in components
}
stop = {
) -> dict[ComponentLabel, npt.NDArray[np.float64]]:
"""Get the maximum integration length for each component."""
return {
comp: get_sphere_intersection(obs_pos, unit_vectors, COMPONENT_CUTOFFS[comp][1])
for comp in components
}
return start, stop
61 changes: 32 additions & 29 deletions zodipy/zodipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from zodipy._interpolate_source import SOURCE_PARAMS_MAPPING
from zodipy._ipd_comps import ComponentLabel
from zodipy._ipd_dens_funcs import construct_density_partials_comps
from zodipy._line_of_sight import get_line_of_sight_start_and_stop_distances
from zodipy._line_of_sight import get_line_of_sight_distances
from zodipy._sky_coords import get_obs_and_earth_positions
from zodipy._unit_vectors import get_unit_vectors_from_ang, get_unit_vectors_from_pixels
from zodipy._validators import get_validated_ang, get_validated_pix
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(
fill_value="extrapolate" if self.extrapolate else np.nan,
)
self._ipd_model = model_registry.get_model(model)
self._gauss_points_and_weights = np.polynomial.legendre.leggauss(gauss_quad_degree)
self._gauss_points_and_weights = np.polynomial.laguerre.laggauss(self.gauss_quad_degree)

@property
def supported_observers(self) -> list[str]:
Expand Down Expand Up @@ -429,7 +429,7 @@ def _compute_emission(

# Get the integration limits for each zodiacal component (which may be
# different or the same depending on the model) along all line of sights.
start, stop = get_line_of_sight_start_and_stop_distances(
stop = get_line_of_sight_distances(
components=self._ipd_model.comps.keys(),
unit_vectors=unit_vectors,
obs_pos=observer_position,
Expand Down Expand Up @@ -457,37 +457,36 @@ def _compute_emission(
integrated_comp_emission = np.zeros((len(self._ipd_model.comps), unit_vectors.shape[1]))
with multiprocessing.get_context(SYS_PROC_START_METHOD).Pool(processes=n_proc) as pool:
for idx, comp_label in enumerate(self._ipd_model.comps.keys()):
stop_chunks = np.array_split(stop[comp_label], n_proc, axis=-1)
if start[comp_label].size == 1:
start_chunks = [start[comp_label]] * n_proc
else:
start_chunks = np.array_split(start[comp_label], n_proc, axis=-1)
comp_integrands = [
partial(
common_integrand,
u_los=np.expand_dims(unit_vectors, axis=-1),
start=np.expand_dims(start, axis=-1),
stop=np.expand_dims(stop, axis=-1),
get_density_function=density_partials[comp_label],
**source_parameters[comp_label],
)
for unit_vectors, start, stop in zip(
unit_vector_chunks, start_chunks, stop_chunks
for unit_vectors in unit_vector_chunks
]
stop_chunks = np.array_split(stop[comp_label], n_proc, axis=-1)
quad_partials = [
partial(
_integrate_gauss_laguerre,
points=self._gauss_points_and_weights[0],
weights=self._gauss_points_and_weights[1],
stop=np.expand_dims(stop, axis=-1),
)
for stop in stop_chunks
]

proc_chunks = [
pool.apply_async(
_integrate_gauss_quad,
args=(comp_integrand, *self._gauss_points_and_weights),
quad_partial,
args=[comp_integrand],
)
for comp_integrand in comp_integrands
for comp_integrand, quad_partial in zip(comp_integrands, quad_partials)
]

integrated_comp_emission[idx] += (
np.concatenate([result.get() for result in proc_chunks])
* 0.5
* (stop[comp_label] - start[comp_label])
integrated_comp_emission[idx] += np.concatenate(
[result.get() for result in proc_chunks]
)

else:
Expand All @@ -498,16 +497,13 @@ def _compute_emission(
comp_integrand = partial(
common_integrand,
u_los=unit_vectors_expanded,
start=np.expand_dims(start[comp_label], axis=-1),
stop=np.expand_dims(stop[comp_label], axis=-1),
get_density_function=density_partials[comp_label],
**source_parameters[comp_label],
)

integrated_comp_emission[idx] = (
_integrate_gauss_quad(comp_integrand, *self._gauss_points_and_weights)
* 0.5
* (stop[comp_label] - start[comp_label])
integrated_comp_emission[idx] = _integrate_gauss_laguerre(
comp_integrand,
*self._gauss_points_and_weights,
stop=np.expand_dims(stop[comp_label], axis=-1),
)

emission = np.zeros(
Expand Down Expand Up @@ -543,10 +539,17 @@ def __repr__(self) -> str:
return repr_str[:-2] + ")"


def _integrate_gauss_quad(
def _integrate_gauss_laguerre(
fn: Callable[[float], npt.NDArray[np.float64]],
points: npt.NDArray[np.float64],
weights: npt.NDArray[np.float64],
stop: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
"""Integrate a function using Gauss-Legendre quadrature."""
return np.squeeze(sum(fn(x) * w for x, w in zip(points, weights)))
"""Integrate a function using Gauss-Laguerre quadrature.
If a stop is provided, the integral is rescaled from 0 -> infty to 0 -> stop.
"""
scale_factor = stop / points[-1]
return np.squeeze(
scale_factor * sum(fn(x * scale_factor) * np.exp(x) * w for x, w in zip(points, weights))
)

0 comments on commit 5ec2cd6

Please sign in to comment.