diff --git a/ra2ce/analysis/analysis_base.py b/ra2ce/analysis/analysis_base.py index 71389ed7a..d545ef318 100644 --- a/ra2ce/analysis/analysis_base.py +++ b/ra2ce/analysis/analysis_base.py @@ -34,6 +34,18 @@ class AnalysisBase(ABC, AnalysisProtocol): the `AnalysisProtocol`. """ + def _get_analysis_result( + self, gdf_result: GeoDataFrame, custom_name: str + ) -> AnalysisResult: + _ar = AnalysisResult( + analysis_result=gdf_result, + analysis_config=self.analysis, + output_path=self.output_path, + ) + if custom_name: + _ar.analysis_name = custom_name + return _ar + def generate_result_wrapper( self, *analysis_result: GeoDataFrame ) -> AnalysisResultWrapper: @@ -48,13 +60,8 @@ def generate_result_wrapper( AnalysisResultWrapper: Wrapping result with configuration details. """ - def get_analysis_result(gdf_result: GeoDataFrame) -> AnalysisResult: - return AnalysisResult( - analysis_result=gdf_result, - analysis_config=self.analysis, - output_path=self.output_path, - ) - return AnalysisResultWrapper( - results_collection=list(map(get_analysis_result, analysis_result)) + results_collection=[ + self._get_analysis_result(_ar, "") for _ar in analysis_result + ] ) diff --git a/ra2ce/analysis/losses/multi_link_origin_closest_destination.py b/ra2ce/analysis/losses/multi_link_origin_closest_destination.py index b9d0a9010..bf41d5f09 100644 --- a/ra2ce/analysis/losses/multi_link_origin_closest_destination.py +++ b/ra2ce/analysis/losses/multi_link_origin_closest_destination.py @@ -15,8 +15,7 @@ from ra2ce.network.network_config_data.network_config_data import ( OriginsDestinationsSection, ) -from ra2ce.network.networks_utils import graph_to_gpkg -from ra2ce.ra2ce_logger import logging +from ra2ce.network.networks_utils import get_nodes_and_edges_from_origin_graph class MultiLinkOriginClosestDestination(AnalysisBase, AnalysisLossesProtocol): @@ -46,56 +45,13 @@ def __init__( self.file_id = analysis_input.file_id self._analysis_input = analysis_input - def _save_gdf(self, gdf: GeoDataFrame, save_path: Path) -> None: - """Takes in a geodataframe object and outputs shapefiles at the paths indicated by edge_shp and node_shp - - Arguments: - gdf [geodataframe]: geodataframe object to be converted - save_path [str]: output path including extension for edges shapefile - Returns: - None - """ - # save to shapefile - gdf.crs = "epsg:4326" # TODO: decide if this should be variable with e.g. an output_crs configured - - for col in gdf.columns: - if gdf[col].dtype == object and col != gdf.geometry.name: - gdf[col] = gdf[col].astype(str) - - if save_path.exists(): - save_path.unlink() - gdf.to_file(save_path, driver="GPKG") - logging.info("Results saved to: {}".format(save_path)) - def execute(self) -> AnalysisResultWrapper: - def _save_gpkg_analysis( - base_graph, - to_save_gdf: list[GeoDataFrame], - to_save_gdf_names: list[str], - ): - for to_save, save_name in zip(to_save_gdf, to_save_gdf_names): - if not to_save.empty: - gpkg_path = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + f"_{save_name}.gpkg" - ) - self._save_gdf(to_save, gpkg_path) - - # Save the Graph - gpkg_path_nodes = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + "_results_nodes.gpkg" - ) - gpkg_path_edges = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + "_results_edges.gpkg" - ) - graph_to_gpkg(base_graph, gpkg_path_edges, gpkg_path_nodes) - _output_path = self.output_path.joinpath(self.analysis.analysis.config_value) analyzer = OriginClosestDestination(self._analysis_input) - if self.analysis.calculate_route_without_disruption: ( - base_graph, + _base_graph, opt_routes_without_hazard, destinations, ) = analyzer.optimal_route_origin_closest_destination() @@ -105,7 +61,7 @@ def _save_gpkg_analysis( opt_routes_with_hazard = GeoDataFrame(data=None) else: ( - base_graph, + _base_graph, origins, destinations, agg_results, @@ -119,7 +75,7 @@ def _save_gpkg_analysis( ) else: ( - base_graph, + _base_graph, origins, destinations, agg_results, @@ -127,40 +83,34 @@ def _save_gpkg_analysis( ) = analyzer.multi_link_origin_closest_destination() opt_routes_without_hazard = GeoDataFrame() - if self.analysis.save_gpkg: - # Save the GeoDataFrames - to_save_gdf = [ - origins, - destinations, - opt_routes_without_hazard, - opt_routes_with_hazard, - ] - to_save_gdf_names = [ - "origins", - "destinations", - "optimal_routes_without_hazard", - "optimal_routes_with_hazard", + _nodes_graph, _edges_graph = get_nodes_and_edges_from_origin_graph(_base_graph) + _base_name = self.analysis.name.replace(" ", "_") + _analysis_result_wrapper = AnalysisResultWrapper( + results_collection=[ + self._get_analysis_result(origins, _base_name + "_origins"), + self._get_analysis_result(destinations, _base_name + "_destinations"), + self._get_analysis_result(_nodes_graph, _base_name + "_results_nodes"), + self._get_analysis_result(_edges_graph, _base_name + "_results_edges"), + self._get_analysis_result( + opt_routes_without_hazard, + _base_name + "_optimal_routes_without_hazard", + ), + self._get_analysis_result( + opt_routes_with_hazard, _base_name + "_optimal_routes_with_hazard" + ), ] - _save_gpkg_analysis(base_graph, to_save_gdf, to_save_gdf_names) - if self.analysis.save_csv: - csv_path = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + "_destinations.csv" - ) - if "geometry" in destinations.columns: - del destinations["geometry"] - if not csv_path.parent.exists(): - csv_path.parent.mkdir(parents=True) - destinations.to_csv(csv_path, index=False) + ) - csv_path = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + "_optimal_routes.csv" + # Legacy code, previously only done to export to CSV. + _opt_routes_name = _base_name + "_optimal_routes" + if not opt_routes_without_hazard.empty: + _analysis_result_wrapper.results_collection.append( + self._get_analysis_result(opt_routes_with_hazard, _opt_routes_name) + ) + if not opt_routes_with_hazard.empty: + _analysis_result_wrapper.results_collection.append( + self._get_analysis_result(opt_routes_without_hazard, _opt_routes_name) ) - if not opt_routes_without_hazard.empty: - del opt_routes_without_hazard["geometry"] - opt_routes_without_hazard.to_csv(csv_path, index=False) - if not opt_routes_with_hazard.empty: - del opt_routes_with_hazard["geometry"] - opt_routes_with_hazard.to_csv(csv_path, index=False) if self.graph_file_hazard.file is not None: agg_results.to_excel( @@ -170,5 +120,4 @@ def _save_gpkg_analysis( index=False, ) - # TODO: This does not seem correct, why were we returning None? - return self.generate_result_wrapper(None) + return _analysis_result_wrapper diff --git a/ra2ce/analysis/losses/optimal_route_origin_closest_destination.py b/ra2ce/analysis/losses/optimal_route_origin_closest_destination.py index 43096574a..9bb60b24d 100644 --- a/ra2ce/analysis/losses/optimal_route_origin_closest_destination.py +++ b/ra2ce/analysis/losses/optimal_route_origin_closest_destination.py @@ -1,8 +1,5 @@ -import logging from pathlib import Path -from geopandas import GeoDataFrame - from ra2ce.analysis.analysis_base import AnalysisBase from ra2ce.analysis.analysis_config_data.analysis_config_data import ( AnalysisSectionLosses, @@ -16,7 +13,6 @@ from ra2ce.network.network_config_data.network_config_data import ( OriginsDestinationsSection, ) -from ra2ce.network.networks_utils import graph_to_gpkg class OptimalRouteOriginClosestDestination(AnalysisBase, AnalysisLossesProtocol): @@ -44,74 +40,20 @@ def __init__( self.file_id = analysis_input.file_id self._analysis_input = analysis_input - def _save_gdf(self, gdf: GeoDataFrame, save_path: Path): - """Takes in a geodataframe object and outputs shapefiles at the paths indicated by edge_shp and node_shp - - Arguments: - gdf [geodataframe]: geodataframe object to be converted - save_path [str]: output path including extension for edges shapefile - Returns: - None - """ - # save to shapefile - gdf.crs = "epsg:4326" # TODO: decide if this should be variable with e.g. an output_crs configured - - for col in gdf.columns: - if gdf[col].dtype == object and col != gdf.geometry.name: - gdf[col] = gdf[col].astype(str) - - if save_path.exists(): - save_path.unlink() - gdf.to_file(save_path, driver="GPKG") - logging.info("Results saved to: {}".format(save_path)) - def execute(self) -> AnalysisResultWrapper: - def _save_gpkg_analysis( - base_graph, - to_save_gdf: list[GeoDataFrame], - to_save_gdf_names: list[str], - ): - for to_save, save_name in zip(to_save_gdf, to_save_gdf_names): - if not to_save.empty: - gpkg_path = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + f"_{save_name}.gpkg" - ) - self._save_gdf(to_save, gpkg_path) - - # Save the Graph - gpkg_path_nodes = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + "_results_nodes.gpkg" - ) - gpkg_path_edges = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + "_results_edges.gpkg" - ) - graph_to_gpkg(base_graph, gpkg_path_edges, gpkg_path_nodes) - - _output_path = self.output_path.joinpath(self.analysis.analysis.config_value) - analyzer = OriginClosestDestination(self._analysis_input) + + # Get gdfs ( base_graph, opt_routes, destinations, ) = analyzer.optimal_route_origin_closest_destination() - if self.analysis.save_gpkg: - # Save the GeoDataFrames - to_save_gdf = [destinations, opt_routes] - to_save_gdf_names = ["destinations", "optimal_routes"] - _save_gpkg_analysis(base_graph, to_save_gdf, to_save_gdf_names) - - if self.analysis.save_csv: - csv_path = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + "_destinations.csv" - ) - del destinations["geometry"] - destinations.to_csv(csv_path, index=False) - - csv_path = _output_path.joinpath( - self.analysis.name.replace(" ", "_") + "_optimal_routes.csv" - ) - del opt_routes["geometry"] - opt_routes.to_csv(csv_path, index=False) - # TODO: This does not seem correct, why were we returning None? - return self.generate_result_wrapper(None) + _base_name = self.analysis.name.replace(" ", "_") + return AnalysisResultWrapper( + results_collection=[ + self._get_analysis_result(base_graph, _base_name + "_origins"), + self._get_analysis_result(destinations, _base_name + "_destinations"), + self._get_analysis_result(opt_routes, _base_name + "_optimal_routes"), + ] + ) diff --git a/ra2ce/network/exporters/geodataframe_network_exporter.py b/ra2ce/network/exporters/geodataframe_network_exporter.py index 433ceaabc..727e15b14 100644 --- a/ra2ce/network/exporters/geodataframe_network_exporter.py +++ b/ra2ce/network/exporters/geodataframe_network_exporter.py @@ -29,13 +29,18 @@ class GeoDataFrameNetworkExporter(NetworkExporterBase): def export_to_gpkg(self, output_dir: Path, export_data: gpd.GeoDataFrame) -> None: - _output_shp_path = output_dir / (self._basename + ".gpkg") + _output_gpkg_path = output_dir.joinpath(self.basename + ".gpkg") + + if _output_gpkg_path.exists(): + logging.info("Removing previous gpkg file %s.", _output_gpkg_path) + _output_gpkg_path.unlink() + export_data.to_file( - _output_shp_path, index=False - ) # , encoding='utf-8' -Removed the encoding type because this causes some shapefiles not to save. - logging.info(f"Saved {_output_shp_path.stem} in {output_dir}.") + _output_gpkg_path, index=False, driver="GPKG", encoding="utf-8" + ) + logging.info("Saved %s in %s.", _output_gpkg_path.stem, output_dir) def export_to_pickle(self, output_dir: Path, export_data: gpd.GeoDataFrame) -> None: - self.pickle_path = output_dir / (self._basename + ".feather") + self.pickle_path = output_dir.joinpath(self.basename + ".feather") export_data.to_feather(self.pickle_path, index=False) - logging.info(f"Saved {self.pickle_path.stem} in {output_dir}.") + logging.info("Saved %s in %s.", self.pickle_path.stem, output_dir) diff --git a/ra2ce/network/exporters/multi_graph_network_exporter.py b/ra2ce/network/exporters/multi_graph_network_exporter.py index 47941ff54..b3bf4527e 100644 --- a/ra2ce/network/exporters/multi_graph_network_exporter.py +++ b/ra2ce/network/exporters/multi_graph_network_exporter.py @@ -24,11 +24,16 @@ from pathlib import Path from typing import Optional +from geopandas import GeoDataFrame + +from ra2ce.network.exporters.geodataframe_network_exporter import ( + GeoDataFrameNetworkExporter, +) from ra2ce.network.exporters.network_exporter_base import ( MULTIGRAPH_TYPE, NetworkExporterBase, ) -from ra2ce.network.networks_utils import graph_to_gpkg +from ra2ce.network.networks_utils import get_nodes_and_edges_from_origin_graph class MultiGraphNetworkExporter(NetworkExporterBase): @@ -38,20 +43,23 @@ def export_to_gpkg(self, output_dir: Path, export_data: MULTIGRAPH_TYPE) -> None if not output_dir.is_dir(): output_dir.mkdir(parents=True) - # TODO: This method should be a writer itself. - graph_to_gpkg( - export_data, - output_dir / (self._basename + "_edges.gpkg"), - output_dir / (self._basename + "_nodes.gpkg"), - ) - logging.info( - f"Saved {self._basename + '_edges.gpkg'} and {self._basename + '_nodes.gpkg'} in {output_dir}." - ) + _nodes_graph, _edges_graph = get_nodes_and_edges_from_origin_graph(export_data) + + def export_gdf(gdf_data: GeoDataFrame, suffix: str): + """ + Different from `GeoDataFrameNetworkExporter` at `index=True`. + """ + _export_file = output_dir.joinpath(self.basename + suffix + ".gpkg") + gdf_data.to_file(_export_file, index=True, driver="GPKG", encoding="utf-8") + logging.info("Saved %s in %s.", _export_file.stem, output_dir) + + export_gdf(_edges_graph, "_edges") + export_gdf(_nodes_graph, "_nodes") def export_to_pickle(self, output_dir: Path, export_data: MULTIGRAPH_TYPE) -> None: - self.pickle_path = output_dir / (self._basename + ".p") + self.pickle_path = output_dir.joinpath(self.basename + ".p") with open(self.pickle_path, "wb") as f: pickle.dump(export_data, f, protocol=4) logging.info( - f"Saved {self.pickle_path.stem} in {self.pickle_path.resolve().parent}." + "Saved %s in %s.", self.pickle_path.stem, self.pickle_path.resolve().parent ) diff --git a/ra2ce/network/exporters/network_exporter_base.py b/ra2ce/network/exporters/network_exporter_base.py index ed27a93e0..bea1027c3 100644 --- a/ra2ce/network/exporters/network_exporter_base.py +++ b/ra2ce/network/exporters/network_exporter_base.py @@ -20,6 +20,7 @@ """ +from dataclasses import dataclass, field from pathlib import Path import geopandas as gpd @@ -31,15 +32,11 @@ NETWORK_TYPE = gpd.GeoDataFrame | MULTIGRAPH_TYPE +@dataclass(kw_only=True) class NetworkExporterBase(Ra2ceExporterProtocol): - _basename: str - _export_types: list[str] = ["pickle"] - pickle_path: Path - - def __init__(self, basename: str, export_types: list[str]) -> None: - self._basename = basename - self._export_types = export_types - self.pickle_path = None + basename: str + export_types: list[str] = field(default_factory=lambda: ["pickle"]) + pickle_path: Path = None def export_to_gpkg(self, output_dir: Path, export_data: NETWORK_TYPE) -> None: """ @@ -70,8 +67,8 @@ def export(self, export_path: Path, export_data: NETWORK_TYPE) -> None: export_path (Path): Path to the output directory where to export the data. export_data (NETWORK_TYPE): Data that needs to be exported. """ - if "pickle" in self._export_types: + if "pickle" in self.export_types: self.export_to_pickle(export_path, export_data) - if "gpkg" in self._export_types: + if "gpkg" in self.export_types: self.export_to_gpkg(export_path, export_data) diff --git a/ra2ce/network/exporters/network_exporter_factory.py b/ra2ce/network/exporters/network_exporter_factory.py index a403daf21..eaa6be594 100644 --- a/ra2ce/network/exporters/network_exporter_factory.py +++ b/ra2ce/network/exporters/network_exporter_factory.py @@ -47,7 +47,7 @@ def export( export_types: list[str], ) -> None: _exporter_type = self.get_exporter_type(network) - self._exporter = _exporter_type(basename, export_types) + self._exporter = _exporter_type(basename=basename, export_types=export_types) self._exporter.export(output_dir, network) def get_pickle_path(self) -> Path: diff --git a/ra2ce/network/networks_utils.py b/ra2ce/network/networks_utils.py index 0e9a76f54..ef66d97de 100644 --- a/ra2ce/network/networks_utils.py +++ b/ra2ce/network/networks_utils.py @@ -1142,47 +1142,45 @@ def graph_to_gdf( return edges, nodes -def graph_to_gpkg(origin_graph: nx.Graph, edge_gpkg: str, node_gpkg: str) -> None: - """Takes in a networkx graph object and outputs shapefiles at the paths indicated by edge_gpkg and node_gpkg +def get_nodes_and_edges_from_origin_graph( + origin_graph: nx.Graph, +) -> tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: + """ + Takes in a networkx graph object and returns the `GeoDataFrame` with separated + nodes and edges. Arguments: origin_graph [nx.Graph]: networkx graph object to be converted - edge_gpkg [str]: output path including extension for edges geopackage - node_gpkg [str]: output path including extension for nodes geopackage Returns: - None + tuple[gpd.GeoDataFrame, gpd.GeoDataFrame]: + resulting tuple of formatted `gpd.GeoDataFrame` for nodes and edges. """ # now only multidigraphs and graphs are used if type(origin_graph) == nx.Graph: + # isinstance / issubclass will not work as nx.MultiGraph would return True origin_graph = nx.MultiGraph(origin_graph) # The nodes should have a geometry attribute (perhaps on top of the x and y attributes) - nodes, edges = graph_to_gdfs(origin_graph, node_geometry=False) + _nodes, _edges = graph_to_gdfs(origin_graph, node_geometry=False) - dfs = [edges, nodes] + dfs = [_edges, _nodes] for df in dfs: for col in df.columns: if df[col].dtype == object and col != df.geometry.name: df[col] = df[col].astype(str) # Add a CRS to the nodes - if nodes.crs is None and edges.crs is not None: - nodes.crs = edges.crs - - logging.info("Saving nodes as shapefile: {}".format(node_gpkg)) - logging.info("Saving edges as shapefile: {}".format(edge_gpkg)) + if _nodes.crs is None and _edges.crs is not None: + _nodes.crs = _edges.crs # The encoding utf-8 might result in an empty shapefile if the wrong encoding is used. - for entity in [nodes, edges]: + for entity in [_nodes, _edges]: if "osmid" in entity: # Otherwise it gives this error: cannot insert osmid, already exist entity["osmid_original"] = entity.pop("osmid") - for _path in [node_gpkg, edge_gpkg]: - if _path.exists(): - _path.unlink() - nodes.to_file(node_gpkg, driver="GPKG", encoding="utf-8") - edges.to_file(edge_gpkg, driver="GPKG", encoding="utf-8") + + return _nodes, _edges @staticmethod diff --git a/tests/network/exporters/test_geodataframe_network_exporter.py b/tests/network/exporters/test_geodataframe_network_exporter.py index e568a77aa..60e90859b 100644 --- a/tests/network/exporters/test_geodataframe_network_exporter.py +++ b/tests/network/exporters/test_geodataframe_network_exporter.py @@ -1,4 +1,5 @@ import shutil +from pathlib import Path import pytest from geopandas import GeoDataFrame @@ -7,25 +8,28 @@ GeoDataFrameNetworkExporter, ) from ra2ce.network.exporters.network_exporter_base import NetworkExporterBase -from tests import test_results class TestGeodataframeNetworkExporter: def test_initialize(self): _basename = "dummy_test" - _exporter = GeoDataFrameNetworkExporter(_basename, ["pickle", "gpkg"]) + _exporter = GeoDataFrameNetworkExporter( + basename=_basename, export_types=["pickle", "gpkg"] + ) assert isinstance(_exporter, GeoDataFrameNetworkExporter) assert isinstance(_exporter, NetworkExporterBase) @pytest.mark.skip(reason="TODO: Needs to define GeoDataFrame dummydata.") - def test_export_to_gpkg(self, request: pytest.FixtureRequest): + def test_export_to_gpkg(self, test_result_param_case: Path): # 1. Define test data. - _output_dir = test_results / request.node.name + _output_dir = test_result_param_case if _output_dir.is_dir(): shutil.rmtree(_output_dir) _basename = "dummy_test" - _exporter = GeoDataFrameNetworkExporter(_basename, ["pickle", "gpkg"]) + _exporter = GeoDataFrameNetworkExporter( + basename=_basename, export_types=["pickle", "gpkg"] + ) _export_data = GeoDataFrame() @@ -37,14 +41,16 @@ def test_export_to_gpkg(self, request: pytest.FixtureRequest): assert (_output_dir / (_basename + ".gpkg")).is_file() @pytest.mark.skip(reason="TODO: Needs to define GeoDataFrame dummydata.") - def test_export_to_pickle(self, request: pytest.FixtureRequest): + def test_export_to_pickle(self, test_result_param_case: Path): # 1. Define test data. - _output_dir = test_results / request.node.name + _output_dir = test_result_param_case if _output_dir.is_dir(): shutil.rmtree(_output_dir) _basename = "dummy_test" - _exporter = GeoDataFrameNetworkExporter(_basename, ["pickle", "gpkg"]) + _exporter = GeoDataFrameNetworkExporter( + basename=_basename, export_types=["pickle", "gpkg"] + ) _export_data = GeoDataFrame() diff --git a/tests/network/exporters/test_multi_graph_network_exporter.py b/tests/network/exporters/test_multi_graph_network_exporter.py index 266c0e17f..31c807037 100644 --- a/tests/network/exporters/test_multi_graph_network_exporter.py +++ b/tests/network/exporters/test_multi_graph_network_exporter.py @@ -12,7 +12,9 @@ class TestMultigraphNetworkExporter: def test_initialize(self): - _exporter = MultiGraphNetworkExporter("_basename", ["pickle", "gpkg"]) + _exporter = MultiGraphNetworkExporter( + basename="_basename", export_types=["pickle", "gpkg"] + ) assert isinstance(_exporter, MultiGraphNetworkExporter) assert isinstance(_exporter, NetworkExporterBase) assert isinstance(_exporter, Ra2ceExporterProtocol) @@ -23,7 +25,9 @@ def test_export_to_gpkg_creates_dir(self, request: pytest.FixtureRequest): """ # 1. Define test data. _basename = "dummy_test" - _exporter = MultiGraphNetworkExporter(_basename, ["pickle", "gpkg"]) + _exporter = MultiGraphNetworkExporter( + basename=_basename, export_types=["pickle", "gpkg"] + ) _test_dir = test_results / request.node.name if _test_dir.is_dir(): shutil.rmtree(_test_dir) @@ -42,7 +46,9 @@ def test_export_to_gpkg_creates_dir(self, request: pytest.FixtureRequest): def test_export_to_pickle(self, request: pytest.FixtureRequest): # 1. Define test data. _basename = "dummy_test" - _exporter = MultiGraphNetworkExporter(_basename, ["pickle", "gpkg"]) + _exporter = MultiGraphNetworkExporter( + basename=_basename, export_types=["pickle", "gpkg"] + ) _test_dir = test_results / request.node.name if _test_dir.is_dir(): shutil.rmtree(_test_dir) diff --git a/tests/network/exporters/test_network_exporter_base.py b/tests/network/exporters/test_network_exporter_base.py index 94ca1503f..fc6121eb9 100644 --- a/tests/network/exporters/test_network_exporter_base.py +++ b/tests/network/exporters/test_network_exporter_base.py @@ -7,10 +7,10 @@ class TestNetworkExporterBase: def test_initialize(self): - _exporter_base = NetworkExporterBase("a_name", []) + _exporter_base = NetworkExporterBase(basename="a_name", export_types=[]) assert isinstance(_exporter_base, NetworkExporterBase) assert isinstance(_exporter_base, Ra2ceExporterProtocol) - assert _exporter_base._export_types == [] + assert _exporter_base.export_types == [] @pytest.mark.parametrize("export_type", [("pickle"), ("gpkg")]) def test_export_data(self, export_type: str, request: pytest.FixtureRequest): @@ -18,7 +18,9 @@ def test_export_data(self, export_type: str, request: pytest.FixtureRequest): _output_dir = test_results / request.node.name # 2. Run test. - _exporter_base = NetworkExporterBase("a_name", [export_type]) + _exporter_base = NetworkExporterBase( + basename="a_name", export_types=[export_type] + ) _result = _exporter_base.export(_output_dir, None) # 3. Verify expectations. diff --git a/tests/network/test_networks.py b/tests/network/test_networks.py index 9de008892..13c40cfb7 100644 --- a/tests/network/test_networks.py +++ b/tests/network/test_networks.py @@ -105,7 +105,7 @@ def test_network_creation( # 3. Then verify expectations. def validate_file(filename: str): - _graph_file = _output_graph_dir / filename + _graph_file = _output_graph_dir.joinpath(filename) return _graph_file.is_file() and _graph_file.exists() assert isinstance(_network_controller, GraphFilesCollection) diff --git a/tests/test_main.py b/tests/test_main.py index 92b583ffc..9ba8490c1 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -228,16 +228,17 @@ def _verify_file(filepath: Path) -> bool: return filepath.exists() and filepath.is_file() # Graph files - assert all(_verify_file(_graph_dir / _f) for _f in expected_graph_files) + assert all(_verify_file(_graph_dir.joinpath(_f)) for _f in expected_graph_files) # Analysis files - assert all( - list( - chain( - *( - list(map(lambda x: _verify_file(_analysis_dir / k / x), v)) - for k, v in expected_analysis_files.items() - ) + _not_generated_files = [] + for _subdir_name, _subdir_files in expected_analysis_files.items(): + _not_generated_files.extend( + filter( + lambda x: not _verify_file(_analysis_dir.joinpath(_subdir_name, x)), + _subdir_files, ) ) - ) + _err_mssg = ", ".join(_not_generated_files) + if any(_not_generated_files): + pytest.fail(f"The following expected files were not generated: {_err_mssg}")