Skip to content

Commit

Permalink
Merge pull request #267 from Deltares/feature/266-streamline-handling…
Browse files Browse the repository at this point in the history
…-of-graph-files

Feature/266 streamline handling of graph files
  • Loading branch information
Carsopre authored Nov 24, 2023
2 parents 01b5818 + dbdf9cd commit 23f6f28
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 76 deletions.
3 changes: 0 additions & 3 deletions ra2ce/configuration/config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,6 @@ def get_analysis_config_without_network(
# Read existing files and graphs from static folder
_static_dir = config_data.static_path
if _static_dir and _static_dir.is_dir():
_network_config.files = NetworkConfigWrapper.get_existent_network_files(
_static_dir.joinpath("output_graph")
)
_network_config.graph_files = NetworkConfigWrapper.read_graphs_from_config(
_static_dir.joinpath("output_graph")
)
Expand Down
11 changes: 9 additions & 2 deletions ra2ce/graph/graph_files/graph_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

@dataclass
class GraphFile(GraphFileProtocol):
"""
Note this class resembles NetworkFile to a large extent
"""

name: str = ""
folder: Path = None
graph: MultiGraph = None
Expand All @@ -20,8 +24,11 @@ def file(self) -> Path | None:
return self.folder.joinpath(self.name)

def read_graph(self, folder: Path) -> None:
self.folder = folder
if self.file and self.file.is_file():
if not folder:
return
_file = folder.joinpath(self.name)
if _file and _file.is_file():
self.folder = folder
_pickle_reader = GraphPickleReader()
self.graph = _pickle_reader.read(self.file)

Expand Down
20 changes: 10 additions & 10 deletions ra2ce/graph/graph_files/graph_files_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

@dataclass
class GraphFilesCollection:
"""
Class containing a collection of the graphs and the paths to the files.
The names of the graph file are assumed to be standardized (e.g. "base_graph.p").
"""

base_graph: GraphFile = field(
default_factory=lambda: GraphFile(name="base_graph.p")
)
Expand Down Expand Up @@ -91,8 +96,8 @@ def get_graph(self, graph_file_type: str) -> MultiGraph | GeoDataFrame:
Returns:
GraphFileProtocol: Graph of that specific graph_file_type
"""
_graph_file = self._get_graph_file(graph_file_type)
return _graph_file.graph
_gf = self._get_graph_file(graph_file_type)
return _gf.get_graph()

def get_file(self, graph_file_type: str) -> Path | None:
"""
Expand All @@ -104,8 +109,8 @@ def get_file(self, graph_file_type: str) -> Path | None:
Returns:
Path: Path to the graph file
"""
_graph_file = self._get_graph_file(graph_file_type)
return _graph_file.file
_gf = self._get_graph_file(graph_file_type)
return _gf.file

@classmethod
def set_files(cls, parent_dir: Path) -> GraphFilesCollection:
Expand Down Expand Up @@ -139,12 +144,7 @@ def set_file(self, file: Path) -> None:
Raises:
ValueError: If the graph_file_type is not one of the known types
"""
_gf = next(
(gf for gf in self._graph_collection if gf.name == file.name),
None,
)
if _gf is None:
raise ValueError(f"Unknown graph file {file} provided.")
_gf = self._get_graph_file(file.stem)
_gf.folder = file.parent

def set_graph(self, graph_file_type: str, graph: MultiGraph | GeoDataFrame):
Expand Down
11 changes: 9 additions & 2 deletions ra2ce/graph/graph_files/network_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

@dataclass
class NetworkFile(GraphFileProtocol):
"""
Note this class resembles GraphFile to a large extent
"""

name: str = ""
folder: Path = None
graph: GeoDataFrame = None
Expand All @@ -19,8 +23,11 @@ def file(self) -> Path | None:
return self.folder.joinpath(self.name)

def read_graph(self, folder: Path) -> None:
self.folder = folder
if self.file and self.file.is_file():
if not folder:
return
_file = folder.joinpath(self.name)
if _file and _file.is_file():
self.folder = folder
self.graph = read_feather(self.file)

def get_graph(self) -> GeoDataFrame:
Expand Down
40 changes: 10 additions & 30 deletions ra2ce/graph/hazard/hazard_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,19 +570,10 @@ def create(self):
"Check your network folder."
)

# Iterate over the three graph/network types to load the file if necessary (when not yet loaded in memory).
for input_graph in ["base_graph", "base_network", "origins_destinations_graph"]:
file_path = self.graph_files.get_file(input_graph)
if (
file_path is not None
and self.graph_files.get_graph(input_graph) is None
):
self.graph_files.read_graph(file_path)

