Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Create adaptation runner #634

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ra2ce/analysis/analysis_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Type

from ra2ce.analysis.adaptation.adaptation import Adaptation
from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper
from ra2ce.analysis.analysis_factory import AnalysisFactory
from ra2ce.analysis.analysis_protocol import AnalysisProtocol
from ra2ce.analysis.damages.analysis_damages_protocol import AnalysisDamagesProtocol
from ra2ce.analysis.losses.analysis_losses_protocol import AnalysisLossesProtocol

Expand All @@ -33,7 +36,7 @@
class AnalysisCollection:
damages_analyses: list[AnalysisDamagesProtocol] = field(default_factory=list)
losses_analyses: list[AnalysisLossesProtocol] = field(default_factory=list)
adaptation_analysis: AnalysisDamagesProtocol = None
adaptation_analysis: Adaptation = None

@classmethod
def from_config(cls, analysis_config: AnalysisConfigWrapper) -> AnalysisCollection:
Expand Down
9 changes: 6 additions & 3 deletions ra2ce/network/origins_destinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ def read_origin_destination_files(

for dp, dn in zip(destination_paths, destination_names):
destination_new = gpd.read_file(dp, crs=crs_, engine="pyogrio")
try:
assert destination_new[od_id]
except Exception:
if not destination_new[od_id].any():
logging.warning(
"No destination found at %s for %s, using default index instead.".format(
dp, od_id
)
)
destination_new[od_id] = destination_new.index
destination_new = destination_new[destination_columns_add]
destination_new["d_id"] = dn + "_" + destination_new[od_id].astype(str)
Expand Down
3 changes: 1 addition & 2 deletions ra2ce/ra2ce_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def run_analysis(self) -> list[AnalysisResultWrapper]:
logging.error(_error)
raise ValueError(_error)

_runner = AnalysisRunnerFactory.get_runner(self.input_config)
return _runner.run(self.input_config.analysis_config)
return AnalysisRunnerFactory.run(self.input_config)

@staticmethod
def run_with_ini_files(
Expand Down
5 changes: 3 additions & 2 deletions ra2ce/runners/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

This module contains all the available runners that can be used in this tool.

Each runner is specific for a given analysis type, to decide which one should be picked we make use of a _factory_ (`AnalysisRunnerFactory`).
Each runner is specific for a given analysis type, to decide which one should be picked we make use of a _factory_ (`AnalysisRunnerFactory`). In addition, you can directly run all available (and supported) analyses for a given configuration (`ConfigWrapper`) simply by doing `AnalysisRunnerFactory.run(config)`.

The result of an analysis runner __execution__ will be an `AnalysisResultWrapper`, an object containing information of the type of analysis ran and its result.
The result of an analysis runner __execution__ will be a collection of `AnalysisResultWrapper`, an object containing information of the type of analysis ran and its result.

# How to add new analysis?
* Create your own runner which should implement the `AnalysisRunnerProtocol`.
* Define in the run method how the analysis should be run.
* If you require extra arguments try using dependency injection while creating the object.
* You may use the `SimpleAnalysisRunnerBase` to avoid code duplication.
* Implement its selection criteria in the `AnalysisRunnerFactory`.
25 changes: 25 additions & 0 deletions ra2ce/runners/adaptation_analysis_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ra2ce.analysis.adaptation.adaptation import Adaptation
from ra2ce.analysis.analysis_collection import AnalysisCollection
from ra2ce.configuration.config_wrapper import ConfigWrapper
from ra2ce.runners.damages_analysis_runner import DamagesAnalysisRunner
from ra2ce.runners.simple_analysis_runner_base import SimpleAnalysisRunnerBase


class AdaptationAnalysisRunner(SimpleAnalysisRunnerBase):
def __str__(self):
return "Adaptation Analysis Runner"

@staticmethod
def can_run(ra2ce_input: ConfigWrapper) -> bool:
if (
not ra2ce_input.analysis_config
or not ra2ce_input.analysis_config.config_data.adaptation
):
return False
return DamagesAnalysisRunner.can_run(ra2ce_input)
ArdtK marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def filter_supported_analyses(
analysis_collection: AnalysisCollection,
) -> list[Adaptation]:
return [analysis_collection.adaptation_analysis]
37 changes: 30 additions & 7 deletions ra2ce/runners/analysis_runner_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,36 @@
"""

import logging
from typing import Type

from ra2ce.analysis.analysis_collection import AnalysisCollection
from ra2ce.analysis.analysis_result.analysis_result_wrapper import AnalysisResultWrapper
from ra2ce.configuration.config_wrapper import ConfigWrapper
from ra2ce.runners.adaptation_analysis_runner import AdaptationAnalysisRunner
from ra2ce.runners.analysis_runner_protocol import AnalysisRunner
from ra2ce.runners.damages_analysis_runner import DamagesAnalysisRunner
from ra2ce.runners.losses_analysis_runner import LossesAnalysisRunner


class AnalysisRunnerFactory:
@staticmethod
def get_runner(ra2ce_input: ConfigWrapper) -> AnalysisRunner:
def get_supported_runners(ra2ce_input: ConfigWrapper) -> list[Type[AnalysisRunner]]:
"""
Gets the supported first analysis runner for the given input. The runner is initialized within this method.
Gets the supported analysis runners for the given input.

Args:
ra2ce_input (Ra2ceInput): Input representing a set of network and analysis ini configurations.

Returns:
AnalysisRunner: Initialized Ra2ce analysis runner.
list[AnalysisRunner]: Supported runners for the given configuration.
"""
_available_runners = [DamagesAnalysisRunner, LossesAnalysisRunner]
if not ra2ce_input.analysis_config:
return []
_available_runners = [
DamagesAnalysisRunner,
LossesAnalysisRunner,
AdaptationAnalysisRunner,
]
_supported_runners = [
_runner for _runner in _available_runners if _runner.can_run(ra2ce_input)
]
Expand All @@ -49,7 +59,20 @@ def get_runner(ra2ce_input: ConfigWrapper) -> AnalysisRunner:
raise ValueError(_err_mssg)

# Initialized selected supported runner (First one available).
_selected_runner = _supported_runners[0]()
if len(_supported_runners) > 1:
logging.warning(f"More than one runner available, using {_selected_runner}")
return _selected_runner
logging.warning(
"More than one runner available, computation time could be longer than expected."
)
return _supported_runners

@staticmethod
def run(ra2ce_input: ConfigWrapper) -> list[AnalysisResultWrapper]:
_supported_runners = AnalysisRunnerFactory.get_supported_runners(ra2ce_input)
_analysis_collection = AnalysisCollection.from_config(
ra2ce_input.analysis_config
)
_results = []
for _runner_type in _supported_runners:
_run_results = _runner_type().run(_analysis_collection)
_results.extend(_run_results)
return _results
8 changes: 4 additions & 4 deletions ra2ce/runners/analysis_runner_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from typing import Protocol, runtime_checkable

from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper
from ra2ce.analysis.analysis_collection import AnalysisCollection
from ra2ce.analysis.analysis_result.analysis_result_wrapper_protocol import (
AnalysisResultWrapperProtocol,
)
Expand All @@ -43,13 +43,13 @@ def can_run(ra2ce_input: ConfigWrapper) -> bool:
"""

def run(
self, analysis_config: AnalysisConfigWrapper
self, analysis_collection: AnalysisCollection
) -> list[AnalysisResultWrapperProtocol]:
"""
Runs this `AnalysisRunner` with the given analysis configuration.
Runs this `AnalysisRunner` for the given analysis collection.

Args:
analysis_config (AnalysisConfigWrapper): Analysis configuration representation to be run on this `AnalysisRunner`.
analysis_collection (AnalysisCollection): Collection of analyses to be run on this `AnalysisRunner`.

Returns:
list[AnalysisResultWrapperProtocol]: List of all results for all ran analysis.
Expand Down
39 changes: 9 additions & 30 deletions ra2ce/runners/damages_analysis_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,23 @@
"""

import logging
import time

from ra2ce.analysis.analysis_collection import AnalysisCollection
from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper
from ra2ce.analysis.analysis_result.analysis_result_wrapper_exporter import (
AnalysisResultWrapperExporter,
)
from ra2ce.analysis.damages.damages_result_wrapper import DamagesResultWrapper
from ra2ce.analysis.damages.analysis_damages_protocol import AnalysisDamagesProtocol
from ra2ce.configuration.config_wrapper import ConfigWrapper
from ra2ce.runners.analysis_runner_protocol import AnalysisRunner
from ra2ce.runners.simple_analysis_runner_base import SimpleAnalysisRunnerBase


class DamagesAnalysisRunner(AnalysisRunner):
class DamagesAnalysisRunner(SimpleAnalysisRunnerBase):
def __str__(self) -> str:
return "Damages Analysis Runner"

@staticmethod
def filter_supported_analyses(
analysis_collection: AnalysisCollection,
) -> list[AnalysisDamagesProtocol]:
return analysis_collection.damages_analyses

@staticmethod
def can_run(ra2ce_input: ConfigWrapper) -> bool:
if (
Expand All @@ -52,25 +53,3 @@ def can_run(ra2ce_input: ConfigWrapper) -> bool:
)
return False
return True

def run(self, analysis_config: AnalysisConfigWrapper) -> list[DamagesResultWrapper]:
_analysis_collection = AnalysisCollection.from_config(analysis_config)
_results = []
for analysis in _analysis_collection.damages_analyses:
logging.info(
"----------------------------- Started analyzing '%s' -----------------------------",
analysis.analysis.name,
)
starttime = time.time()

_result_wrapper = analysis.execute()
AnalysisResultWrapperExporter().export_result(_result_wrapper)

endtime = time.time()
logging.info(
"----------------------------- Analysis '%s' finished. "
"Time: %ss -----------------------------",
analysis.analysis.name,
str(round(endtime - starttime, 2)),
)
return _results
45 changes: 9 additions & 36 deletions ra2ce/runners/losses_analysis_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,22 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

import logging
import time

from ra2ce.analysis.analysis_collection import AnalysisCollection
from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper
from ra2ce.analysis.analysis_result.analysis_result_wrapper import AnalysisResultWrapper
from ra2ce.analysis.analysis_result.analysis_result_wrapper_exporter import (
AnalysisResultWrapperExporter,
)
from ra2ce.analysis.losses.analysis_losses_protocol import AnalysisLossesProtocol
from ra2ce.configuration.config_wrapper import ConfigWrapper
from ra2ce.runners.analysis_runner_protocol import AnalysisRunner
from ra2ce.runners.simple_analysis_runner_base import SimpleAnalysisRunnerBase


class LossesAnalysisRunner(AnalysisRunner):
class LossesAnalysisRunner(SimpleAnalysisRunnerBase):
def __str__(self) -> str:
return "Losses Analysis Runner"

@staticmethod
def filter_supported_analyses(
analysis_collection: AnalysisCollection,
) -> list[AnalysisLossesProtocol]:
return analysis_collection.losses_analyses

@staticmethod
def can_run(ra2ce_input: ConfigWrapper) -> bool:
if (
Expand All @@ -44,28 +42,3 @@ def can_run(ra2ce_input: ConfigWrapper) -> bool:
):
return False
return True

def run(
self, analysis_config: AnalysisConfigWrapper
) -> list[AnalysisResultWrapper]:
_analysis_collection = AnalysisCollection.from_config(analysis_config)
_results = []
for analysis in _analysis_collection.losses_analyses:
logging.info(
"----------------------------- Started analyzing '%s' -----------------------------",
analysis.analysis.name,
)
starttime = time.time()

_result_wrapper = analysis.execute()
_results.append(_result_wrapper)
AnalysisResultWrapperExporter().export_result(_result_wrapper)

endtime = time.time()
logging.info(
"----------------------------- Analysis '%s' finished. "
"Time: %ss -----------------------------",
analysis.analysis.name,
str(round(endtime - starttime, 2)),
)
return _results
84 changes: 84 additions & 0 deletions ra2ce/runners/simple_analysis_runner_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007

Risk Assessment and Adaptation for Critical Infrastructure (RA2CE).
Copyright (C) 2023 Stichting Deltares

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

import logging
import time
from abc import abstractmethod

from ra2ce.analysis.analysis_collection import AnalysisCollection
from ra2ce.analysis.analysis_protocol import AnalysisProtocol
from ra2ce.analysis.analysis_result.analysis_result_wrapper import AnalysisResultWrapper
from ra2ce.analysis.analysis_result.analysis_result_wrapper_exporter import (
AnalysisResultWrapperExporter,
)
from ra2ce.configuration.config_wrapper import ConfigWrapper
from ra2ce.runners.analysis_runner_protocol import AnalysisRunner


class SimpleAnalysisRunnerBase(AnalysisRunner):
@abstractmethod
def __str__(self) -> str:
raise NotImplementedError()

@staticmethod
@abstractmethod
def filter_supported_analyses(
analysis_collection: AnalysisCollection,
) -> list[AnalysisProtocol]:
"""
Gets the supported analysis for a concrete runner.

Args:
analysis_collection (AnalysisCollection): Collection of analyses to filter.

Returns:
list[AnalysisProtocol]: Supported analyses from the provided collection.
"""
raise NotImplementedError()

@staticmethod
@abstractmethod
def can_run(ra2ce_input: ConfigWrapper) -> bool:
raise NotImplementedError()

def run(
self, analysis_collection: AnalysisCollection
) -> list[AnalysisResultWrapper]:
_results = []
for analysis in self.filter_supported_analyses(analysis_collection):
logging.info(
"----------------------------- Started analyzing '%s' -----------------------------",
analysis.analysis.name,
)
starttime = time.time()

_result_wrapper = analysis.execute()
_results.append(_result_wrapper)
AnalysisResultWrapperExporter().export_result(_result_wrapper)

endtime = time.time()
logging.info(
"----------------------------- Analysis '%s' finished. "
"Time: %ss -----------------------------",
analysis.analysis.name,
str(round(endtime - starttime, 2)),
)
return _results
Loading