Skip to content

Commit

Permalink
Merge pull request #269 from Deltares/feature/268-add-enums-for-netwo…
Browse files Browse the repository at this point in the history
…rk-config-settings

feature: add enums for network config settings
  • Loading branch information
ArdtK authored Nov 27, 2023
2 parents 3b5a3f9 + a550d02 commit 2daa93e
Show file tree
Hide file tree
Showing 25 changed files with 381 additions and 91 deletions.
6 changes: 4 additions & 2 deletions ra2ce/analyses/analysis_config_data/analysis_config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from pathlib import Path
from typing import Optional

from ra2ce.analyses.analysis_config_data.enums.weighing_enum import WeighingEnum
from ra2ce.common.configuration.config_data_protocol import ConfigDataProtocol
from ra2ce.graph.network_config_data.enums.aggregate_wl_enum import AggregateWlEnum
from ra2ce.graph.network_config_data.network_config_data import (
NetworkSection,
OriginsDestinationsSection,
Expand Down Expand Up @@ -76,7 +78,7 @@ class AnalysisSectionIndirect(AnalysisSectionBase):
"""

# general
weighing: str = "" # should be enum
weighing: WeighingEnum = field(default_factory=lambda: WeighingEnum.NONE)
loss_per_distance: str = ""
loss_type: str = "" # should be enum
disruption_per_category: str = ""
Expand All @@ -90,7 +92,7 @@ class AnalysisSectionIndirect(AnalysisSectionBase):
maximum_jam: float = math.nan
partofday: str = ""
# accessiblity analyses
aggregate_wl: str = "" # should be enum
aggregate_wl: AggregateWlEnum = field(default_factory=lambda: AggregateWlEnum.NONE)
threshold: float = math.nan
threshold_destinations: float = math.nan
uniform_duration: float = math.nan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
IndirectAnalysisNameList,
ProjectSection,
)
from ra2ce.analyses.analysis_config_data.enums.weighing_enum import WeighingEnum
from ra2ce.common.configuration.ini_configuration_reader_protocol import (
ConfigDataReaderProtocol,
)
Expand Down Expand Up @@ -102,6 +103,12 @@ def _get_analysis_section_indirect(
_section.save_csv = self._parser.getboolean(
section_name, "save_csv", fallback=_section.save_csv
)
_weighing = self._parser.get(section_name, "weighing", fallback=None)
# Map distance -> length
if _weighing == "distance":
_section.weighing = WeighingEnum.LENGTH
else:
_section.weighing = WeighingEnum.get_enum(_weighing)
# losses
_section.traffic_cols = self._parser.getlist(
section_name, "traffic_cols", fallback=_section.traffic_cols
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from ra2ce.common.validation.ra2ce_validator_protocol import Ra2ceIoValidator
from ra2ce.common.validation.validation_report import ValidationReport
from ra2ce.configuration.ra2ce_enum_base import Ra2ceEnumBase
from ra2ce.graph.network_config_data.network_config_data_validator import (
NetworkDictValues,
)
Expand All @@ -49,15 +50,21 @@ def _validate_header(self, header: Any) -> ValidationReport:
for _item in header:
_report.merge(self._validate_header(_item))
else:
# check keys with predescribed values
for key, value in header.__dict__.items():
if not value:
continue
if key not in AnalysisNetworkDictValues.keys():
continue
_expected_values_list = AnalysisNetworkDictValues[key]
if isinstance(value, Ra2ceEnumBase):
# enumerations
_expected_values_list = value.list_valid_options()
else:
# other items with limited value options (should become enumerations)
if key not in AnalysisNetworkDictValues.keys():
continue
_expected_values_list = AnalysisNetworkDictValues[key]
if value not in _expected_values_list:
_report.error(
f"Wrong input to property [ {key} ], has to be one of: {_expected_values_list}"
f"Wrong input to property [ {key} ]; has to be one of: {_expected_values_list}."
)

return _report
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions ra2ce/analyses/analysis_config_data/enums/weighing_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ra2ce.configuration.ra2ce_enum_base import Ra2ceEnumBase


class WeighingEnum(Ra2ceEnumBase):
NONE = 0
LENGTH = 1
TIME = 2
INVALID = 99
44 changes: 28 additions & 16 deletions ra2ce/analyses/indirect/analyses_indirect.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AnalysisConfigData,
AnalysisSectionIndirect,
)
from ra2ce.analyses.analysis_config_data.enums.weighing_enum import WeighingEnum
from ra2ce.analyses.indirect.losses import Losses
from ra2ce.analyses.indirect.origin_closest_destination import OriginClosestDestination
from ra2ce.analyses.indirect.traffic_analysis.traffic_analysis_factory import (
Expand Down Expand Up @@ -112,7 +113,7 @@ def single_link_redundancy(
if nx.has_path(graph, u, v):
# calculate the alternative distance if that edge is unavailable
alt_dist = nx.dijkstra_path_length(
graph, u, v, weight=analysis.weighing
graph, u, v, weight=analysis.weighing.config_value
)
alt_dist_list.append(alt_dist)

Expand All @@ -121,7 +122,7 @@ def single_link_redundancy(
alt_nodes_list.append(alt_nodes)

# calculate the difference in distance
dif_dist_list.append(alt_dist - data[analysis.weighing])
dif_dist_list.append(alt_dist - data[analysis.weighing.config_value])

detour_exist_list.append(1)
else:
Expand Down Expand Up @@ -190,7 +191,10 @@ def _single_link_losses_uniform(
# detour_losses = traffic_per_day[veh/day] * detour_distance[meter] * cost_per_meter[USD/meter/vehicle] * duration_disruption[hour] / 24[hour/day]
gdf.loc[
(gdf["detour"] == 1)
& (gdf[hz + "_" + analysis.aggregate_wl] > analysis.threshold),
& (
gdf[hz + "_" + analysis.aggregate_wl.config_value]
> analysis.threshold
),
col + "_detour_losses",
] += (
gdf[col]
Expand All @@ -202,7 +206,10 @@ def _single_link_losses_uniform(
# no_detour_losses = traffic_per_day[veh/day] * occupancy[person/veh] * gdp_percapita_per_day[USD/person] * duration_disruption[hour] / 24[hour/day]
gdf.loc[
(gdf["detour"] == 0)
& (gdf[hz + "_" + analysis.aggregate_wl] > analysis.threshold),
& (
gdf[hz + "_" + analysis.aggregate_wl.config_value]
> analysis.threshold
),
col + "_nodetour_losses",
] += (
gdf[col]
Expand Down Expand Up @@ -246,8 +253,8 @@ def _single_link_losses_categorized(
ub = 1e10
for road_cat in _all_road_categories:
gdf.loc[
(gdf[hz + "_" + analysis.aggregate_wl] > lb)
& (gdf[hz + "_" + analysis.aggregate_wl] <= ub)
(gdf[hz + "_" + analysis.aggregate_wl.config_value] > lb)
& (gdf[hz + "_" + analysis.aggregate_wl.config_value] <= ub)
& (gdf["class_identifier"] == road_cat),
"duration_disruption",
] = disruption_df_.loc[
Expand Down Expand Up @@ -343,7 +350,7 @@ def multi_link_redundancy(

if nx.has_path(graph, u, v):
alt_dist = nx.dijkstra_path_length(
graph, u, v, weight=analysis.weighing
graph, u, v, weight=analysis.weighing.config_value
)
alt_nodes = nx.dijkstra_path(graph, u, v)
connected = 1
Expand Down Expand Up @@ -374,7 +381,9 @@ def multi_link_redundancy(
# previously here you would find if dist == dist which is a critical bug. Replaced by just verifying dist is a value.
gdf["diff_dist"] = [
dist - length if dist else np.NaN
for (dist, length) in zip(gdf["alt_dist"], gdf[analysis.weighing])
for (dist, length) in zip(
gdf["alt_dist"], gdf[analysis.weighing.config_value]
)
]

gdf["hazard"] = hazard_name
Expand Down Expand Up @@ -477,8 +486,11 @@ def multi_link_losses(self, gdf, analysis: AnalysisSectionIndirect):
ub = 1e10
for road_cat in all_road_categories:
gdf_.loc[
(gdf_[hz + "_" + analysis.aggregate_wl] > lb)
& (gdf_[hz + "_" + analysis.aggregate_wl] <= ub)
(gdf_[hz + "_" + analysis.aggregate_wl.config_value] > lb)
& (
gdf_[hz + "_" + analysis.aggregate_wl.config_value]
<= ub
)
& (gdf_["class_identifier"] == road_cat),
"duration_disruption",
] = disruption_df_.loc[
Expand Down Expand Up @@ -583,7 +595,7 @@ def optimal_route_origin_destination(
) -> gpd.GeoDataFrame:
# create list of origin-destination pairs
od_nodes = self._get_origin_destination_pairs(graph)
pref_routes = find_route_ods(graph, od_nodes, analysis.weighing)
pref_routes = find_route_ods(graph, od_nodes, analysis.weighing.config_value)
return pref_routes

def optimal_route_od_link(
Expand Down Expand Up @@ -631,7 +643,9 @@ def multi_link_origin_destination(
# igraph_hz = ig.Graph.from_networkx(igraph_hz)

# Find the routes
od_routes = find_route_ods(graph_hz, od_nodes, analysis.weighing)
od_routes = find_route_ods(
graph_hz, od_nodes, analysis.weighing.config_value
)
od_routes["hazard"] = hazard_name
all_results.append(od_routes)

Expand Down Expand Up @@ -1071,9 +1085,6 @@ def _save_gpkg_analysis(
)
graph_to_gpkg(base_graph, gpkg_path_edges, gpkg_path_nodes)

if analysis.weighing == "distance":
# The name is different in the graph.
analysis.weighing = "length"
if analysis.analysis == "single_link_redundancy":
g = self.graph_files.base_graph.get_graph()
gdf = self.single_link_redundancy(g, analysis)
Expand Down Expand Up @@ -1359,7 +1370,8 @@ def find_route_ods(
# get edge with the lowest weighing if there are multiple edges that connect u and v
_uv_graph = graph[u][v]
edge_key = sorted(
_uv_graph, key=lambda x, _fgraph=_uv_graph: _fgraph[x][weighing]
_uv_graph,
key=lambda x, _fgraph=_uv_graph: _fgraph[x][weighing],
)[0]
_uv_graph_edge = _uv_graph[edge_key]
if "geometry" in _uv_graph_edge:
Expand Down
2 changes: 1 addition & 1 deletion ra2ce/analyses/indirect/origin_closest_destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.unit = "km"
self.network_threshold = analysis.threshold
self.threshold_destinations = analysis.threshold_destinations
self.weighing = analysis.weighing
self.weighing = analysis.weighing.config_value
self.o_name = config.origins_destinations.origins_names
self.d_name = config.origins_destinations.destinations_names
self.od_id = config.origins_destinations.id_name_origin_destination
Expand Down
67 changes: 67 additions & 0 deletions ra2ce/configuration/ra2ce_enum_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from enum import Enum


class Ra2ceEnumBase(Enum):
"""
Base class for enums defined within Ra2ce.
NONE = 0: Optional entry (config is optional and missing)
INVALID = 99: Mandatory entry (config contains invalid value)
"""

@classmethod
def get_enum(cls, input: str | None) -> Ra2ceEnumBase:
"""
Create an enum from a given input string.
Args:
input (str): Value from config.
Returns:
Ra2ceEnumBase: Enumeration instance.
NONE: This entry is used if the config is missing.
INVALID: This entry is used if the config value is invalid.
"""
try:
if not input:
return cls.NONE
return cls[input.upper().strip()]
except (AttributeError, KeyError):
return cls.INVALID

def is_valid(self) -> bool:
"""
Check if given value is valid.
Args:
key (str): Enum key (name)
Returns:
bool: If the given key is not a valid key
"""
if self.name == "INVALID":
return False
return True

def list_valid_options(self) -> list[Ra2ceEnumBase]:
"""
List the enum options as allowed in the config.
Returns:
list[str | None]: Concatenated options, separated by ", "
"""
return [_enum for _enum in type(self)][:-1]

@property
def config_value(self) -> str | None:
"""
Reconstruct the name as it is known in the config.
This could entail replacement of " " by "_" and lower() operations.
Returns:
str: Value as known in the config.
"""
if self.name == "NONE":
return None
return self.name.lower()
2 changes: 1 addition & 1 deletion ra2ce/graph/hazard/hazard_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self._hazard_id = config.hazard.hazard_id
self._hazard_map = config.hazard.hazard_map
self._hazard_crs = config.hazard.hazard_crs
self._hazard_aggregate_wl = config.hazard.aggregate_wl
self._hazard_aggregate_wl = config.hazard.aggregate_wl.config_value
self._hazard_directory = config.static_path.joinpath("hazard")

# graph files
Expand Down
Empty file.
9 changes: 9 additions & 0 deletions ra2ce/graph/network_config_data/enums/aggregate_wl_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ra2ce.configuration.ra2ce_enum_base import Ra2ceEnumBase


class AggregateWlEnum(Ra2ceEnumBase):
NONE = 0
MIN = 1
MAX = 2
MEAN = 3
INVALID = 99
11 changes: 11 additions & 0 deletions ra2ce/graph/network_config_data/enums/network_type_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ra2ce.configuration.ra2ce_enum_base import Ra2ceEnumBase


class NetworkTypeEnum(Ra2ceEnumBase):
NONE = 0
WALK = 1
BIKE = 2
DRIVE = 3
DRIVE_SERVICE = 4
ALL = 5
INVALID = 99
19 changes: 19 additions & 0 deletions ra2ce/graph/network_config_data/enums/road_type_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ra2ce.configuration.ra2ce_enum_base import Ra2ceEnumBase


class RoadTypeEnum(Ra2ceEnumBase):
NONE = 0
MOTORWAY = 1
MOTORWAY_LINK = 2
TRUNK = 3
TRUNK_LINK = 4
PRIMARY = 5
PRIMARY_LINK = 6
SECONDARY = 7
SECONDARY_LINK = 8
TERTIARY = 9
TERTIARY_LINK = 10
RESIDENTIAL = 11
ROAD = 12
UNCLASSIFIED = 98
INVALID = 99
25 changes: 25 additions & 0 deletions ra2ce/graph/network_config_data/enums/source_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from ra2ce.configuration.ra2ce_enum_base import Ra2ceEnumBase


class SourceEnum(Ra2ceEnumBase):
OSB_BPF = 1
OSM_DOWNLOAD = 2
SHAPEFILE = 3
PICKLE = 4
INVALID = 99

@classmethod
def get_enum(cls, input: str) -> SourceEnum:
try:
return cls[input.replace(" ", "_").upper()]
except KeyError:
return cls.INVALID

@property
def config_value(self) -> str:
_parts = self.name.split("_")
return " ".join(
[_part if len(_part) == 3 else _part.lower() for _part in _parts]
)
Loading

0 comments on commit 2daa93e

Please sign in to comment.