#### Step 1: hazard overlay of the base graph (NetworkX) ###
if self.graph_files.base_graph.file:
if self.graph_files.base_graph_hazard.file is None:
graph = self.graph_files.base_graph.graph
graph = self.graph_files.base_graph.get_graph()

# Check if the graph needs to be reprojected
hazard_crs = pyproj.CRS.from_user_input(self._hazard_crs)
Expand Down Expand Up @@ -623,17 +614,6 @@ def create(self):

# Save graphs/network with hazard
self._export_network_files("base_graph_hazard", types_to_export)
else:
_hazard_base_graph = self._output_graph_dir.joinpath(
"base_graph_hazard.p"
)
# Try to find the base graph hazard file
self.graph_files.base_graph_hazard.read_graph(_hazard_base_graph.parent)
if not self.graph_files.base_graph_hazard.get_graph():
# File not found
logging.warning(
f"Base graph hazard file not found at {_hazard_base_graph}"
)

#### Step 2: hazard overlay of the origins_destinations (NetworkX) ###
if (
Expand All @@ -642,7 +622,7 @@ def create(self):
and self._destinations
and (not self.graph_files.origins_destinations_graph_hazard.file)
):
graph = self.graph_files.origins_destinations_graph.graph
graph = self.graph_files.origins_destinations_graph.get_graph()
ods = self.load_origins_destinations()

# Check if the graph needs to be reprojected
Expand Down Expand Up @@ -735,7 +715,7 @@ def create(self):
# Check if the graph needs to be reprojected
hazard_crs = pyproj.CRS.from_user_input(self._hazard_crs)
gdf_crs = pyproj.CRS.from_user_input(
self.graph_files.base_network.graph.crs
self.graph_files.base_network.get_graph().crs
)

