Skip to content

Commit

Permalink
Merge pull request #256 from Deltares/feature/225-further-performance…
Browse files Browse the repository at this point in the history
…-improvements

chore: Small improvements while analysing ra2ce performance.
  • Loading branch information
ArdtK authored Nov 27, 2023
2 parents 23f6f28 + ac3fad3 commit fa0ccc9
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 273 deletions.
2 changes: 1 addition & 1 deletion ra2ce/analyses/indirect/analyses_indirect.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def multi_link_origin_destination_regional_impact(self, gdf_ori):
"output_graph", "origin_destination_table.gpkg"
)
origin = gpd.read_file(origin_fn, engine="pyogrio")
index = [type(x) == str for x in origin["o_id"]]
index = [isinstance(x, str) for x in origin["o_id"]]
origin = origin[index]
origin.reset_index(inplace=True, drop=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __add__(
prioritarian=self.prioritarian + other,
egalitarian=self.egalitarian + other,
)
raise NotImplementedError(
raise TypeError(
"It is not possible to sum {} with a value of type {}.".format(
AccumulatedTraffic.__name__, type(other).__name__
)
Expand All @@ -74,7 +74,7 @@ def __mul__(
prioritarian=self.prioritarian * other,
egalitarian=self.egalitarian * other,
)
raise NotImplementedError(
raise TypeError(
"It is not possible to multiply {} with a value of type {}.".format(
AccumulatedTraffic.__name__, type(other).__name__
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from pathlib import Path
from typing import Callable
from geopandas import GeoDataFrame, read_file, sjoin
from joblib import Parallel, delayed
from networkx import Graph
from numpy import nanmean
from ra2ce.graph.hazard.hazard_intersect.hazard_intersect_builder_base import (
HazardIntersectBuilderBase,
)
from ra2ce.graph.hazard.hazard_intersect.hazard_intersect_parallel_run import (
get_hazard_parallel_process,
)


@dataclass
Expand Down Expand Up @@ -134,7 +136,10 @@ def geodataframe_overlay(hazard_shp_file: Path, ra2ce_name: str):

def _overlay_in_parallel(self, overlay_func: Callable):
# Run in parallel to boost performance.
Parallel(n_jobs=2, require="sharedmem")(
delayed(overlay_func)(self.hazard_gpkg_files[i], _ra2ce_name)
for i, _ra2ce_name in enumerate(self.ra2ce_names)
get_hazard_parallel_process(
overlay_func,
lambda delayed_func: (
delayed_func(self.hazard_gpkg_files[i], _ra2ce_name)
for i, _ra2ce_name in enumerate(self.ra2ce_names)
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@

from networkx import Graph, set_edge_attributes
from geopandas import GeoDataFrame
from ra2ce.graph.hazard.hazard_intersect.hazard_intersect_parallel_run import (
get_hazard_parallel_process,
)
from ra2ce.graph.networks_utils import (
fraction_flooded,
get_graph_edges_extent,
get_valid_mean,
)
from joblib import Parallel, delayed
from tqdm import tqdm


@dataclass
Expand Down Expand Up @@ -71,9 +74,9 @@ def _from_networkx(self, hazard_overlay: Graph) -> Graph:
*graph* (NetworkX Graph) : NetworkX graph with hazard values
"""
# TODO apply multiprocessing?
from tqdm import (
tqdm, # somehow this only works when importing here and not at the top of the file
)
# from tqdm import (
# tqdm, # somehow this only works when importing here and not at the top of the file
# )

# Verify the graph type (networkx)
assert isinstance(hazard_overlay, Graph)
Expand Down Expand Up @@ -169,10 +172,6 @@ def _from_geodataframe(self, hazard_overlay: GeoDataFrame):
Returns:
"""
from tqdm import (
tqdm, # somehow this only works when importing here and not at the top of the file
)

assert isinstance(hazard_overlay, GeoDataFrame), "Network is not a GeoDataFrame"

# Make sure none of the geometries is a nonetype object (this will raise an error in zonal_stats)
Expand Down Expand Up @@ -202,18 +201,27 @@ def overlay_geodataframe(

tqdm.pandas(desc="Network hazard overlay with " + hazard_name)
_hazard_files_str = str(hazard_tif_file)
# Performance sinkhole
flood_stats = hazard_overlay.geometry.progress_apply(
lambda x, _hz_str=_hazard_files_str: zonal_stats(
x,
_hz_str,
lambda _geom_vector: zonal_stats(
vectors=_geom_vector,
raster=_hazard_files_str,
all_touched=True,
stats="min max",
add_stats={"mean": get_valid_mean},
)
)
hazard_overlay[ra2ce_name + "_mi"] = [x[0]["min"] for x in flood_stats]
hazard_overlay[ra2ce_name + "_ma"] = [x[0]["max"] for x in flood_stats]
hazard_overlay[ra2ce_name + "_me"] = [x[0]["mean"] for x in flood_stats]

def _get_attributes(gen_flood_stat: list[dict]) -> tuple:
# Just get the first element of the generator
_flood_stat = gen_flood_stat[0]
return _flood_stat["min"], _flood_stat["max"], _flood_stat["mean"]

(
hazard_overlay[ra2ce_name + "_mi"],
hazard_overlay[ra2ce_name + "_ma"],
hazard_overlay[ra2ce_name + "_me"],
) = list(zip(*map(_get_attributes, flood_stats)))

tqdm.pandas(desc="Network fraction with hazard overlay with " + hazard_name)
hazard_overlay[ra2ce_name + "_fr"] = hazard_overlay.geometry.progress_apply(
Expand All @@ -226,7 +234,10 @@ def overlay_geodataframe(

def _overlay_in_parallel(self, overlay_func: Callable):
# Run in parallel to boost performance.
Parallel(n_jobs=2, require="sharedmem")(
delayed(overlay_func)(self.hazard_tif_files[i], hn, rn)
for i, (hn, rn) in enumerate(self._combined_names)
get_hazard_parallel_process(
overlay_func,
lambda delayed_func: (
delayed_func(self.hazard_tif_files[i], hn, rn)
for i, (hn, rn) in enumerate(self._combined_names)
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@

class HazardIntersectBuilderProtocol(Protocol):
def get_intersection(
self, hazard_overlay: GeoDataFrame| Graph
self, hazard_overlay: GeoDataFrame | Graph
) -> GeoDataFrame | Graph:
"""
Retrieves the resulting network from intersecting the hazard layer with a graph.
Args:
hazard_overlay (GeoDataFrame | Graph): Layer containing hazards.
Returns:
GeoDataFrame | Graph: Intersected graph with hazards.
"""
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
Copyright (C) 2023 Stichting Deltares
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

# This file contains the common utils to run a method in parallel
from joblib import Parallel, delayed
from typing import Callable


def get_hazard_parallel_process(
delegated_func: Callable, func_iterable: Callable
) -> None:
"""
Runs in parallel a delegated process which will consume using the `delayed` method together
with its associated parameters to retrieve from `func_iterable`.
Args:
delegated_func (Callable): Method signature which will be run in parallel.
func_iterable (Callable): Method generating the arguments required for `delegated_func`.
"""
return Parallel(n_jobs=2, require="sharedmem")(
func_iterable(delayed(delegated_func))
)
2 changes: 0 additions & 2 deletions ra2ce/graph/hazard/hazard_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def overlay_hazard_raster_graph(
Returns:
*graph* (NetworkX Graph) : NetworkX graph with hazard values
"""
from tqdm import tqdm

# Verify the graph type (networkx)
assert isinstance(graph, nx.classes.graph.Graph)
Expand Down Expand Up @@ -322,7 +321,6 @@ def point_hazard_intersect(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
Returns:
gdf (GeoDataFrame): the point geodataframe with hazard raster(s) data joined
"""
from tqdm import tqdm

## Intersect the origin and destination nodes with the hazard map (now only geotiff possible)
for i, (hn, rn) in enumerate(zip(self.hazard_names, self.ra2ce_names)):
Expand Down
6 changes: 3 additions & 3 deletions ra2ce/graph/networks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,11 @@ def snap_endpoints_lines(
return lines_gdf


def find_isolated_endpoints(linesIds: list, lines: list) -> list:
def find_isolated_endpoints(lines_ids: list[str], lines: list) -> list:
"""Find endpoints of lines that don't touch another line.
Args:
linesIds: a list of the IDs of lines
lines_ids: a list of the IDs of lines
lines: a list of LineStrings or a MultiLineString
Returns:
Expand All @@ -417,7 +417,7 @@ def find_isolated_endpoints(linesIds: list, lines: list) -> list:
Build on library from https://github.com/ojdo/python-tools/blob/master/shapelytools.py
"""
isolated_endpoints = []
for i, id_line in enumerate(zip(linesIds, lines)):
for i, id_line in enumerate(zip(lines_ids, lines)):
ids, line = id_line
other_lines = lines[:i] + lines[i + 1 :]
for q in [0, -1]:
Expand Down
29 changes: 17 additions & 12 deletions ra2ce/graph/origins_destinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,25 @@ def closest_node(node: np.ndarray, nodes: np.ndarray) -> np.ndarray:
return nodes[np.argmin(dist_2)]


def get_od(o_id, d_id):
def get_od(o_id: str, d_id: str) -> str:
"""
TODO: VERY UNCLEAR what this method is meant to do.
FIX: Solve below logic, it is not a correct paradigm. ADD TESTS AND TYPE HINTS.
Gets a valid origin id node from the given pair.
Args:
o_id (str): Id for the `origin` node.
d_id (str): Id for the `destination` node.
Returns:
str | np.nan: Valid value to represent the origin - destination node.
"""
match_name = o_id
if o_id == "nan":
# convert string nans to np.nans to be able to differentiate between origins and destinations in the next step.
match_name = np.nan
if not match_name == match_name:
# match_name is nan, the point is not an origin but a destination
match_name = d_id
return match_name
_nan_values = ["nan", np.nan]
if o_id not in _nan_values:
return o_id
if d_id not in _nan_values:
# `o_id` was nan, so it was a destination, not an origin.
# therefore we return `d_id`
return d_id
return np.nan


def add_data_to_existing_node(graph, node, match_name):
Expand Down Expand Up @@ -425,7 +431,6 @@ def get_node_id_from_position(
coords = tuple(sorted([coord for coord in geometry_coords]))
if coords in checked_lines:
graph.remove_edge(*line[0:3])
continue
else:
inverse_vertices_dict.update(
{p: (line[0], line[1], line[2]) for p in set(geometry_coords[1:-1])}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def valid_accumulated_traffic() -> AccumulatedTraffic:

class TestAccumulatedTrafficDataclass:
def test_multiply_wrong_type_raises_error(self):
with pytest.raises(NotImplementedError) as exc_err:
with pytest.raises(TypeError) as exc_err:
AccumulatedTraffic() * "Lorem ipsum"
assert (
str(exc_err.value)
== "It is not possible to multiply AccumulatedTraffic with a value of type str."
)

def test_addition_wrong_type_raises_error(self):
with pytest.raises(NotImplementedError) as exc_err:
with pytest.raises(TypeError) as exc_err:
AccumulatedTraffic() + "Lorem ipsum"
assert (
str(exc_err.value)
Expand Down
Loading

0 comments on commit fa0ccc9

Please sign in to comment.