diff --git a/examples/data/hazard_overlay/network.ini b/examples/data/hazard_overlay/network.ini index bc1c78404..cf553ba96 100644 --- a/examples/data/hazard_overlay/network.ini +++ b/examples/data/hazard_overlay/network.ini @@ -11,6 +11,7 @@ polygon = map.geojson # [origins_destinations] origins = None # / None diff --git a/poetry.lock b/poetry.lock index df00f57d7..fcffefcd1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1800,6 +1800,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3246,6 +3256,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4019,6 +4030,27 @@ files = [ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, ] +[[package]] +name = "snkit" +version = "1.9.0" +description = "a spatial networks toolkit" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "snkit-1.9.0-py3-none-any.whl", hash = "sha256:9304716dddef9c738c6d28d14de62e0ed8aab5c5bc3279eaefb4987b8f3e8f4a"}, + {file = "snkit-1.9.0.tar.gz", hash = "sha256:e715d608a100f54d888288e15b5a59bfbcecc863684bb47742150dd7e5171e41"}, +] + +[package.dependencies] +geopandas = ">=0.13" +shapely = ">=2.0" + +[package.extras] +dev = ["black", "mypy", "nbstripout", "pre-commit", "pytest", "pytest-cov", "ruff"] +docs = ["m2r2", "sphinx"] +networkx = ["networkx (>=3.0)"] + [[package]] name = "snowballstemmer" version = "2.2.0" @@ -4710,4 +4742,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9, <3.11" -content-hash = "b2befd24102178585c1cc0d347691890d62c6f7536b7db966adbcf299d0a9ea6" +content-hash = "b366a4070c5329084a9e11140fbe358eb637c6bba66784d523664b6d2278c516" diff --git a/pyproject.toml b/pyproject.toml index 54edc8582..f03f29bee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ geopy = "^2.4.0" pyogrio = "^0.6.0" joblib = "^1.3.2" pyinstaller = "^6.2.0" +snkit = "^1.9.0" [tool.poetry.group.dev.dependencies] isort = "^5.10.1" diff --git a/ra2ce/analysis/analysis_config_data/analysis_config_data.py b/ra2ce/analysis/analysis_config_data/analysis_config_data.py index 917b5612c..a327dec01 100644 --- a/ra2ce/analysis/analysis_config_data/analysis_config_data.py +++ b/ra2ce/analysis/analysis_config_data/analysis_config_data.py @@ -137,7 +137,7 @@ class AnalysisSectionDamages(AnalysisSectionBase): climate_factor: float = math.nan climate_period: float = math.nan # road damage - representative_damage_percentile: float = 100 + representative_damage_percentage: float = 100 event_type: EventTypeEnum = field(default_factory=lambda: EventTypeEnum.INVALID) damage_curve: DamageCurveEnum = field( default_factory=lambda: DamageCurveEnum.INVALID diff --git a/ra2ce/analysis/analysis_config_data/analysis_config_data_reader.py b/ra2ce/analysis/analysis_config_data/analysis_config_data_reader.py index ef42f4b10..4938daaa4 100644 --- a/ra2ce/analysis/analysis_config_data/analysis_config_data_reader.py +++ b/ra2ce/analysis/analysis_config_data/analysis_config_data_reader.py @@ -217,6 +217,11 @@ def _get_analysis_section_damages( "evaluation_period", fallback=_section.evaluation_period, ) + _section.evaluation_period = self._parser.getfloat( + section_name, + "representative_damage_percentage", + fallback=_section.representative_damage_percentage, + ) _section.interest_rate = self._parser.getfloat( section_name, "interest_rate", diff --git a/ra2ce/analysis/damages/damage_calculation/damage_network_base.py b/ra2ce/analysis/damages/damage_calculation/damage_network_base.py index 6962ac9ed..a326c956e 100644 --- a/ra2ce/analysis/damages/damage_calculation/damage_network_base.py +++ b/ra2ce/analysis/damages/damage_calculation/damage_network_base.py @@ -44,14 +44,14 @@ def __init__( self, road_gdf: GeoDataFrame, val_cols: list[str], - representative_damage_percentile: float, + representative_damage_percentage: float, ): """Construct the Data""" self.val_cols = val_cols self.gdf = road_gdf # set of hazard info per event self.stats = set([x.split("_")[-1] for x in val_cols]) - self.representative_damage_percentile = representative_damage_percentile + self.representative_damage_percentage = representative_damage_percentage # TODO: also track the damage cols after the dam calculation, that is useful for the risk calc. module # TODO: also create constructors of the children of this class @@ -268,20 +268,20 @@ def calculate_damage_HZ(self, events): def calculate_damage_OSdaMage(self, events): """Damage calculation with the OSdaMage functions""" - def interpolate_damage(row, representative_damage_percentile): + def interpolate_damage(row, representative_damage_percentage): # Extract the tuple of damage values from the row damage_values = row["dam_{}_{}_quartiles".format(curve_name, event)] # Quantile values corresponding to the damage values - percentiles = [0, 25, 50, 75, 100] + percentages = [0, 25, 50, 75, 100] # Perform linear interpolation using interp1d from scipy _interpolator = interp1d( - percentiles, damage_values, kind="linear", fill_value="extrapolate" + percentages, damage_values, kind="linear", fill_value="extrapolate" ) - # Interpolate the damage value for the given representative_damage_percentile - interpolated_damage = _interpolator(representative_damage_percentile) + # Interpolate the damage value for the given representative_damage_percentage + interpolated_damage = _interpolator(representative_damage_percentage) return interpolated_damage @@ -331,7 +331,7 @@ def interpolate_damage(row, representative_damage_percentile): cols_to_scale = ["lower_damage", "upper_damage"] df = scale_damage_using_lanes(lane_scale_factors, df, cols_to_scale) - # create separate column for each percentile of construction costs (is faster then tuple) + # create separate column for each percentage of construction costs (is faster then tuple) for percentage in [ 0, 25, @@ -376,7 +376,7 @@ def interpolate_damage(row, representative_damage_percentile): df[f"dam_{curve_name}_{event}_representative"] = ( df.apply( lambda row: interpolate_damage( - row, self.representative_damage_percentile + row, self.representative_damage_percentage ), axis=1, ) diff --git a/ra2ce/analysis/damages/damage_calculation/damage_network_events.py b/ra2ce/analysis/damages/damage_calculation/damage_network_events.py index 04a750daa..58a9ee0e9 100644 --- a/ra2ce/analysis/damages/damage_calculation/damage_network_events.py +++ b/ra2ce/analysis/damages/damage_calculation/damage_network_events.py @@ -42,10 +42,10 @@ def __init__( self, road_gdf: GeoDataFrame, val_cols: list[str], - representative_damage_percentile: float, + representative_damage_percentage: float, ): # Construct using the parent class __init__ - super().__init__(road_gdf, val_cols, representative_damage_percentile) + super().__init__(road_gdf, val_cols, representative_damage_percentage) self.events = set([x.split("_")[1] for x in val_cols]) # set of unique events if not any(self.events): diff --git a/ra2ce/analysis/damages/damage_calculation/damage_network_return_periods.py b/ra2ce/analysis/damages/damage_calculation/damage_network_return_periods.py index 49429bf7e..c15f63e37 100644 --- a/ra2ce/analysis/damages/damage_calculation/damage_network_return_periods.py +++ b/ra2ce/analysis/damages/damage_calculation/damage_network_return_periods.py @@ -50,10 +50,10 @@ def __init__( self, road_gdf: GeoDataFrame, val_cols: list[str], - representative_damage_percentile: float, + representative_damage_percentage: float, ): # Construct using the parent class __init__ - super().__init__(road_gdf, val_cols, representative_damage_percentile) + super().__init__(road_gdf, val_cols, representative_damage_percentage) self.return_periods = set( [x.split("_")[1] for x in val_cols] @@ -64,13 +64,13 @@ def __init__( @classmethod def construct_from_csv( - cls, path: Path, representative_damage_percentile: float, sep: str = ";" + cls, path: Path, representative_damage_percentage: float, sep: str = ";" ): road_gdf = pd.read_csv(path, sep=sep) val_cols = [ c for c in road_gdf.columns if c.startswith("F_") ] # Find everything starting with 'F' - return cls(road_gdf, val_cols, representative_damage_percentile) + return cls(road_gdf, val_cols, representative_damage_percentage) ### Controlers for return period based damage and risk calculations def main(self, damage_function: DamageCurveEnum, manual_damage_functions): diff --git a/ra2ce/analysis/damages/damages.py b/ra2ce/analysis/damages/damages.py index 6565072e8..3c8569477 100644 --- a/ra2ce/analysis/damages/damages.py +++ b/ra2ce/analysis/damages/damages.py @@ -84,7 +84,7 @@ def _rename_road_gdf_to_conventions(road_gdf_columns: list[str]) -> list[str]: # Choose between event or return period based analysis if self.analysis.event_type == EventTypeEnum.EVENT: event_gdf = DamageNetworkEvents( - road_gdf, val_cols, self.analysis.representative_damage_percentile + road_gdf, val_cols, self.analysis.representative_damage_percentage ) event_gdf.main( damage_function=damage_function, @@ -95,7 +95,7 @@ def _rename_road_gdf_to_conventions(road_gdf_columns: list[str]) -> list[str]: elif self.analysis.event_type == EventTypeEnum.RETURN_PERIOD: return_period_gdf = DamageNetworkReturnPeriods( - road_gdf, val_cols, self.analysis.representative_damage_percentile + road_gdf, val_cols, self.analysis.representative_damage_percentage ) return_period_gdf.main( damage_function=damage_function, diff --git a/ra2ce/analysis/damages/damages_lookup.py b/ra2ce/analysis/damages/damages_lookup.py index fda9fe46b..0b54ee713 100644 --- a/ra2ce/analysis/damages/damages_lookup.py +++ b/ra2ce/analysis/damages/damages_lookup.py @@ -19,7 +19,6 @@ along with this program. If not, see . """ - import os from collections import OrderedDict from pathlib import Path diff --git a/ra2ce/network/avg_speed/avg_speed_calculator.py b/ra2ce/network/avg_speed/avg_speed_calculator.py index 5f6669ed2..0b729f49f 100644 --- a/ra2ce/network/avg_speed/avg_speed_calculator.py +++ b/ra2ce/network/avg_speed/avg_speed_calculator.py @@ -14,7 +14,9 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ + import logging +import math from pathlib import Path from re import split from statistics import mean @@ -41,13 +43,13 @@ def __init__(self, graph: nx.Graph, output_graph_dir: Path | None) -> None: self.avg_speed = self._calculate(output_graph_dir) @staticmethod - def parse_speed(speed_input: str | list[str]) -> float: + def parse_speed(speed_input: float | str | list[str]) -> float: """ Parse the average speed from the input string(s). Args: - speed_input (str | list[str]): (List of) string(s) with the speed(s). - Can have different formats, e.g. "50 mph", "50", "50;60", "50-60", "50|60". + speed_input (float | str | list[str]): (List of) string(s) with the speed(s). + Can have different formats, e.g. nan(float), 30.0, "50 mph", "50", "50;60", "50-60", "50|60". Returns: float: Average speed of the input string(s). @@ -55,6 +57,10 @@ def parse_speed(speed_input: str | list[str]) -> float: """ if not speed_input: return 0.0 + if isinstance(speed_input, float): + if math.isnan(speed_input): + return 0.0 + return speed_input if isinstance(speed_input, list): return mean(map(AvgSpeedCalculator.parse_speed, speed_input)) if " mph" in speed_input: diff --git a/ra2ce/network/network_config_data/network_config_data.py b/ra2ce/network/network_config_data/network_config_data.py index ec8985c7d..eb7aea48c 100644 --- a/ra2ce/network/network_config_data/network_config_data.py +++ b/ra2ce/network/network_config_data/network_config_data.py @@ -50,6 +50,7 @@ class NetworkSection: polygon: Optional[Path] = None network_type: NetworkTypeEnum = field(default_factory=lambda: NetworkTypeEnum.NONE) road_types: list[RoadTypeEnum] = field(default_factory=list) + attributes_to_exclude_in_simplification: list[str] = field(default_factory=list) save_gpkg: bool = False diff --git a/ra2ce/network/network_config_data/network_config_data_reader.py b/ra2ce/network/network_config_data/network_config_data_reader.py index a35cca5c8..d46c6b207 100644 --- a/ra2ce/network/network_config_data/network_config_data_reader.py +++ b/ra2ce/network/network_config_data/network_config_data_reader.py @@ -171,6 +171,9 @@ def get_network_section(self) -> NetworkSection: ) ) _network_section.polygon = self._get_str_as_path(_network_section.polygon) + _network_section.attributes_to_exclude_in_simplification = self._parser.getlist( + _section, "attributes_to_exclude_in_simplification", fallback=[] + ) return _network_section def get_origins_destinations_section(self) -> OriginsDestinationsSection: diff --git a/ra2ce/network/network_simplification/__init__.py b/ra2ce/network/network_simplification/__init__.py new file mode 100644 index 000000000..a47d4e0a2 --- /dev/null +++ b/ra2ce/network/network_simplification/__init__.py @@ -0,0 +1,3 @@ +from ra2ce.network.network_simplification.network_graph_simplificator import ( + NetworkGraphSimplificator, +) diff --git a/ra2ce/network/network_simplification/network_graph_simplificator.py b/ra2ce/network/network_simplification/network_graph_simplificator.py new file mode 100644 index 000000000..71d6a354f --- /dev/null +++ b/ra2ce/network/network_simplification/network_graph_simplificator.py @@ -0,0 +1,189 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +import logging +from dataclasses import dataclass + +import networkx as nx +from tqdm import tqdm + +from ra2ce.network.network_simplification.network_simplification_with_attribute_exclusion import ( + NetworkSimplificationWithAttributeExclusion, +) +from ra2ce.network.network_simplification.network_simplification_without_attribute_exclusion import ( + NetworkSimplificationWithoutAttributeExclusion, +) + +NxGraph = nx.Graph | nx.MultiGraph | nx.MultiDiGraph + + +@dataclass(kw_only=True) +class NetworkGraphSimplificator: + """ + Factory dataclass to simplify the containing graph. + """ + + graph_complex: NxGraph + attributes_to_exclude: list[str] + new_id: str = "rfid" + + def simplify( + self, + ) -> tuple[nx.Graph, nx.Graph, tuple[dict, dict]]: + """ + Create a simplified graph with unique ids from a complex graph + + Returns: + tuple[nx.Graph, nx.Graph, tuple[dict, dict]]: The simple and complex graph and the "id" tables. + """ + logging.info("Simplifying graph") + try: + _graph_complex = self._graph_create_unique_ids( + self.graph_complex, "{}_c".format(self.new_id) + ) + _graph_simple = self._get_graph_simple() + + # Create look_up_tables between graphs with unique ids + ( + _simple_to_complex, + _complex_to_simple, + ) = self._graph_link_simple_id_to_complex(_graph_simple) + + # Store id table and add simple ids to complex graph + _id_tables = (_simple_to_complex, _complex_to_simple) + _graph_complex = self._add_simple_id_to_graph_complex( + _graph_complex, _complex_to_simple, self.new_id + ) + logging.info("Simplified graph successfully created") + except Exception as exc: + _graph_simple = None + _id_tables = None + logging.error("Did not create a simplified version of the graph (%s)", exc) + return _graph_simple, _graph_complex, _id_tables + + def _get_graph_simple(self) -> NxGraph: + if any(self.attributes_to_exclude): + _graph_simple = NetworkSimplificationWithAttributeExclusion( + nx_graph=self.graph_complex, + attributes_to_exclude=self.attributes_to_exclude, + ).simplify_graph() + else: + self.graph_complex = ( + self.graph_complex.to_directed() + ) # simplification function requires nx.MultiDiGraph + + # Create simplified graph and add unique ids + _graph_simple = NetworkSimplificationWithoutAttributeExclusion( + nx_graph=self.graph_complex + ).simplify_graph() + + return self._graph_create_unique_ids(_graph_simple, self.new_id) + + def _graph_create_unique_ids( + self, graph: nx.Graph, new_id_name: str = "rfid" + ) -> nx.Graph: + # Check if new_id_name exists and if unique + u, v, k = list(graph.edges)[0] + if new_id_name in graph.edges[u, v, k]: + return graph + # TODO: decide if we always add a new ID (in iGraph this is different) + # if len(set([str(e[-1][new_id_name]) for e in graph.edges.data(keys=True)])) < len(graph.edges()): + for i, (u, v, k) in enumerate(graph.edges(keys=True)): + graph[u][v][k][new_id_name] = i + 1 + logging.info("Added a new unique identifier field '%s'.", new_id_name) + return graph + + def _add_simple_id_to_graph_complex( + self, complex_graph: nx.classes.Graph, complex_to_simple, new_id + ) -> nx.classes.Graph: + """Adds the appropriate ID of the simple graph to each edge of the complex graph as a new attribute 'rfid' + + Arguments: + complex_graph (Graph) : The complex graph, still lacking 'rfid' + complex_to_simple (dict) : lookup table linking complex to simple graphs + + Returns: + complex_graph (Graph) : Same object, with added attribute 'rfid' + + """ + + obtained_complex_ids = nx.get_edge_attributes( + complex_graph, "{}_c".format(new_id) + ) # {(u,v,k) : 'rfid_c'} + simple_ids_per_complex_id = obtained_complex_ids # start with a copy + + for key, value in obtained_complex_ids.items(): # {(u,v,k) : 'rfid_c'} + try: + new_value = complex_to_simple[ + value + ] # find simple id belonging to the complex id + simple_ids_per_complex_id[key] = new_value + except KeyError as e: + logging.error( + "Could not find the simple ID belonging to complex ID %s; value set to None. Full error: %s", + key, + e, + ) + simple_ids_per_complex_id[key] = None + + # Now the format of simple_ids_per_complex_id is: {(u,v,k) : 'rfid} + nx.set_edge_attributes(complex_graph, simple_ids_per_complex_id, new_id) + + return complex_graph + + def _graph_link_simple_id_to_complex(self, graph_simple: nx.classes.graph.Graph): + """ + Create lookup tables (dicts) to match edges_ids of the complex and simple graph + Optionally, saves these lookup tables as json files. + + Arguments: + graph_simple (Graph) : Graph, containing attribute 'new_id' + + Returns: + simple_to_complex (dict): Keys are ids of the simple graph, values are lists with all matching complex ids + complex_to_simple (dict): Keys are the ids of the complex graph, value is the matching simple_ID + + We need this because the simple graph is derived from the complex graph, and therefore initially only the + simple graph knows from which complex edges it was created. To assign this information also to the complex + graph we invert the look-up dictionary + @author: Kees van Ginkel en Margreet van Marle + """ + # Iterate over the simple, because this already has the corresponding complex information + lookup_dict = {} + # keys are the ids of the simple graph, values are lists with all matching complex id's + for u, v, k in tqdm(graph_simple.edges(keys=True)): + key_1 = graph_simple[u][v][k]["{}".format(self.new_id)] + value_1 = graph_simple[u][v][k]["{}_c".format(self.new_id)] + lookup_dict[key_1] = value_1 + + inverted_lookup_dict = {} + # keys are the ids of the complex graph, value is the matching simple_ID + for key, value in lookup_dict.items(): + if isinstance(value, list): + for subvalue in value: + inverted_lookup_dict[subvalue] = key + elif isinstance(value, int): + inverted_lookup_dict[value] = key + + simple_to_complex = lookup_dict + complex_to_simple = inverted_lookup_dict + + logging.info("Lookup tables from complex to simple and vice versa were created") + return simple_to_complex, complex_to_simple diff --git a/ra2ce/network/network_simplification/network_simplification_with_attribute_exclusion.py b/ra2ce/network/network_simplification/network_simplification_with_attribute_exclusion.py new file mode 100644 index 000000000..9d44f09d1 --- /dev/null +++ b/ra2ce/network/network_simplification/network_simplification_with_attribute_exclusion.py @@ -0,0 +1,55 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from dataclasses import dataclass + +import networkx as nx + +from ra2ce.network.network_simplification.snkit_network_wrapper import ( + SnkitNetworkWrapper, +) + +NxGraph = nx.Graph | nx.MultiGraph | nx.MultiDiGraph + + +@dataclass(kw_only=True) +class NetworkSimplificationWithAttributeExclusion: + """ + Simplifies a network by excluding a given set of attributes (columns). + """ + + nx_graph: NxGraph + attributes_to_exclude: list[str] + + def simplify_graph(self) -> nx.Graph: + """ + Simplifies the inner graph by using the `snkit` package. + + Returns: + nx.Graph: Resulting simplified graph. + """ + _snkit_network_wrapper = SnkitNetworkWrapper.from_networkx( + self.nx_graph, + column_names_dict=dict( + node_id_column_name="id", + edge_from_id_column_name="from_id", + edge_to_id_column_name="to_id", + ), + ) + _snkit_network_wrapper.merge_edges(self.attributes_to_exclude) + _snkit_network_wrapper.process_network() + return _snkit_network_wrapper.to_networkx() diff --git a/ra2ce/network/network_simplification/network_simplification_without_attribute_exclusion.py b/ra2ce/network/network_simplification/network_simplification_without_attribute_exclusion.py new file mode 100644 index 000000000..70e2d47e5 --- /dev/null +++ b/ra2ce/network/network_simplification/network_simplification_without_attribute_exclusion.py @@ -0,0 +1,59 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import logging +from dataclasses import dataclass + +import networkx as nx +from osmnx.simplification import simplify_graph + +from ra2ce.network.networks_utils import add_x_y_to_nodes + +NxGraph = nx.Graph | nx.MultiGraph | nx.MultiDiGraph + + +@dataclass(kw_only=True) +class NetworkSimplificationWithoutAttributeExclusion: + """ + Simplifies a network with the `osmnx.simplification` functionality. + """ + + nx_graph: NxGraph + + def simplify_graph( + self, + ) -> nx.Graph: + """ + Simplify the graph after adding missing x and y attributes to nodes + + Returns: + nx.Graph: Simplified graph + """ + _complex_graph = add_x_y_to_nodes(self.nx_graph) + _simple_graph = simplify_graph( + _complex_graph, strict=True, remove_rings=True, track_merged=False + ) + + logging.info( + "Graph simplified from %s to %s nodes and %s to %s edges.", + _complex_graph.number_of_nodes(), + _simple_graph.number_of_nodes(), + _complex_graph.number_of_edges(), + _simple_graph.number_of_edges(), + ) + + return _simple_graph diff --git a/ra2ce/network/network_simplification/nx_to_snkit_network_converter.py b/ra2ce/network/network_simplification/nx_to_snkit_network_converter.py new file mode 100644 index 000000000..17e1d47c8 --- /dev/null +++ b/ra2ce/network/network_simplification/nx_to_snkit_network_converter.py @@ -0,0 +1,172 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from collections import defaultdict +from dataclasses import dataclass + +import geopandas as gpd +import networkx as nx +from geopandas import GeoDataFrame +from numpy import int64 as npInt64 +from shapely.geometry import LineString, Point +from snkit.network import Network as SnkitNetwork + +NxGraph = nx.Graph | nx.MultiGraph | nx.MultiDiGraph + + +@dataclass(kw_only=True) +class NxToSnkitNetworkConverter: + """ + Class responsible to convert a `networkx.MultiGraph` into + a matching `snkit.network.Network`. + """ + + networkx_graph: NxGraph + node_id_column_name: str = "id" + edge_from_id_column_name: str = "from_id" + edge_to_id_column_name: str = "to_id" + + def convert(self) -> SnkitNetwork: + """ + Converts a regular `NetworkX.graph` into a `snkit.network.Network` object. + + Returns: + SnkitNetwork: The resulting `snkit.network.Network` converted object. + """ + # Extract graph values + _crs = self.networkx_graph.graph.get("crs", None) + + # Create new network + snkit_network = SnkitNetwork() + node_attributes = [ + {self.node_id_column_name: node, **data} + for node, data in self.networkx_graph.nodes(data=True) + ] + snkit_network.nodes = GeoDataFrame(node_attributes) + snkit_network.nodes = self._check_and_create_node_geometries( + snkit_network.nodes + ) + snkit_network.nodes.set_geometry("geometry", inplace=True, crs=_crs) + + edge_attributes = [ + {self.edge_from_id_column_name: u, self.edge_to_id_column_name: v, **data} + for u, v, data in self.networkx_graph.edges(data=True) + ] + snkit_network.edges = GeoDataFrame(edge_attributes) + snkit_network = self._check_and_create_edge_geometries(snkit_network) + snkit_network.edges.set_geometry("geometry", inplace=True, crs=_crs) + + # Set network CRS to default_crs + snkit_network.set_crs(_crs) + + # Checks + snkit_network = self._check_edge_ids(snkit_network) + snkit_network = self._get_nodes_degree(snkit_network) + + # Return converted and validated network + return snkit_network + + def _check_edge_ids(self, network: SnkitNetwork) -> SnkitNetwork: + if not id in network.edges.columns: + network.edges["id"] = network.edges.index + return network + + def _get_nodes_degree(self, network: SnkitNetwork) -> SnkitNetwork: + def _calculate_degree(snkit_network: SnkitNetwork) -> dict: + degrees = defaultdict(int) + + from_ids = snkit_network.edges["from_id"].to_numpy(dtype=npInt64) + for from_id in from_ids: + degrees[from_id] += 1 + + to_ids = snkit_network.edges["to_id"].to_numpy(dtype=npInt64) + for to_id in to_ids: + degrees[to_id] += 1 + + return degrees + + degrees = _calculate_degree(network) + network.nodes["degree"] = network.nodes["id"].apply( + lambda node_id: degrees.get(node_id, 0) + ) + return network + + def _check_and_create_node_geometries( + self, + geo_dataframe: gpd.GeoDataFrame, + ) -> gpd.GeoDataFrame: + """ + Check if there is a geometry column in network.nodes. + If not, check if both 'x' and 'y' columns are present. + If both are present, create Point geometries for each row. + If either 'x' or 'y' is missing, raise a ValueError. + + Parameters: + gdf (snkit.network.Network): The gdf object containing nodes as a GeoDataFrame. + + Raises: + ValueError: If neither geometry column nor both 'x' and 'y' columns are present in network.nodes. + """ + if "geometry" in geo_dataframe.columns: + return geo_dataframe + + if "x" in geo_dataframe.columns and "y" in geo_dataframe.columns: + geo_dataframe["geometry"] = geo_dataframe.apply( + lambda row: Point(row["x"], row["y"]), axis=1 + ) + return geo_dataframe + else: + raise ValueError( + "The network nodes must contain either a 'geometry' column or both 'x' and 'y' columns." + ) + + def _check_and_create_edge_geometries( + self, + network: SnkitNetwork, + ) -> SnkitNetwork: + """ + Creates a GEOMETRY attribute for each edge in the graph using the geometries of the nodes. + + Parameters: + G (nx.Graph): The NetworkX graph with nodes having geometries. + + Returns: + nx.Graph: The NetworkX graph with edges having GEOMETRY attributes. + """ + + def create_linestring(row): + from_geom = node_geometries.get(row["from_id"]) + to_geom = node_geometries.get(row["to_id"]) + if from_geom is None or to_geom is None: + raise ValueError( + f"Geometry missing for from_id {row['from_id']} or to_id {row['to_id']}." + ) + return LineString([from_geom, to_geom]) + + if "geometry" in network.edges.columns: + return network + # Check if nodes have geometries + if not ("geometry" in network.nodes.columns): + network.nodes = self._check_and_create_node_geometries(network.nodes) + + # Convert nodes to a dictionary for fast lookup + node_geometries = network.nodes.set_index("id")["geometry"].to_dict() + + # Apply the function to create the geometry column + network.edges["geometry"] = network.edges.apply(create_linestring, axis=1) + + return network diff --git a/ra2ce/network/network_simplification/snkit_network_merge_wrapper.py b/ra2ce/network/network_simplification/snkit_network_merge_wrapper.py new file mode 100644 index 000000000..3923a8324 --- /dev/null +++ b/ra2ce/network/network_simplification/snkit_network_merge_wrapper.py @@ -0,0 +1,676 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import logging +from collections import defaultdict + +import geopandas as gpd +import networkx as nx +import pandas as pd +from shapely.geometry import LineString, MultiLineString, MultiPoint, Point +from shapely.ops import linemerge +from snkit.network import Network as SnkitNetwork +from tqdm import tqdm + +NxGraph = nx.Graph | nx.MultiGraph | nx.MultiDiGraph + + +""" +Disclaimer! + +This file contains several complex logic introduced in feature #277. +At the moment it was not possible to streamline / improve it further than +its current state. +""" + + +def merge_edges( + snkit_network: SnkitNetwork, + aggregate_func: str | dict, + by: str | list, + id_col: str, +) -> SnkitNetwork: + """ + Merges the edges of a given `snkit.network.Network`. + + Args: + snkit_network (SnkitNetwork): network to merge. + aggregate_func (str | dict): Aggregation function to apply. + by (str | list): Arguments (column names). + id_col (str, optional): Name of the column representing the 'id'. + + Returns: + SnkitNetwork: _description_ + """ + + def _node_connectivity_degree(node, snkit_network: SnkitNetwork) -> int: + return len( + snkit_network.edges[ + (snkit_network.edges.from_id == node) + | (snkit_network.edges.to_id == node) + ] + ) + + def _get_edge_ids_to_update(edges_list: list) -> list: + ids_to_update = [] + for edges in edges_list: + ids_to_update.extend(edges.id.tolist()) + return ids_to_update + + def _get_merged_edges( + paths_to_group: list, + by: list, + aggfunc: str | dict, + net: SnkitNetwork, + ) -> gpd.GeoDataFrame: + updated_edges = gpd.GeoDataFrame( + columns=net.edges.columns, crs=net.edges.crs + ) # merged edges + + for edge_path in tqdm(paths_to_group, desc="merge_edge_paths"): + # Convert None values to a placeholder value + placeholder = "None" + for col in by: + edge_path[col] = edge_path[col].fillna(placeholder) + merged_edges = _get_merge_edge_paths(edge_path, by, aggfunc, net) + updated_edges = pd.concat([updated_edges, merged_edges], ignore_index=True) + + updated_edges_gdf = gpd.GeoDataFrame(updated_edges, geometry="geometry") + updated_edges_gdf.set_crs(net.edges.crs, inplace=True) + updated_edges_gdf = updated_edges_gdf.drop(columns=["id"]) + return updated_edges_gdf + + def _get_edge_paths(node_set: set, snkit_network: SnkitNetwork) -> list: + # Convert edges to an adjacency list using vectorized operations + edge_dict = defaultdict(set) + from_ids = snkit_network.edges["from_id"].values + to_ids = snkit_network.edges["to_id"].values + + for from_id, to_id in zip(from_ids, to_ids): + edge_dict[from_id].add(to_id) + edge_dict[to_id].add(from_id) + + edge_paths = [] + + while node_set: + popped_node = node_set.pop() + node_path = {popped_node} + candidates = {popped_node} + while candidates: + popped_cand = candidates.pop() + matches = edge_dict[popped_cand] + matches = matches - node_path + for match in matches: + if match in node_set: + candidates.add(match) + node_path.add(match) + node_set.remove(match) + else: + node_path.add(match) + if len(node_path) >= 2: + edge_paths.append( + snkit_network.edges.loc[ + (snkit_network.edges.from_id.isin(node_path)) + & (snkit_network.edges.to_id.isin(node_path)) + ] + ) + return edge_paths + + if "degree" not in snkit_network.nodes.columns: + snkit_network.nodes["degree"] = snkit_network.nodes[id_col].apply( + lambda x: _node_connectivity_degree(x, snkit_network) + ) + + degree_2 = list(snkit_network.nodes[id_col].loc[snkit_network.nodes.degree == 2]) + degree_2_set = set(degree_2) + edge_paths = _get_edge_paths(degree_2_set, snkit_network) + + edge_ids_to_update = _get_edge_ids_to_update(edge_paths) + edges_to_keep = snkit_network.edges[ + ~snkit_network.edges["id"].isin(edge_ids_to_update) + ] + + updated_edges = _get_merged_edges( + paths_to_group=edge_paths, + by=by, + aggfunc=aggregate_func, + net=snkit_network, + ) + edges_to_keep = edges_to_keep.drop(columns=["id"]) + updated_edges = updated_edges.reset_index(drop=True) + + new_edges = pd.concat([edges_to_keep, updated_edges], ignore_index=True) + new_edges_gdf = gpd.GeoDataFrame(new_edges, geometry="geometry") + new_edges_gdf.set_crs(edges_to_keep.crs, inplace=True) + new_edges_gdf = new_edges_gdf.reset_index(drop=True) + + nodes_to_keep = list(set(new_edges.from_id.tolist() + new_edges.to_id.tolist())) + new_nodes_gdf = snkit_network.nodes[snkit_network.nodes[id_col].isin(nodes_to_keep)] + new_nodes_gdf = new_nodes_gdf.reset_index(drop=True) + + merged_snkit_network = SnkitNetwork(nodes=new_nodes_gdf, edges=new_edges_gdf) + merged_snkit_network.nodes["degree"] = merged_snkit_network.nodes[id_col].apply( + lambda x: _node_connectivity_degree(x, merged_snkit_network) + ) + + return merged_snkit_network + + +def _merge_connected_lines( + gdf: gpd.GeoDataFrame, by: str, aggfunc: dict +) -> gpd.GeoDataFrame: + """ + Merge connected lines in a GeoDataFrame into a single LineString. + + Parameters: + gdf (gpd.GeoDataFrame): GeoDataFrame containing the lines to merge. + by (str): Column name to group by. + aggfunc (dict): Dictionary of aggregation functions for other columns. + + Returns: + gpd.GeoDataFrame: GeoDataFrame with merged lines. + """ + + # Merge all geometries into a single MultiLineString + merged_geometry = linemerge(gdf.geometry.tolist()) + indices = gdf[by].iloc[0] + # Create a new GeoDataFrame with the merged geometry + merged_gdf = gpd.GeoDataFrame( + [{**indices, "geometry": merged_geometry}], crs=gdf.crs + ) + merged_gdf.set_index(by, inplace=True) + + # Combine the attributes using the aggregation function + for col, func in aggfunc.items(): + if col != "geometry": + # Try to convert the column to float if needed + if gdf[col].dtype != float: + try: + gdf[col] = gdf[col].astype(float) + except ValueError: + pass # Skip conversion if it fails + merged_gdf[col] = [func(gdf[col])] + + return merged_gdf + + +def _get_merge_edge_paths( + edges: gpd.GeoDataFrame, + excluded_edge_types: list, + aggfunc: str | dict, + snkit_network: SnkitNetwork, +) -> gpd.GeoDataFrame: + def get_connected_lines(ids: pd.Index): + """ + Find groups of connected lines in a GeoDataFrame. + + Parameters: + gdf (GeoDataFrame): A GeoDataFrame containing LINESTRING geometries. + + Returns: + list of lists: Each sublist contains indices of lines in gdf that are connected. + """ + + # Initialize an empty graph + _networkx_graph = nx.Graph() + gdf = edges.loc[ids.tolist()] + # Add edges to the graph for each line in the GeoDataFrame + for idx, row in gdf.iterrows(): + # Get the start and end points of the line + line = row["geometry"] + start_point = line.coords[0] + end_point = line.coords[-1] + + # Add the line as an edge between its start and end points with the id as attribute + _networkx_graph.add_edge(start_point, end_point, index=idx, id=row["id"]) + + # Find connected components in the graph + connected_components = list(nx.connected_components(_networkx_graph)) + + # Map each component to the corresponding line ids + connected_line_groups = [] + for component in connected_components: + line_ids = [] + for _, _, data in _networkx_graph.edges(component, data=True): + line_ids.append(data["id"]) + connected_line_groups.append(line_ids) + + return connected_line_groups + + def _get_paths_to_merge(groups: dict) -> list: + _paths_to_merge = [] # list of gpds to merge + for _, edge_group_ids in groups.items(): + sub_path_parts = get_connected_lines(edge_group_ids) + _paths_to_merge.extend(sub_path_parts) + return _paths_to_merge + + # _get_merged_paths starts here + grouped_edges = edges.groupby(excluded_edge_types) + if len(grouped_edges.groups) == 1: + merged_edges = GdfSnkitNetworkMerger( + geo_dataframe=edges, snkit_network=snkit_network + ).merge( + by=excluded_edge_types, + aggregate_func=aggfunc, + ) + else: + merged_edges = gpd.GeoDataFrame( + columns=edges.columns, crs=edges.crs + ) # merged edges + edge_groups = edges.groupby(excluded_edge_types).groups + paths_to_merge = _get_paths_to_merge(edge_groups) + + for path_ids in paths_to_merge: + path_to_merge = edges[ + edges["id"].isin(path_ids) + ].copy() # indices of the edges in edges gdf + merged = GdfSnkitNetworkMerger( + geo_dataframe=path_to_merge, snkit_network=snkit_network + ).merge( + by=excluded_edge_types, + aggregate_func=aggfunc, + ) + merged_edges = pd.concat([merged_edges, merged], ignore_index=True) + + merged_edges.crs = edges.crs + + return merged_edges + + +class GdfSnkitNetworkMerger: + """ + Merger of a `gpd.GeoDataFrame` and a `snkit.network.network`. + This class was created to contain the related close and try reducing the code's complexity. + """ + + def __init__( + self, geo_dataframe: gpd.GeoDataFrame, snkit_network: SnkitNetwork + ) -> None: + self._geo_dataframe = geo_dataframe + self._snkit_network = snkit_network + + def merge( + self, + by: list, + aggregate_func: dict, + ) -> gpd.GeoDataFrame: + """ + Merges the inner defined `gpd.GeoDataFrame` and `snkit.network.Network` + based on the given arguments (`by`) and aggregation function ()`aggregate_func`). + """ + _geo_dataframe = self._geo_dataframe + _snkit_network = self._snkit_network + # _merge function starts from here: + self._geo_dataframe["intersections"] = _geo_dataframe.apply( + lambda x: self._get_intersections(x, _geo_dataframe), axis=1 + ) + # _merged = gdf.dissolve(by=by, aggfunc=_aggfunc, sort=False) + _merged = _merge_connected_lines(_geo_dataframe, by, aggregate_func) + if len(_geo_dataframe) == 1: + # 1. no merging is occurring + start_path_extremities = [_geo_dataframe.iloc[0]["from_id"]] + end_path_extremities = [_geo_dataframe.iloc[0]["to_id"]] + _merged.from_id = start_path_extremities[0] + _merged.to_id = end_path_extremities[0] + else: + # 2. merging is occurring + if ( + len( + _geo_dataframe[ + _geo_dataframe["intersections"].apply(lambda x: len(x) == 1) + ] + ) + == 0 + ): + # 2.1. a loop with two nodes degree > 2 + gdf_node_ids = list( + set(_geo_dataframe.from_id.tolist() + _geo_dataframe.to_id.tolist()) + ) + gdf_node_slice = _snkit_network.nodes[ + _snkit_network.nodes["id"].isin(gdf_node_ids) + ] + if len(gdf_node_slice[gdf_node_slice["degree"] > 2]) == 0: + # 2.1.1. a loop with only degree 2 edges => isolated from the rest of the graph + logging.warning( + """ + A sub-graph loop isolated from the main graph is detected and removed. + This isolated part had %s nodes with node_fids {gdf_node_slice.id.tolist()} in + the input node graph. + """, + len(gdf_node_slice), + ) + if "demand_edge" in _geo_dataframe.columns: + logging.warning( + """'This sub-graph had these demand nodes %s""", + ( + _geo_dataframe[ + _geo_dataframe.demand_edge == 1 + ].from_id.tolist() + + _geo_dataframe[ + _geo_dataframe.demand_edge == 1 + ].to_id.tolist() + ), + ) + return gpd.GeoDataFrame( + data=None, + columns=_snkit_network.edges.columns, + crs=_snkit_network.edges.crs, + ) + + elif len(gdf_node_slice[gdf_node_slice["degree"] > 2]) == 1: + # 2.1.2. If there is only one node with the degree bigger than 2 + if ( + "demand_edge" not in _geo_dataframe.columns + or len(_geo_dataframe[_geo_dataframe["demand_edge"] == 1]) == 0 + ): + # No demand node is in this loop. Then omit this loop and return empty gdf + return gpd.GeoDataFrame( + data=None, + columns=_snkit_network.edges.columns, + crs=_snkit_network.edges.crs, + ) + elif ( + "demand_edge" in _geo_dataframe.columns + and len(_geo_dataframe[_geo_dataframe["demand_edge"] == 1]) > 0 + ): + demand_node_ids = [ + i + for i in set( + _geo_dataframe.from_id.tolist() + + _geo_dataframe.to_id.tolist() + ) + if ( + _geo_dataframe[ + _geo_dataframe.demand_edge == 1 + ].from_id.tolist() + + _geo_dataframe[ + _geo_dataframe.demand_edge == 1 + ].to_id.tolist() + ).count(i) + == 2 + ] + if len(demand_node_ids) > 1: + # merging this situation is skipped: not probable + complicated + return _geo_dataframe + else: + # Only one demand node exists in the loop + if isinstance( + linemerge(_merged.geometry.iloc[0]), MultiLineString + ): + # to exclude the merged geoms for which linemerge does not work + return _geo_dataframe + path_extremities_node_ids = { + x + for x in gdf_node_slice[ + gdf_node_slice["degree"] > 2 + ].id.tolist() + + demand_node_ids + } + _merged = self._get_merged_multiple_demand_edges( + _merged, path_extremities_node_ids + ) + else: + # 2.1.3. the only remaining option is two nodes with degrees bigger than 2 + if ( + "demand_edge" not in _geo_dataframe.columns + or len(_geo_dataframe[_geo_dataframe["demand_edge"] == 1]) == 0 + ): + # No demand node is in this loop. Then merge + _merged = self._get_merged_in_a_loop(_merged, gdf_node_slice) + else: + return _geo_dataframe + else: + # 2.2. merging non-loop paths + path_extremities_node_ids = { + i + for i in set( + _geo_dataframe.from_id.tolist() + _geo_dataframe.to_id.tolist() + ) + if ( + _geo_dataframe.from_id.tolist() + _geo_dataframe.to_id.tolist() + ).count(i) + == 1 + } + # if len(path_extremities_node_ids) > 0: + if ("demand_edge" in _geo_dataframe.columns) and ( + len(_geo_dataframe[_geo_dataframe["demand_edge"] == 1]) > 1 + ): + _merged = self._get_merged_multiple_demand_edges( + _merged, path_extremities_node_ids + ) + elif ( + "demand_edge" in _geo_dataframe.columns + and len(_geo_dataframe[_geo_dataframe["demand_edge"] == 1]) <= 1 + ) or ("demand_edge" not in _geo_dataframe.columns): + # 2.2.2.no dem node is in the to_be_merged path or only one dem node. In the later case dem node + # will not be dissolved because it is in the path_extremities_node_ids + _merged = self._get_merged_one_or_none_demand_edges( + _merged, path_extremities_node_ids + ) + else: + raise ValueError( + f"""Check the lines with the following ids {_geo_dataframe.id.tolist()} """ + ) + + _merged.node_A = _merged.from_id + _merged.node_B = _merged.to_id + _merged.crs = _geo_dataframe.crs + return _merged + + def _get_merged_in_a_loop( + self, _merged: gpd.GeoDataFrame, gdf_node_slice: pd.DataFrame + ) -> gpd.GeoDataFrame: + # 2.1.2. pick one with one intersection point + start_path_extrms = [gdf_node_slice[gdf_node_slice["degree"] > 2].iloc[0].id] + end_path_extrms = [gdf_node_slice[gdf_node_slice["degree"] > 2].iloc[1].id] + _merged.from_id = [ + start_path_extremity for start_path_extremity in start_path_extrms + ] + _merged.to_id = [end_path_extremity for end_path_extremity in end_path_extrms] + return _merged + + def _get_merged_multiple_demand_edges( + self, _merged: gpd.GeoDataFrame, path_extrms_nod_ids: set + ) -> gpd.GeoDataFrame: + def get_node_id(r: gpd.GeoSeries, attr: str, path_extrms_nod_ids: set) -> int: + # to fill from_id and to_id of the to-be-merged paths + if r[attr] == -1: + for path_extremities_node_id in path_extrms_nod_ids: + path_extremities_node_geom = self._snkit_network.nodes[ + self._snkit_network.nodes.id == path_extremities_node_id + ].geometry.iloc[0] + if r.geometry.intersects(path_extremities_node_geom): + return path_extremities_node_id + else: + return r[attr] + + _mrgd = self._get_split_edges_info(_merged) + _mrgd.from_id = _mrgd.apply( + lambda row: get_node_id(row, "from_id", path_extrms_nod_ids), axis=1 + ) + _mrgd.to_id = _mrgd.apply( + lambda row: get_node_id(row, "to_id", path_extrms_nod_ids), axis=1 + ) + return _mrgd + + def _get_split_edges_info(self, merged: gpd.GeoDataFrame) -> tuple: + # used for the cases where demand nodes exist in the to-be-merged paths + # make the demand node from_id of the merged edge + geo_dataframe = self._geo_dataframe + dem_nod_ids = [ + i + for i in set(geo_dataframe.from_id.tolist() + geo_dataframe.to_id.tolist()) + if ( + geo_dataframe[geo_dataframe.demand_edge == 1].from_id.tolist() + + geo_dataframe[geo_dataframe.demand_edge == 1].to_id.tolist() + ).count(i) + == 2 + ] + split_parts = [merged["geometry"].iloc[0]] + split_edges_gdf = gpd.GeoDataFrame(columns=merged.columns) + for dem_nod_id in dem_nod_ids: + for part in split_parts: + part_splits, split_edges_gdf = self._split( + merged, part, dem_nod_id, split_edges_gdf + ) + if part_splits is not None: + split_parts.extend(part_splits) + split_parts.remove(part) + return split_edges_gdf + + def _split( + self, + merged: gpd.GeoDataFrame, + line_geom: MultiLineString | LineString, + dem_nod_id: int, + splits_gdf: gpd.GeoDataFrame, + ) -> tuple: + # used for the cases where demand nodes exist in the to-be-merged paths + dem_nod_geom = self._snkit_network.nodes[ + self._snkit_network.nodes.id == dem_nod_id + ].geometry.iloc[0] + if line_geom.contains(dem_nod_geom): + if isinstance(line_geom, MultiLineString): + coords = [ + linemerge(line_geom).coords[0], + linemerge(line_geom).coords[-1], + ] + else: + coords = [line_geom.coords[0], line_geom.coords[-1]] + # Add the coords from the points + coords += dem_nod_geom.coords + # Calculate the distance along the line for each point + dists = [linemerge(line_geom).project(Point(p)) for p in coords] + # sort the coordinates + coords = [p for (d, p) in sorted(zip(dists, coords))] + splits = [ + LineString([coords[i], coords[i + 1]]) for i in range(len(coords) - 1) + ] + splits_gdf = self._update_split_edges_gdf( + merged, splits, dem_nod_id, splits_gdf, line_geom + ) + return splits, splits_gdf + else: + return None, splits_gdf + + def _update_split_edges_gdf( + self, + merged: gpd.GeoDataFrame, + parts: list, + dem_nod_id: int, + splt_edgs: gpd.GeoDataFrame, + _split_line_geom: LineString, + ) -> gpd.GeoDataFrame: + # used for the cases where demand nodes exist in the to-be-merged paths + for part in parts: + # _split_line_geom is the line divided and produced parts + if _split_line_geom not in splt_edgs.geometry.tolist(): + part_gdf = gpd.GeoDataFrame( + { + "geometry": part, + "id": len(splt_edgs), + "from_id": dem_nod_id, + "to_id": -1, + **merged.drop(columns=["geometry", "id", "from_id", "to_id"]), + } + ) + else: + # if _split_line_geom is divided and n stored in splt_edgs, we need to retrieve from/to_id info + # and update splt_edgs + part_gdf = gpd.GeoDataFrame( + { + "geometry": part, + "id": -1, + "from_id": splt_edgs[ + splt_edgs.geometry == _split_line_geom + ].apply( + lambda row: ( + row.from_id if row.from_id != -1 else dem_nod_id + ), + axis=1, + ), + "to_id": splt_edgs[ + splt_edgs.geometry == _split_line_geom + ].apply( + lambda row: (row.to_id if row.to_id != -1 else dem_nod_id), + axis=1, + ), + **merged.drop(columns=["geometry", "id", "from_id", "to_id"]), + } + ) + _split_line_index = splt_edgs.loc[ + splt_edgs.geometry == _split_line_geom + ].index[0] + splt_edgs = splt_edgs.drop(_split_line_index) + splt_edgs = pd.concat([splt_edgs, part_gdf], ignore_index=True) + return splt_edgs + + def _get_merged_one_or_none_demand_edges( + self, _merged, path_extrms_nod_ids: set + ) -> gpd.GeoDataFrame: + geo_dataframe = self._geo_dataframe + _start_edges = geo_dataframe[ + geo_dataframe["intersections"].apply(lambda x: len(x) == 1) + ] + if ("demand_edge" in geo_dataframe.columns) and ( + len(geo_dataframe[geo_dataframe["demand_edge"] == 1]) + ) == 1: + _start_edge = _start_edges[_start_edges.demand_edge == 1].iloc[0] + elif ("demand_edge" not in geo_dataframe.columns) or ( + len(geo_dataframe[geo_dataframe["demand_edge"] == 1]) + ) != 1: + _start_edge = _start_edges.iloc[0] + start_path_extrms = [ + ( + _start_edge["from_id"] + if _start_edge["from_id"] in list(path_extrms_nod_ids) + else _start_edge["to_id"] + ) + ] + end_path_extrms = [(path_extrms_nod_ids - set(start_path_extrms)).pop()] + _merged.from_id = [ + start_path_extremity for start_path_extremity in start_path_extrms + ] + _merged.to_id = [end_path_extremity for end_path_extremity in end_path_extrms] + return _merged + + def _get_intersections(self, _edge, _edges): + intersections = [] + edge_geometry = _edge.geometry.simplify(tolerance=1e-8) + + for _, other_edge in _edges.iterrows(): + other_edge_geometry = other_edge.geometry.simplify(tolerance=1e-8) + + if not edge_geometry.equals(other_edge_geometry): # avoid self-intersection + intersection = edge_geometry.intersection(other_edge_geometry) + + if not intersection.is_empty and any( + intersection.intersects(boundary) + for boundary in edge_geometry.boundary.geoms + ): + if isinstance(intersection, MultiPoint): + intersections.extend( + [ + point.coords[0] + for point in intersection.geoms + if point in other_edge_geometry.boundary.geoms + ] + ) + else: + intersections.append(intersection.coords[0]) + + return sorted(intersections, key=lambda x: x[0]) diff --git a/ra2ce/network/network_simplification/snkit_network_wrapper.py b/ra2ce/network/network_simplification/snkit_network_wrapper.py new file mode 100644 index 000000000..bae06d170 --- /dev/null +++ b/ra2ce/network/network_simplification/snkit_network_wrapper.py @@ -0,0 +1,171 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +from __future__ import annotations + +from dataclasses import dataclass + +import networkx as nx +from shapely import MultiLineString +from shapely.ops import linemerge +from snkit.network import Network + +from ra2ce.network.network_simplification.nx_to_snkit_network_converter import ( + NxToSnkitNetworkConverter, +) +from ra2ce.network.network_simplification.snkit_network_merge_wrapper import merge_edges +from ra2ce.network.network_simplification.snkit_to_nx_network_converter import ( + SnkitToNxNetworkConverter, +) +from ra2ce.network.networks_utils import line_length + +NxGraph = nx.Graph | nx.MultiGraph | nx.MultiDiGraph + + +@dataclass(kw_only=True) +class SnkitNetworkWrapper: + """ + Wrapper created to reduce complexity of conversion and processing of a `snkit.network.Network` + within the rest of the code. + """ + + snkit_network: Network + node_id_column_name: str = "id" + edge_from_id_column_name: str = "from_id" + edge_to_id_column_name: str = "to_id" + + @classmethod + def from_networkx( + cls, + networkx_graph: NxGraph, + column_names_dict: dict[str, str], + ) -> SnkitNetworkWrapper: + """ + Generates a `SnkitNetworkWrapper` based on the given `NxGraph`. + + Args: + networkx_graph (NxGraph): Graph to convert. + column_names_dict (dict[str, str]): Column names to use. + + Returns: + SnkitNetworkWrapper: Wrapper containing the converted `snkit.network.Network`. + """ + _snkit_converted_network = NxToSnkitNetworkConverter( + networkx_graph=networkx_graph, **column_names_dict + ).convert() + return cls( + snkit_network=_snkit_converted_network, + **column_names_dict, + ) + + def merge_edges(self, attributes_to_exclude: list[str]) -> None: + def filter_excluded_attributes() -> list[str]: + columns_set = set(self.snkit_network.edges.columns) + return [attr for attr in attributes_to_exclude if attr in columns_set] + + cols = [col for col in self.snkit_network.edges.columns if col != "geometry"] + _attributes_to_exclude = filter_excluded_attributes() + + if "demand_edge" not in _attributes_to_exclude: + _aggregate_function = self._aggfunc_with_demand_edge( + cols, _attributes_to_exclude + ) + else: + _aggregate_function = self._aggfunc_no_demand_edge( + cols, _attributes_to_exclude + ) + + # Overwrite the existing network with the merged edges. + self.snkit_network = merge_edges( + snkit_network=self.snkit_network, + aggregate_func=_aggregate_function, + by=_attributes_to_exclude, + id_col="id", + ) + + def process_network(self) -> None: + _network_crs = self.snkit_network.edges.crs + self.snkit_network.edges["length"] = self.snkit_network.edges["geometry"].apply( + lambda x: line_length(x, _network_crs) + ) # length in m + self.snkit_network.edges = self.snkit_network.edges[ + self.snkit_network.edges["length"] != 0 + ] # Remove zero-length edges + + def convert_to_line_string(geometry_to_convert) -> MultiLineString: + if isinstance(geometry_to_convert, MultiLineString): + return linemerge([line for line in geometry_to_convert.geoms]) + return geometry_to_convert + + self.snkit_network.edges["geometry"] = self.snkit_network.edges[ + "geometry" + ].apply(convert_to_line_string) + + def to_networkx(self) -> NxGraph: + """ + Converts the wrapped `snkit_network` into a corresponding `networkx.Graph`. + + Returns: + NxGraph: The converted graph. + """ + return SnkitToNxNetworkConverter( + snkit_network=self.snkit_network, + node_id_column_name=self.node_id_column_name, + edge_from_id_column_name=self.edge_from_id_column_name, + edge_to_id_column_name=self.edge_to_id_column_name, + ).convert() + + def _aggfunc_with_demand_edge(self, cols, attributes_to_exclude: list[str]): + def aggregate_column(col_data, col_name: str): + if col_name in attributes_to_exclude: + return col_data.iloc[0] + elif col_name == "rfid_c": + return list(col_data) + elif col_name in ["maxspeed", "avgspeed"]: + return col_data.mean() + elif col_name == "demand_edge": + return max(col_data) + elif col_data.dtype == "O": + return "; ".join( + str(item) for item in col_data if isinstance(item, str) + ) + else: + return col_data.iloc[0] + + return { + col: (lambda col_data, col_name=col: aggregate_column(col_data, col_name)) + for col in cols + } + + def _aggfunc_no_demand_edge(self, cols, attributes_to_exclude: list[str]): + def aggregate_column(col_data, col_name: str): + if col_name in attributes_to_exclude: + return col_data.iloc[0] + elif col_name == "rfid_c": + return list(col_data) + elif col_name in ["maxspeed", "avgspeed"]: + return col_data.mean() + elif col_data.dtype == "O": + return "; ".join( + str(item) for item in col_data if isinstance(item, str) + ) + else: + return col_data.iloc[0] + + return { + col: (lambda col_data, col_name=col: aggregate_column(col_data, col_name)) + for col in cols + } diff --git a/ra2ce/network/network_simplification/snkit_to_nx_network_converter.py b/ra2ce/network/network_simplification/snkit_to_nx_network_converter.py new file mode 100644 index 000000000..15f9128bd --- /dev/null +++ b/ra2ce/network/network_simplification/snkit_to_nx_network_converter.py @@ -0,0 +1,68 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from dataclasses import dataclass + +import networkx as nx +from snkit.network import Network as SnkitNetwork + + +@dataclass(kw_only=True) +class SnkitToNxNetworkConverter: + """ + Class responsible to convert a `snkit.network.Network` into + a matching `networkx.MultiGraph`. + """ + + snkit_network: SnkitNetwork + node_id_column_name: str = "id" + edge_from_id_column_name: str = "from_id" + edge_to_id_column_name: str = "to_id" + + def convert(self) -> nx.MultiGraph: + """ + Converts the given `snkit.network.Network` into a matching + `networkx.MultiGraph`. + + Args: + snkit_network (SnkitNetwork): The snkit network to convert. + + Returns: + `networkx.MultiGraph`: The converted graph. + """ + # Define new graph + _nx_graph = nx.MultiGraph() + _crs = self.snkit_network.edges.crs + + # Add nodes to the graph + for _, row in self.snkit_network.nodes.iterrows(): + node_id = row[self.node_id_column_name] + attributes = {k: v for k, v in row.items()} + _nx_graph.add_node(node_id, **attributes) + + # Add edges to the graph + for _, row in self.snkit_network.edges.iterrows(): + u = row[self.edge_from_id_column_name] + v = row[self.edge_to_id_column_name] + attributes = {k: v for k, v in row.items()} + _nx_graph.add_edge(u, v, **attributes) + + # Add CRS information to the graph + if "crs" not in _nx_graph.graph: + _nx_graph.graph["crs"] = _crs + + return _nx_graph diff --git a/ra2ce/network/network_wrappers/osm_network_wrapper/osm_network_wrapper.py b/ra2ce/network/network_wrappers/osm_network_wrapper/osm_network_wrapper.py index 25a136d47..432e4002d 100644 --- a/ra2ce/network/network_wrappers/osm_network_wrapper/osm_network_wrapper.py +++ b/ra2ce/network/network_wrappers/osm_network_wrapper/osm_network_wrapper.py @@ -34,6 +34,7 @@ from ra2ce.network.network_config_data.enums.network_type_enum import NetworkTypeEnum from ra2ce.network.network_config_data.enums.road_type_enum import RoadTypeEnum from ra2ce.network.network_config_data.network_config_data import NetworkConfigData +from ra2ce.network.network_simplification import NetworkGraphSimplificator from ra2ce.network.network_wrappers.network_wrapper_protocol import ( NetworkWrapperProtocol, ) @@ -54,6 +55,9 @@ class OsmNetworkWrapper(NetworkWrapperProtocol): road_types: list[RoadTypeEnum] def __init__(self, config_data: NetworkConfigData) -> None: + self.attributes_to_exclude_in_simplification = ( + config_data.network.attributes_to_exclude_in_simplification + ) self.output_graph_dir = config_data.output_graph_dir self.graph_crs = config_data.crs @@ -72,7 +76,6 @@ def with_polygon( clean graph as the `polygon_graph` property. Args: - config_data (NetworkConfigData): Basic configuration data which contain information required the different methods of this wrapper. polygon (BaseGeometry): Base polygon from which to generate the graph. Returns: @@ -134,9 +137,10 @@ def get_network_from_geojson(config_data: NetworkConfigData): def get_network(self) -> tuple[MultiGraph, GeoDataFrame]: # Create 'graph_simple' - graph_simple, graph_complex, link_tables = nut.create_simplified_graph( - self.polygon_graph - ) + graph_simple, graph_complex, link_tables = NetworkGraphSimplificator( + graph_complex=self.polygon_graph, + attributes_to_exclude=self.attributes_to_exclude_in_simplification, + ).simplify() # Assign the average speed and time to the graphs graph_simple = AvgSpeedCalculator(graph_simple, self.output_graph_dir).assign() diff --git a/ra2ce/network/network_wrappers/vector_network_wrapper.py b/ra2ce/network/network_wrappers/vector_network_wrapper.py index ad1344387..e4ecb0d4f 100644 --- a/ra2ce/network/network_wrappers/vector_network_wrapper.py +++ b/ra2ce/network/network_wrappers/vector_network_wrapper.py @@ -35,6 +35,7 @@ from ra2ce.network.avg_speed.avg_speed_calculator import AvgSpeedCalculator from ra2ce.network.exporters.json_exporter import JsonExporter from ra2ce.network.network_config_data.network_config_data import NetworkConfigData +from ra2ce.network.network_simplification import NetworkGraphSimplificator from ra2ce.network.network_wrappers.network_wrapper_protocol import ( NetworkWrapperProtocol, ) @@ -51,6 +52,9 @@ def __init__( self, config_data: NetworkConfigData, ) -> None: + self.attributes_to_exclude_in_simplification = ( + config_data.network.attributes_to_exclude_in_simplification + ) self.crs = config_data.crs # Network options @@ -87,16 +91,15 @@ def get_network( ) edges, nodes = self.get_network_edges_and_nodes_from_graph(graph) graph_complex = nut.graph_from_gdf(edges, nodes, node_id="node_fid") - graph_complex = ( - graph_complex.to_directed() - ) # simplification function requires nx.MultiDiGraph if self.delete_duplicate_nodes: graph_complex = self._delete_duplicate_nodes(graph_complex) logging.info("Start converting the complex graph to a simple graph") - graph_simple, graph_complex, link_tables = nut.create_simplified_graph( - graph_complex - ) + # Create 'graph_simple' + graph_simple, graph_complex, link_tables = NetworkGraphSimplificator( + graph_complex=graph_complex, + attributes_to_exclude=self.attributes_to_exclude_in_simplification, + ).simplify() # Assign the average speed and time to the graphs graph_simple = AvgSpeedCalculator(graph_simple, self.output_graph_dir).assign() @@ -241,27 +244,24 @@ def _delete_duplicate_nodes(graph_complex: nx.MultiGraph) -> nx.MultiGraph: updated_graph.graph["name"] = graph_complex.graph.get("name", None) return updated_graph - def _get_direct_graph_from_vector( - self, gdf: gpd.GeoDataFrame, edge_attributes_to_include: list - ) -> nx.DiGraph: - """Creates a simple directed graph with node and edge geometries based on a given GeoDataFrame. + def _create_graph_from_gdf( + self, + geo_dataframe: gpd.GeoDataFrame, + edge_attributes_to_include: list, + ) -> nx.Graph | nx.DiGraph: + """ + Creates a simple undirected graph with node and edge geometries based on a given GeoDataFrame. Args: gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries. Allow both LineString and MultiLineString. - edge_attributes_to_include: Attributes needed to be included from gdf in the graph - + edge_attributes_to_include (List[str], optional): Additional attributes to include from the GeoDataFrame in the graph. Returns: - nx.DiGraph: NetworkX graph object with "crs", "approach" as graph properties. + nx.Graph: NetworkX graph object with node and edge geometries and specified attributes. """ - - # simple geometry handling - gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf) - - # to graph - digraph = nx.DiGraph(crs=gdf.crs, approach="primal") - for _, row in gdf.iterrows(): + _networkx_graph = nx.DiGraph(crs=geo_dataframe.crs, approach="primal") + for _, row in geo_dataframe.iterrows(): link_id = row.get(self.file_id, None) link_type = row.get(self.link_type_column, None) @@ -273,15 +273,15 @@ def _get_direct_graph_from_vector( "avgspeed": row.pop("avgspeed") if "avgspeed" in row else None, "geometry": row.pop("geometry"), } - digraph.add_node(from_node, geometry=Point(from_node)) - digraph.add_node(to_node, geometry=Point(to_node)) - digraph.add_edge( + _networkx_graph.add_node(from_node, geometry=Point(from_node)) + _networkx_graph.add_node(to_node, geometry=Point(to_node)) + _networkx_graph.add_edge( from_node, to_node, link_id=link_id, **_edge_attributes, ) - if len(edge_attributes_to_include) > 0: + if edge_attributes_to_include: for edge_attribute_to_include in edge_attributes_to_include: edge_attribute = ( row[edge_attribute_to_include] @@ -289,9 +289,31 @@ def _get_direct_graph_from_vector( else None ) if edge_attribute: - edge_data = digraph[from_node][to_node] + edge_data = _networkx_graph[from_node][to_node] edge_data[edge_attribute_to_include] = edge_attribute - return digraph + + return _networkx_graph + + def _get_direct_graph_from_vector( + self, gdf: gpd.GeoDataFrame, edge_attributes_to_include: list + ) -> nx.DiGraph: + """Creates a simple directed graph with node and edge geometries based on a given GeoDataFrame. + + Args: + gdf (gpd.GeoDataFrame): Input GeoDataFrame containing line geometries. + Allow both LineString and MultiLineString. + edge_attributes_to_include: Attributes needed to be included from gdf in the graph + + + Returns: + nx.DiGraph: NetworkX graph object with "crs", "approach" as graph properties. + """ + + # simple geometry handling + gdf = VectorNetworkWrapper.explode_and_deduplicate_geometries(gdf) + + # to graph + return self._create_graph_from_gdf(gdf, edge_attributes_to_include) def _get_undirected_graph_from_vector( self, gdf: gpd.GeoDataFrame, edge_attributes_to_include: list diff --git a/ra2ce/network/networks.py b/ra2ce/network/networks.py index 6ae7f2cd9..385e43edc 100644 --- a/ra2ce/network/networks.py +++ b/ra2ce/network/networks.py @@ -174,6 +174,12 @@ def _include_attributes(self, attributes: list, graph: nx.Graph) -> nx.Graph: return updated_graph def _get_new_network_and_graph(self, export_types: list[str]) -> None: + """ + TODO: This method should be relying on a generic definition of a network result + from `.get_network`. This means, instead of getting `_base_graph, _network_gdf` + we get a generic `_ra2ce_network_wrapper` from which can later on just do a + `.simplify_network` or `.add_eges`, etc. using inheritance. + """ _base_graph, _network_gdf = NetworkWrapperFactory( self._config_data diff --git a/ra2ce/network/networks_utils.py b/ra2ce/network/networks_utils.py index 5f3609ea2..28bf34778 100644 --- a/ra2ce/network/networks_utils.py +++ b/ra2ce/network/networks_utils.py @@ -20,6 +20,7 @@ """ import itertools import logging +import math import os import sys import warnings @@ -38,7 +39,6 @@ from numpy.ma import MaskedArray from osgeo import gdal from osmnx import graph_to_gdfs -from osmnx.simplification import simplify_graph from rasterio.features import shapes from rasterio.mask import mask from shapely.geometry import LineString, MultiLineString, Point, box, shape @@ -980,38 +980,6 @@ def delete_duplicates(all_points: list[Point]) -> list[Point]: return uniquepoints -def create_simplified_graph( - graph_complex: nx.Graph, new_id: str = "rfid" -) -> tuple[nx.Graph, nx.Graph, tuple[dict, dict]]: - """Create a simplified graph with unique ids from a complex graph""" - logging.info("Simplifying graph") - try: - graph_complex = graph_create_unique_ids(graph_complex, "{}_c".format(new_id)) - - # Create simplified graph and add unique ids - graph_simple = simplify_graph_count(graph_complex) - graph_simple = graph_create_unique_ids(graph_simple, new_id) - - # Create look_up_tables between graphs with unique ids - simple_to_complex, complex_to_simple = graph_link_simple_id_to_complex( - graph_simple, new_id=new_id - ) - - # Store id table and add simple ids to complex graph - id_tables = (simple_to_complex, complex_to_simple) - graph_complex = add_simple_id_to_graph_complex( - graph_complex, complex_to_simple, new_id - ) - logging.info("Simplified graph succesfully created") - except Exception as exc: - graph_simple = None - id_tables = None - logging.error( - "Did not create a simplified version of the graph ({})".format(exc) - ) - return graph_simple, graph_complex, id_tables - - def gdf_check_create_unique_ids( gdf: gpd.GeoDataFrame, id_name: str, new_id_name: str = "rfid" ) -> tuple[gpd.GeoDataFrame, str]: @@ -1074,19 +1042,6 @@ def graph_check_create_unique_ids( return graph, id_name -def graph_create_unique_ids(graph: nx.Graph, new_id_name: str = "rfid") -> nx.Graph: - # Check if new_id_name exists and if unique - u, v, k = list(graph.edges)[0] - if new_id_name in graph.edges[u, v, k]: - return graph - # TODO: decide if we always add a new ID (in iGraph this is different) - # if len(set([str(e[-1][new_id_name]) for e in graph.edges.data(keys=True)])) < len(graph.edges()): - for i, (u, v, k) in enumerate(graph.edges(keys=True)): - graph[u][v][k][new_id_name] = i + 1 - logging.info("Added a new unique identifier field '{}'.".format(new_id_name)) - return graph - - def add_missing_geoms_graph(graph: nx.Graph, geom_name: str = "geometry") -> nx.Graph: # Not all nodes have geometry attributed (some only x and y coordinates) so add a geometry columns nodes_without_geom = [ @@ -1110,7 +1065,8 @@ def add_missing_geoms_graph(graph: nx.Graph, geom_name: str = "geometry") -> nx. def add_x_y_to_nodes(graph: nx.Graph) -> nx.Graph: """ - Add missing x and y attributes to nodes + Add missing x and y attributes to nodes. + TODO: Should this be moved to `network_simplification`? Args: graph (nx.Graph): Graph to add x and y attributes to @@ -1127,32 +1083,6 @@ def add_x_y_to_nodes(graph: nx.Graph) -> nx.Graph: return graph -def simplify_graph_count(complex_graph: nx.Graph) -> nx.Graph: - """ - Simplify the graph after adding missing x and y attributes to nodes - - Args: - complex_graph (nx.Graph): Graph to simplify - - Returns: - nx.Graph: Simplified graph - """ - complex_graph = add_x_y_to_nodes(complex_graph) - simple_graph = simplify_graph( - complex_graph, strict=True, remove_rings=True, track_merged=False - ) - - logging.info( - "Graph simplified from %s to %s nodes and %s to %s edges.", - complex_graph.number_of_nodes(), - simple_graph.number_of_nodes(), - complex_graph.number_of_edges(), - simple_graph.number_of_edges(), - ) - - return simple_graph - - def graph_from_gdf( gdf: gpd.GeoDataFrame, gdf_nodes, name: str = "network", node_id: str = "ID" ) -> nx.MultiGraph: @@ -1432,7 +1362,7 @@ def get_graph_edges_extent( if maxy > max_y: max_y = maxy - return (min_x, max_x, min_y, max_y) + return min_x, max_x, min_y, max_y def reproject_graph(original_graph: nx.Graph, crs_in: str, crs_out: str) -> nx.Graph: diff --git a/tests/analysis/damages/damage_calculation/test_damage_network_events.py b/tests/analysis/damages/damage_calculation/test_damage_network_events.py index 9b24e4d36..fdded825d 100644 --- a/tests/analysis/damages/damage_calculation/test_damage_network_events.py +++ b/tests/analysis/damages/damage_calculation/test_damage_network_events.py @@ -14,10 +14,10 @@ def test_init_with_valid_args(self): # 1. Define test data. _road_gf = None _val_cols = ["an_event_01", "an_event_02"] - _representative_damage_percentile = 100 + _representative_damage_percentage = 100 # 2. Run test _dne = DamageNetworkEvents( - _road_gf, _val_cols, _representative_damage_percentile + _road_gf, _val_cols, _representative_damage_percentage ) # 3. Verify final expectations @@ -29,12 +29,12 @@ def test_init_with_invalid_args(self): # 1. Define test data. _road_gf = None _val_cols = [] - _representative_damage_percentile = 100 + _representative_damage_percentage = 100 # 2. Run test. with pytest.raises(ValueError) as exc_err: _dne = DamageNetworkEvents( - _road_gf, _val_cols, _representative_damage_percentile + _road_gf, _val_cols, _representative_damage_percentage ) # 3. Verify final expectations. diff --git a/tests/analysis/damages/damage_calculation/test_damage_network_return_periods.py b/tests/analysis/damages/damage_calculation/test_damage_network_return_periods.py index 781de691e..5e9f001aa 100644 --- a/tests/analysis/damages/damage_calculation/test_damage_network_return_periods.py +++ b/tests/analysis/damages/damage_calculation/test_damage_network_return_periods.py @@ -16,11 +16,11 @@ def test_init(self): # 1. Define test data. _road_gf = None _val_cols = ["an_event_ab", "an_event_cd"] - _representative_damage_percentile = 100 + _representative_damage_percentage = 100 # 2. Run test. _damage = DamageNetworkReturnPeriods( - _road_gf, _val_cols, _representative_damage_percentile + _road_gf, _val_cols, _representative_damage_percentage ) # 2. Verify expectations. @@ -32,12 +32,12 @@ def test_init_with_no_event_raises_error(self): # 1. Define test data. _road_gf = None _val_cols = [] - _representative_damage_percentile = None + _representative_damage_percentage = None # 2. Run test. with pytest.raises(ValueError) as exc_err: DamageNetworkReturnPeriods( - _road_gf, _val_cols, _representative_damage_percentile + _road_gf, _val_cols, _representative_damage_percentage ) # 3. Verify expectations diff --git a/tests/analysis/damages/test_damages.py b/tests/analysis/damages/test_damages.py index 69ce6a4b2..e3d111e19 100644 --- a/tests/analysis/damages/test_damages.py +++ b/tests/analysis/damages/test_damages.py @@ -83,10 +83,10 @@ def test_event_based_damage_calculation_huizinga_stylized(self): val_cols = [ col for col in road_gdf.columns if (col[0].isupper() and col[1] == "_") ] - _representative_damage_percentile = 100 + _representative_damage_percentage = 100 # DO ACTUAL DAMAGE CALCULATION event_gdf = DamageNetworkEvents( - road_gdf, val_cols, _representative_damage_percentile + road_gdf, val_cols, _representative_damage_percentage ) event_gdf.main(damage_function=damage_function) @@ -120,9 +120,9 @@ def test_event_based_damage_calculation_huizinga( val_cols = [ col for col in road_gdf.columns if (col[0].isupper() and col[1] == "_") ] - _representative_damage_percentile = 100 + _representative_damage_percentage = 100 event_gdf = DamageNetworkEvents( - road_gdf, val_cols, _representative_damage_percentile + road_gdf, val_cols, _representative_damage_percentage ) event_gdf.main(damage_function=damage_function) @@ -183,11 +183,11 @@ def test_event_based_damage_calculation_osdamage_stylized(self): col for col in road_gdf.columns if (col[0].isupper() and col[1] == "_") ] - _representative_damage_percentile = 100 + _representative_damage_percentage = 100 # DO ACTUAL DAMAGE CALCULATION event_gdf = DamageNetworkEvents( - road_gdf, val_cols, _representative_damage_percentile + road_gdf, val_cols, _representative_damage_percentage ) event_gdf.main(damage_function=damage_function) @@ -268,11 +268,11 @@ def test_event_based_damage_calculation_manual_stylized(self): fun0 = manual_damage_functions.loaded[0] assert fun0.prefix == "te" - _representative_damage_percentile = 100 + _representative_damage_percentage = 100 # DO ACTUAL DAMAGE CALCULATION event_gdf = DamageNetworkEvents( - road_gdf, val_cols, _representative_damage_percentile + road_gdf, val_cols, _representative_damage_percentage ) event_gdf.main( damage_function=damage_function, @@ -370,7 +370,7 @@ def test_old_event_based_damage_calculation_manualfunction( def test_construct_damage_network_return_periods(self, risk_data_file: Path): damage_network = DamageNetworkReturnPeriods.construct_from_csv( - risk_data_file, sep=";", representative_damage_percentile=100 + risk_data_file, sep=";", representative_damage_percentage=100 ) assert ( type(damage_network) == DamageNetworkReturnPeriods @@ -378,7 +378,7 @@ def test_construct_damage_network_return_periods(self, risk_data_file: Path): def test_risk_calculation_default(self, risk_data_file: Path): damage_network = DamageNetworkReturnPeriods.construct_from_csv( - risk_data_file, sep=";", representative_damage_percentile=100 + risk_data_file, sep=";", representative_damage_percentage=100 ) damage_network.control_risk_calculation(mode=RiskCalculationModeEnum.DEFAULT) assert ( @@ -388,7 +388,7 @@ def test_risk_calculation_default(self, risk_data_file: Path): def test_risk_calculation_cutoff(self, risk_data_file: Path): for rp in [15, 200, 25]: damage_network = DamageNetworkReturnPeriods.construct_from_csv( - risk_data_file, sep=";", representative_damage_percentile=100 + risk_data_file, sep=";", representative_damage_percentage=100 ) damage_network.control_risk_calculation( mode=RiskCalculationModeEnum.CUT_FROM_YEAR, year=rp @@ -401,7 +401,7 @@ def test_risk_calculation_cutoff(self, risk_data_file: Path): def test_risk_calculation_triangle_to_null(self, risk_data_file: Path): damage_network = DamageNetworkReturnPeriods.construct_from_csv( - risk_data_file, sep=";", representative_damage_percentile=100 + risk_data_file, sep=";", representative_damage_percentage=100 ) for triangle_rp in [8, 2]: damage_network.control_risk_calculation( diff --git a/tests/network/network_simplification/__init__.py b/tests/network/network_simplification/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/network/network_simplification/test_network_graph_simplificator.py b/tests/network/network_simplification/test_network_graph_simplificator.py new file mode 100644 index 000000000..f2b05c1a9 --- /dev/null +++ b/tests/network/network_simplification/test_network_graph_simplificator.py @@ -0,0 +1,309 @@ +import random +from typing import Callable, Iterator +import math +import networkx as nx +import numpy as np +import pytest +from shapely.geometry import LineString, Point + +from ra2ce.network import add_missing_geoms_graph +from ra2ce.network.network_simplification.network_graph_simplificator import ( + NetworkGraphSimplificator, +) +from ra2ce.network.network_simplification.network_simplification_with_attribute_exclusion import ( + NetworkSimplificationWithAttributeExclusion, +) +from ra2ce.network.network_simplification.nx_to_snkit_network_converter import ( + NxToSnkitNetworkConverter, +) +from ra2ce.network.network_simplification.snkit_to_nx_network_converter import ( + SnkitToNxNetworkConverter, +) +from ra2ce.network.networks_utils import line_length + + +def _detailed_edge_comparison( + graph1: nx.MultiDiGraph | nx.MultiGraph, graph2: nx.MultiDiGraph | nx.MultiGraph +) -> bool: + def _dicts_comparison( + graph1: nx.MultiDiGraph | nx.MultiGraph, graph2: nx.MultiDiGraph | nx.MultiGraph + ) -> bool: + for u, v, k, data1 in graph1.edges(keys=True, data=True): + data2 = graph2.get_edge_data(u, v, k) + for key1, value1 in data1.items(): + if key1 not in data2: + return False + if isinstance(value1, float) and math.isnan(value1): + if not math.isnan(data2[key1]): + return False + continue + if value1 != data2[key1]: + return False + return True + + check_1_2 = _dicts_comparison(graph1, graph2) + check_2_1 = _dicts_comparison(graph2, graph1) + + if check_1_2 and check_2_1: + return True + else: + return False + + +class TestNetworkGraphSimplificator: + @pytest.fixture(name="network_graph_simplificator_factory") + def _get_network_graph_simplificator( + self, + ) -> Iterator[Callable[[], NetworkGraphSimplificator]]: + def get_network_graph_simplificator() -> NetworkGraphSimplificator: + return NetworkGraphSimplificator( + graph_complex=None, attributes_to_exclude=[] + ) + + yield get_network_graph_simplificator + + def test_validate_fixture_init( + self, + network_graph_simplificator_factory: Callable[[], NetworkGraphSimplificator], + ): + # 1. Define test data. + _network_graph_simplificator = network_graph_simplificator_factory() + + # 2. Verify expectations + assert isinstance(_network_graph_simplificator, NetworkGraphSimplificator) + + @pytest.fixture(name="multigraph_fixture") + def _get_multigraph_fixture(self) -> Iterator[nx.MultiGraph]: + _graph = nx.MultiGraph() + _graph.add_edge(2, 3, weight=5) + _graph.add_edge(2, 1, weight=2) + yield _graph + + def test__graph_create_unique_ids_with_missing_id_data( + self, + network_graph_simplificator_factory: Callable[[], NetworkGraphSimplificator], + multigraph_fixture: nx.MultiGraph, + ): + # 1. Define test data + _network_graph_simplificator = network_graph_simplificator_factory() + assert isinstance(multigraph_fixture, nx.MultiGraph) + _new_id_name = "dummy_id" + + # 2. Run test + _return_graph = _network_graph_simplificator._graph_create_unique_ids( + multigraph_fixture, _new_id_name + ) + + # 3. Verify final expectations + assert _return_graph == multigraph_fixture + _dicts_keys = [_k[-1].keys() for _k in multigraph_fixture.edges.data(keys=True)] + assert all(_new_id_name in _keys for _keys in _dicts_keys) + + def test__graph_create_unique_ids_with_existing_id( + self, + network_graph_simplificator_factory: NetworkGraphSimplificator, + multigraph_fixture: nx.MultiGraph, + ): + # 1. Define test data + _network_graph_simplificator = network_graph_simplificator_factory() + assert isinstance(multigraph_fixture, nx.MultiGraph) + _new_id_name = "weight" + + # 2. Run test + _return_graph = _network_graph_simplificator._graph_create_unique_ids( + multigraph_fixture, _new_id_name + ) + + # 3. Verify final expectations + assert _return_graph == multigraph_fixture + + +class TestNetworkSimplificationWithAttributeExclusion: + @pytest.fixture(name="network_simplification_with_attribute_exclusion") + def _get_network_simplification_with_attribute_exclusion( + self, + ) -> Iterator[NetworkSimplificationWithAttributeExclusion]: + yield NetworkSimplificationWithAttributeExclusion( + nx_graph=None, attributes_to_exclude=[] + ) + + @pytest.fixture(name="nx_digraph_factory") + def _get_nx_digraph_factory(self) -> Iterator[Callable[[], nx.MultiDiGraph]]: + def create_nx_multidigraph(): + _nx_digraph = nx.MultiDiGraph() + for i in range(1, 16): + _nx_digraph.add_node(i, x=i, y=i * 10) + + _nx_digraph.add_edge(1, 2, a=np.nan) + _nx_digraph.add_edge(2, 1, a=np.nan) + _nx_digraph.add_edge(2, 3, a=np.nan) + _nx_digraph.add_edge(3, 4, a=np.nan) + _nx_digraph.add_edge(4, 5, a="yes") + _nx_digraph.add_edge(5, 6, a="yes") + _nx_digraph.add_edge(6, 7, a="yes") + _nx_digraph.add_edge(7, 8, a=np.nan) + _nx_digraph.add_edge(8, 9, a=np.nan) + _nx_digraph.add_edge(8, 12, a=np.nan) + _nx_digraph.add_edge(8, 13, a="yes") + _nx_digraph.add_edge(9, 10, a=np.nan) + _nx_digraph.add_edge(10, 11, a=np.nan) + _nx_digraph.add_edge(11, 12, a="yes") + _nx_digraph.add_edge(13, 14, a="yes") + _nx_digraph.add_edge(14, 15, a="yes") + _nx_digraph.add_edge(15, 11, a="yes") + + _nx_digraph = add_missing_geoms_graph(_nx_digraph, "geometry") + _nx_digraph.graph["crs"] = "EPSG:4326" + + _nx_digraph = add_missing_geoms_graph(_nx_digraph, "geometry") + return _nx_digraph + + yield create_nx_multidigraph + + @pytest.fixture(name="expected_result_graph_fixture") + def _get_expected_result_graph_fixture( + self, nx_digraph_factory: nx.MultiDiGraph + ) -> nx.MultiGraph: + _nx_digraph = nx_digraph_factory() + _result_digraph = nx.MultiGraph() + node_ids_degrees = {2: 1, 4: 2, 7: 2, 8: 4, 11: 3, 12: 2} + for node_id, degree in node_ids_degrees.items(): + node_data = _nx_digraph.nodes[node_id] + node_data["id"] = node_id + node_data["degree"] = degree + _result_digraph.add_node(node_id, **node_data) + _result_digraph = add_missing_geoms_graph(_result_digraph, "geometry") + + _result_digraph.add_edge( + 2, + 4.0, + a="None", + from_node=2, + to_node=4, + geometry=LineString( + [ + _nx_digraph.nodes[2]["geometry"], + _nx_digraph.nodes[3]["geometry"], + _nx_digraph.nodes[4]["geometry"], + ] + ), + ) + + _result_digraph.add_edge( + 4, + 7.0, + a="yes", + from_node=4, + to_node=7, + geometry=LineString( + [ + _nx_digraph.nodes[4]["geometry"], + _nx_digraph.nodes[5]["geometry"], + _nx_digraph.nodes[6]["geometry"], + _nx_digraph.nodes[7]["geometry"], + ] + ), + ) + _result_digraph.add_edge( + 7, + 8.0, + a="None", + from_node=7, + to_node=8, + geometry=LineString( + [ + _nx_digraph.nodes[7]["geometry"], + _nx_digraph.nodes[8]["geometry"], + ] + ), + ) + _result_digraph.add_edge( + 8, + 11.0, + a="None", + from_node=8, + to_node=11, + geometry=LineString( + [ + _nx_digraph.nodes[8]["geometry"], + _nx_digraph.nodes[9]["geometry"], + _nx_digraph.nodes[10]["geometry"], + _nx_digraph.nodes[11]["geometry"], + ] + ), + ) + _result_digraph.add_edge( + 8, + 11.0, + a="yes", + from_node=8, + to_node=11, + geometry=LineString( + [ + _nx_digraph.nodes[8]["geometry"], + _nx_digraph.nodes[13]["geometry"], + _nx_digraph.nodes[14]["geometry"], + _nx_digraph.nodes[15]["geometry"], + _nx_digraph.nodes[11]["geometry"], + ] + ), + ) + _result_digraph.add_edge( + 8, + 12.0, + a="None", + from_node=8, + to_node=12, + geometry=LineString( + [ + _nx_digraph.nodes[8]["geometry"], + _nx_digraph.nodes[12]["geometry"], + ] + ), + ) + _result_digraph.add_edge( + 11, + 12.0, + a="yes", + from_node=11, + to_node=12, + geometry=LineString( + [ + _nx_digraph.nodes[11]["geometry"], + _nx_digraph.nodes[12]["geometry"], + ] + ), + ) + + _result_digraph.graph["crs"] = "EPSG:4326" + + snkit_network = NxToSnkitNetworkConverter( + networkx_graph=_result_digraph + ).convert() + snkit_network.edges["length"] = snkit_network.edges["geometry"].apply( + lambda x: line_length(x, snkit_network.edges.crs) + ) + snkit_network.edges = snkit_network.edges.drop( + columns=["id", "from_node", "to_node"] + ) + return SnkitToNxNetworkConverter(snkit_network=snkit_network).convert() + + def test_simplify_graph( + self, + network_simplification_with_attribute_exclusion: NetworkSimplificationWithAttributeExclusion, + nx_digraph_factory: Callable[[], nx.MultiDiGraph], + expected_result_graph_fixture: nx.MultiDiGraph, + ): + network_simplification_with_attribute_exclusion.nx_graph = nx_digraph_factory() + network_simplification_with_attribute_exclusion.attributes_to_exclude = ["a"] + + _graph_simple = network_simplification_with_attribute_exclusion.simplify_graph() + + # Compare nodes with attributes + assert _graph_simple.nodes(data=True) == expected_result_graph_fixture.nodes( + data=True + ) + # Compare edges topology + assert set(_graph_simple.edges()) == set(expected_result_graph_fixture.edges()) + # Compare edges with attributes + assert _detailed_edge_comparison(_graph_simple, expected_result_graph_fixture) diff --git a/tests/network/test_networks_utils.py b/tests/network/test_networks_utils.py index c5184f7b7..d0fc9f635 100644 --- a/tests/network/test_networks_utils.py +++ b/tests/network/test_networks_utils.py @@ -265,12 +265,6 @@ def test_with_valid_data(self): assert any(_point.almost_equals(_p) for _p in _points[:3]) -class TestCreateSimplifiedGraph: - def test_with_none_graph_complex_doesnot_raise(self): - _return_result = nu.create_simplified_graph(None, "") - assert _return_result == (None, None, None) - - class TestGdfCheckCreateUniqueIds: def test_with_user_defined_identifier(self): # 1. Define test data. @@ -335,37 +329,6 @@ def test_with_valid_graph(self): assert _return_graph == _graph assert _return_id == _find_id - -class TestGraphCreateUniqueIds: - def test_with_missing_id_data(self): - # 1. Define test data - _graph = nx.MultiGraph() - _graph.add_edge(2, 3, weight=5) - _graph.add_edge(2, 1, weight=2) - _new_id_name = "dummy_id" - - # 2. Run test - _return_graph = nu.graph_create_unique_ids(_graph, _new_id_name) - - # 3. Verify final expectations - assert _return_graph == _graph - _dicts_keys = [_k[-1].keys() for _k in _graph.edges.data(keys=True)] - assert all(_new_id_name in _keys for _keys in _dicts_keys) - - def test_with_existing_id(self): - # 1. Define test data - _graph = nx.MultiGraph() - _graph.add_edge(2, 3, weight=5) - _graph.add_edge(2, 1, weight=2) - _new_id_name = "weight" - - # 2. Run test - _return_graph = nu.graph_create_unique_ids(_graph, _new_id_name) - - # 3. Verify final expectations - assert _return_graph == _graph - - class TestNetworksUtils: def test_get_normalized_geojson_polygon_from_geojson(self): # 1. Define test data.