if (
Expand All @@ -747,7 +727,7 @@ def create(self):
hazard_crs, gdf_crs
)
)
extent_gdf = self.graph_files.base_network.graph.total_bounds
extent_gdf = self.graph_files.base_network.get_graph().total_bounds
logging.info("Gdf extent before reprojecting: {}".format(extent_gdf))
gdf_reprojected = self.graph_files.base_network.graph.copy().to_crs(
hazard_crs
Expand All @@ -761,18 +741,18 @@ def create(self):
gdf_reprojected = self.hazard_intersect(gdf_reprojected)

# Assign the original geometries to the reprojected raster
original_geometries = self.graph_files.base_network.graph["geometry"]
original_geometries = self.graph_files.base_network.get_graph()[
"geometry"
]
gdf_reprojected["geometry"] = original_geometries
self.graph_files.base_network_hazard.graph = gdf_reprojected.copy()
del gdf_reprojected
else:
# read previously created file
logging.info("Setting 'base_network_hazard' graph.")
if self.graph_files.base_network_hazard.file:
self.graph_files.base_network_hazard.get_graph()
else:
if not self.graph_files.base_network_hazard.file:
self.graph_files.base_network_hazard.graph = self.hazard_intersect(
self.graph_files.base_network.graph
self.graph_files.base_network.get_graph()
)

#### Step 4: hazard overlay of the locations that are checked for isolation ###
Expand All @@ -784,7 +764,7 @@ def create(self):
# get hazard at locations from network based on nearest
logging.info("Get hazard at locations from network.")
locations_hazard = self.get_point_hazard_from_network(
locations, self.graph_files.base_network_hazard.graph
locations, self.graph_files.base_network_hazard.get_graph()
)

_exporter = NetworkExporterFactory()
Expand Down
4 changes: 0 additions & 4 deletions ra2ce/graph/network_config_data/network_config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pyproj import CRS

from ra2ce.common.configuration.config_data_protocol import ConfigDataProtocol
from ra2ce.graph.graph_files.graph_files_collection import GraphFilesCollection


@dataclass
Expand Down Expand Up @@ -98,9 +97,6 @@ class NetworkConfigData(ConfigDataProtocol):
crs: CRS = field(default_factory=lambda: CRS.from_user_input(4326))
project: ProjectSection = field(default_factory=lambda: ProjectSection())
network: NetworkSection = field(default_factory=lambda: NetworkSection())
graph_files: GraphFilesCollection = field(
default_factory=lambda: GraphFilesCollection()
)
origins_destinations: OriginsDestinationsSection = field(
default_factory=lambda: OriginsDestinationsSection()
)
Expand Down
24 changes: 8 additions & 16 deletions ra2ce/graph/network_config_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from __future__ import annotations

import logging
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -63,24 +62,17 @@ def from_data(
_new_network_config = cls()
_new_network_config.ini_file = ini_file
_new_network_config.config_data = config_data
if config_data.output_graph_dir and config_data.output_graph_dir.is_dir():
_new_network_config.graph_files = (
_new_network_config.get_existent_network_files(
config_data.output_graph_dir
if config_data.output_graph_dir:
if config_data.output_graph_dir.is_dir():
_new_network_config.graph_files = (
_new_network_config.read_graphs_from_config(
config_data.output_graph_dir
)
)
)
else:
logging.error(
f"Graph dir not found. Value provided: {config_data.output_graph_dir}"
)
else:
config_data.output_graph_dir.mkdir(parents=True)
return _new_network_config

@staticmethod
def get_existent_network_files(output_graph_dir: Path) -> GraphFilesCollection:
"""Checks if file of graph exist in network folder and adds filename to the graph object"""
_graph_files = GraphFilesCollection()
return _graph_files.set_files(output_graph_dir)

@staticmethod
def read_graphs_from_config(static_output_dir: Path) -> GraphFilesCollection:
if not static_output_dir.exists():
Expand Down
6 changes: 3 additions & 3 deletions ra2ce/graph/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _get_stored_network_and_graph(self) -> None:
+ "Ra2ce will use this: {}".format(base_graph_filepath)
)

def read_graph(
def get_graph(
file_type: str, file_path: Path | None
) -> nx.MultiGraph | GeoDataFrame:
graph = self.graph_files.get_graph(file_type)
Expand All @@ -198,8 +198,8 @@ def read_graph(
)
return graph

_base_graph = read_graph("base_graph", base_graph_filepath)
_network_gdf = read_graph("base_network", base_network_filepath)
_base_graph = get_graph("base_graph", base_graph_filepath)
_network_gdf = get_graph("base_network", base_network_filepath)

# Assuming the same CRS for both the network and graph
self.base_graph_crs = _network_gdf.crs
Expand Down
8 changes: 3 additions & 5 deletions ra2ce/ra2ce_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,9 @@
from shapely.errors import ShapelyDeprecationWarning

from ra2ce.analyses.analysis_config_data.analysis_config_data import AnalysisConfigData
from ra2ce.analyses.analysis_config_wrapper import (
AnalysisConfigWrapper,
)
from ra2ce.configuration.config_factory import ConfigFactory
from ra2ce.configuration.config_wrapper import ConfigWrapper
from ra2ce.graph.network_config_data.network_config_data import NetworkConfigData
from ra2ce.graph.network_config_wrapper import NetworkConfigWrapper
from ra2ce.ra2ce_logging import Ra2ceLogger
from ra2ce.runners import AnalysisRunnerFactory

Expand All @@ -49,7 +45,7 @@


class Ra2ceHandler:
input_config: Optional[ConfigWrapper] = None
input_config: ConfigWrapper

def __init__(self, network: Optional[Path], analysis: Optional[Path]) -> None:
self._initialize_logger(network, analysis)
Expand All @@ -76,6 +72,8 @@ def run_analysis(self) -> None:
"""
Runs a Ra2ce analysis based on the provided network and analysis files.
"""
if not self.input_config.analysis_config:
return
if not self.input_config.is_valid_input():
_error = "Error validating input files. Ra2ce will close now."
logging.error(_error)
Expand Down
29 changes: 28 additions & 1 deletion tests/graph/graph_files/test_graph_files_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def test_set_file_given_unknown_graph_file_raises_value_error(self):
_collection.set_file(_filepath)

# 3. Verify expectations.
assert str(exc_err.value) == f"Unknown graph file {_filepath} provided."
assert (
str(exc_err.value) == f"Unknown graph file type {_filepath.stem} provided."
)

def test_set_files(self):
# 1. Define test data
Expand All @@ -52,6 +54,31 @@ def test_set_files(self):
# 3. Verify results
assert _collection.base_graph.file == _file

def test_get_file(self):
# 1. Define test data
_type = "base_graph"
_file = test_data.joinpath("readers_test_data", f"{_type}.p")
_collection = GraphFilesCollection()
_collection.set_file(_file)

# 2. Execute test
_f = _collection.get_file(_type)

# 3. Verify results
assert _f == _file

def test_get_file_given_unknown_type_raises_value_error(self):
# 1. Define test data
_type = "unknown_type"
_collection = GraphFilesCollection()

# 2. Execute test
with pytest.raises(ValueError) as exc_err:
_f = _collection.get_file(_type)

# 3. Verify results
assert str(exc_err.value) == f"Unknown graph file type {_type} provided."

def test_read_graph(self):
# 1. Define test data
_folder = test_data.joinpath("readers_test_data")
Expand Down

0 comments on commit 23f6f28

Please sign in to comment.