Skip to content

Commit

Permalink
Merge pull request #308 from Deltares/feature/179-make-optional-expor…
Browse files Browse the repository at this point in the history
…t-of-files-in-osmnetworkwrapper

Feature/179 make optional export of files in osmnetworkwrapper
  • Loading branch information
Carsopre authored Mar 8, 2024
2 parents a959f32 + 28930a1 commit 89d9020
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 19 deletions.
4 changes: 2 additions & 2 deletions ra2ce/network/network_wrappers/network_wrapper_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def get_network(self) -> tuple[MultiGraph, GeoDataFrame]:
elif source == SourceEnum.PICKLE:
logging.info("Start importing a network from pickle")
base_graph = GraphPickleReader().read(
self.output_graph_dir.joinpath("base_graph.p")
self._config_data.output_graph_dir.joinpath("base_graph.p")
)
network_gdf = gpd.read_feather(
self.output_graph_dir.joinpath("base_network.feather")
self._config_data.output_graph_dir.joinpath("base_network.feather")
)
return base_graph, network_gdf
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,48 @@ def get_network(self) -> tuple[MultiGraph, GeoDataFrame]:

# Check if all geometries between nodes are there, if not, add them as a straight line.
graph_simple = nut.add_missing_geoms_graph(graph_simple, geom_name="geometry")
graph_simple = self._get_avg_speed(graph_simple)
graph_simple = self._set_avg_speed_to_graph(graph_simple)
return graph_simple, edges_complex

def _get_avg_speed(
def _get_avg_speeds(self, original_graph: nx.classes.graph.Graph) -> pd.DataFrame:
_save_csv = False
_avg_speed_filepath = None
if self.output_graph_dir is not None:
_save_csv = True
_avg_speed_filepath = self.output_graph_dir.joinpath("avg_speed.csv")
if _avg_speed_filepath.is_file():
return pd.read_csv(_avg_speed_filepath)
logging.warning(
"No valid file found with average speeds in {}, calculating and saving them instead.".format(
_avg_speed_filepath
)
)

return nut.calc_avg_speed(
original_graph,
"highway",
save_csv=_save_csv,
save_path=_avg_speed_filepath,
)

def _set_avg_speed_to_graph(
self, original_graph: nx.classes.graph.Graph
) -> nx.classes.graph.Graph:
if all(["length" in e for u, v, e in original_graph.edges.data()]) and any(
["maxspeed" in e for u, v, e in original_graph.edges.data()]
):
# Add time weighing - Define and assign average speeds; or take the average speed from an existing CSV
path_avg_speed = self.output_graph_dir.joinpath("avg_speed.csv")
if path_avg_speed.is_file():
avg_speeds = pd.read_csv(path_avg_speed)
else:
avg_speeds = nut.calc_avg_speed(
original_graph,
"highway",
save_csv=True,
save_path=path_avg_speed,

_length_array, _maxspeed_array = list(
zip(
*(
("length" in e, "maxspeed" in e)
for _, _, e in original_graph.edges.data()
)
original_graph = nut.assign_avg_speed(original_graph, avg_speeds, "highway")
)
)
if all(_length_array) and any(_maxspeed_array):
# Add time weighing - Define and assign average speeds; or take the average speed from an existing CSV
_avg_speeds = self._get_avg_speeds(original_graph)
original_graph = nut.assign_avg_speed(
original_graph, _avg_speeds, "highway"
)

# make a time value of seconds, length of road streches is in meters
for u, v, k, edata in original_graph.edges.data(keys=True):
Expand All @@ -135,6 +156,11 @@ def _get_avg_speed(
return original_graph

def _export_linking_tables(self, linking_tables: tuple[Any]) -> None:
if not self.output_graph_dir:
logging.warning(
"No `output_graph_dir` is set, therefore no intermediate results will be exported."
)
return
_exporter = JsonExporter()
_exporter.export(
self.output_graph_dir.joinpath("simple_to_complex.json"), linking_tables[0]
Expand Down
27 changes: 26 additions & 1 deletion tests/network/network_wrappers/test_osm_network_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path

import networkx as nx
from geopandas import GeoDataFrame
import pytest
from networkx import Graph, MultiDiGraph
from networkx import Graph, MultiDiGraph, MultiGraph
from networkx.utils import graphs_equal
from shapely.geometry import LineString, Polygon
from shapely.geometry.base import BaseGeometry
Expand Down Expand Up @@ -301,3 +302,27 @@ def test_given_valid_base_geometry_with_polygon(
# 3. Verify expectations.
assert isinstance(_wrapper, OsmNetworkWrapper)
assert isinstance(_wrapper.polygon_graph, MultiDiGraph)

@slow_test
def test_given_no_output_graph_dir_when_get_network(self):
# 1. Define test data.
_test_input_directory = test_data.joinpath("graph", "test_osm_network_wrapper")
_polygon_file = _test_input_directory.joinpath("_test_polygon.geojson")
assert _polygon_file.exists()

_network_config_data = self._get_dummy_network_config_data()
_network_config_data.network.polygon = _polygon_file
_network_config_data.network.network_type = NetworkTypeEnum.DRIVE
_network_config_data.network.road_types = []
# `output_graph_dir` is a property indirectly derived from `static_path`.
_network_config_data.static_path = None

# 2. Run test.
_wrapper = OsmNetworkWrapper(_network_config_data)
_result_mg, _result_gdf = _wrapper.get_network()

# 3. Verify expectations.
assert isinstance(_wrapper, OsmNetworkWrapper)
assert isinstance(_wrapper.polygon_graph, MultiDiGraph)
assert isinstance(_result_mg, MultiGraph)
assert isinstance(_result_gdf, GeoDataFrame)

0 comments on commit 89d9020

Please sign in to comment.