From eacc8f1d938f577d03413c09067423842c8fb4f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carles=20S=2E=20Soriano=20P=C3=A9rez?= Date: Tue, 10 Dec 2024 10:15:10 +0100 Subject: [PATCH] feat: Create adaptation runner (#634) * chore: Refactored code to make the runs analysis-independent * test: Fixed failing 'runners' tests * chore: Reworked code to filter analysis in collection * chore: Updated `Ra2ceHandler.run_analysis` to directly call the factory run method * chore: Fixed incorrect call for generating analysis collection * chore: Added new analysis runner * test: Slight refactor of `tests.runners` * test: Adapted tests * test: Fixed failing test * chore: Removed outdated functionality * chore: Small fix --- ra2ce/analysis/analysis_collection.py | 5 +- ra2ce/network/origins_destinations.py | 9 +- ra2ce/ra2ce_handler.py | 3 +- ra2ce/runners/README.md | 5 +- ra2ce/runners/adaptation_analysis_runner.py | 25 ++++++ ra2ce/runners/analysis_runner_factory.py | 37 ++++++-- ra2ce/runners/analysis_runner_protocol.py | 8 +- ra2ce/runners/damages_analysis_runner.py | 39 ++------- ra2ce/runners/losses_analysis_runner.py | 45 ++-------- ra2ce/runners/simple_analysis_runner_base.py | 84 +++++++++++++++++++ tests/analysis/test_analysis_collection.py | 5 ++ .../runners/{dummy_classes.py => conftest.py} | 28 ++++++- .../test_adaptation_analysis_runner.py | 57 +++++++++++++ tests/runners/test_analysis_runner_factory.py | 20 +++-- tests/runners/test_damages_analysis_runner.py | 21 ++--- tests/runners/test_losses_analysis_runner.py | 7 -- .../acceptance_test_data/output/analyses.ini | 81 ------------------ 17 files changed, 283 insertions(+), 196 deletions(-) create mode 100644 ra2ce/runners/adaptation_analysis_runner.py create mode 100644 ra2ce/runners/simple_analysis_runner_base.py rename tests/runners/{dummy_classes.py => conftest.py} (51%) create mode 100644 tests/runners/test_adaptation_analysis_runner.py delete mode 100644 tests/test_data/acceptance_test_data/output/analyses.ini diff --git a/ra2ce/analysis/analysis_collection.py b/ra2ce/analysis/analysis_collection.py index 4eae56807..e2664ca48 100644 --- a/ra2ce/analysis/analysis_collection.py +++ b/ra2ce/analysis/analysis_collection.py @@ -22,9 +22,12 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Type +from ra2ce.analysis.adaptation.adaptation import Adaptation from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper from ra2ce.analysis.analysis_factory import AnalysisFactory +from ra2ce.analysis.analysis_protocol import AnalysisProtocol from ra2ce.analysis.damages.analysis_damages_protocol import AnalysisDamagesProtocol from ra2ce.analysis.losses.analysis_losses_protocol import AnalysisLossesProtocol @@ -33,7 +36,7 @@ class AnalysisCollection: damages_analyses: list[AnalysisDamagesProtocol] = field(default_factory=list) losses_analyses: list[AnalysisLossesProtocol] = field(default_factory=list) - adaptation_analysis: AnalysisDamagesProtocol = None + adaptation_analysis: Adaptation = None @classmethod def from_config(cls, analysis_config: AnalysisConfigWrapper) -> AnalysisCollection: diff --git a/ra2ce/network/origins_destinations.py b/ra2ce/network/origins_destinations.py index e61b23b5e..70b69a209 100644 --- a/ra2ce/network/origins_destinations.py +++ b/ra2ce/network/origins_destinations.py @@ -120,9 +120,12 @@ def read_origin_destination_files( for dp, dn in zip(destination_paths, destination_names): destination_new = gpd.read_file(dp, crs=crs_, engine="pyogrio") - try: - assert destination_new[od_id] - except Exception: + if not destination_new[od_id].any(): + logging.warning( + "No destination found at %s for %s, using default index instead.".format( + dp, od_id + ) + ) destination_new[od_id] = destination_new.index destination_new = destination_new[destination_columns_add] destination_new["d_id"] = dn + "_" + destination_new[od_id].astype(str) diff --git a/ra2ce/ra2ce_handler.py b/ra2ce/ra2ce_handler.py index 3ae7d86e4..5fba5ae81 100644 --- a/ra2ce/ra2ce_handler.py +++ b/ra2ce/ra2ce_handler.py @@ -178,8 +178,7 @@ def run_analysis(self) -> list[AnalysisResultWrapper]: logging.error(_error) raise ValueError(_error) - _runner = AnalysisRunnerFactory.get_runner(self.input_config) - return _runner.run(self.input_config.analysis_config) + return AnalysisRunnerFactory.run(self.input_config) @staticmethod def run_with_ini_files( diff --git a/ra2ce/runners/README.md b/ra2ce/runners/README.md index ece33ab83..63fe3e3dd 100644 --- a/ra2ce/runners/README.md +++ b/ra2ce/runners/README.md @@ -2,12 +2,13 @@ This module contains all the available runners that can be used in this tool. -Each runner is specific for a given analysis type, to decide which one should be picked we make use of a _factory_ (`AnalysisRunnerFactory`). +Each runner is specific for a given analysis type, to decide which one should be picked we make use of a _factory_ (`AnalysisRunnerFactory`). In addition, you can directly run all available (and supported) analyses for a given configuration (`ConfigWrapper`) simply by doing `AnalysisRunnerFactory.run(config)`. -The result of an analysis runner __execution__ will be an `AnalysisResultWrapper`, an object containing information of the type of analysis ran and its result. +The result of an analysis runner __execution__ will be a collection of `AnalysisResultWrapper`, an object containing information of the type of analysis ran and its result. # How to add new analysis? * Create your own runner which should implement the `AnalysisRunnerProtocol`. * Define in the run method how the analysis should be run. * If you require extra arguments try using dependency injection while creating the object. + * You may use the `SimpleAnalysisRunnerBase` to avoid code duplication. * Implement its selection criteria in the `AnalysisRunnerFactory`. \ No newline at end of file diff --git a/ra2ce/runners/adaptation_analysis_runner.py b/ra2ce/runners/adaptation_analysis_runner.py new file mode 100644 index 000000000..6fe43136e --- /dev/null +++ b/ra2ce/runners/adaptation_analysis_runner.py @@ -0,0 +1,25 @@ +from ra2ce.analysis.adaptation.adaptation import Adaptation +from ra2ce.analysis.analysis_collection import AnalysisCollection +from ra2ce.configuration.config_wrapper import ConfigWrapper +from ra2ce.runners.damages_analysis_runner import DamagesAnalysisRunner +from ra2ce.runners.simple_analysis_runner_base import SimpleAnalysisRunnerBase + + +class AdaptationAnalysisRunner(SimpleAnalysisRunnerBase): + def __str__(self): + return "Adaptation Analysis Runner" + + @staticmethod + def can_run(ra2ce_input: ConfigWrapper) -> bool: + if ( + not ra2ce_input.analysis_config + or not ra2ce_input.analysis_config.config_data.adaptation + ): + return False + return DamagesAnalysisRunner.can_run(ra2ce_input) + + @staticmethod + def filter_supported_analyses( + analysis_collection: AnalysisCollection, + ) -> list[Adaptation]: + return [analysis_collection.adaptation_analysis] diff --git a/ra2ce/runners/analysis_runner_factory.py b/ra2ce/runners/analysis_runner_factory.py index 5d327e47a..43ad4607b 100644 --- a/ra2ce/runners/analysis_runner_factory.py +++ b/ra2ce/runners/analysis_runner_factory.py @@ -20,8 +20,12 @@ """ import logging +from typing import Type +from ra2ce.analysis.analysis_collection import AnalysisCollection +from ra2ce.analysis.analysis_result.analysis_result_wrapper import AnalysisResultWrapper from ra2ce.configuration.config_wrapper import ConfigWrapper +from ra2ce.runners.adaptation_analysis_runner import AdaptationAnalysisRunner from ra2ce.runners.analysis_runner_protocol import AnalysisRunner from ra2ce.runners.damages_analysis_runner import DamagesAnalysisRunner from ra2ce.runners.losses_analysis_runner import LossesAnalysisRunner @@ -29,17 +33,23 @@ class AnalysisRunnerFactory: @staticmethod - def get_runner(ra2ce_input: ConfigWrapper) -> AnalysisRunner: + def get_supported_runners(ra2ce_input: ConfigWrapper) -> list[Type[AnalysisRunner]]: """ - Gets the supported first analysis runner for the given input. The runner is initialized within this method. + Gets the supported analysis runners for the given input. Args: ra2ce_input (Ra2ceInput): Input representing a set of network and analysis ini configurations. Returns: - AnalysisRunner: Initialized Ra2ce analysis runner. + list[AnalysisRunner]: Supported runners for the given configuration. """ - _available_runners = [DamagesAnalysisRunner, LossesAnalysisRunner] + if not ra2ce_input.analysis_config: + return [] + _available_runners = [ + DamagesAnalysisRunner, + LossesAnalysisRunner, + AdaptationAnalysisRunner, + ] _supported_runners = [ _runner for _runner in _available_runners if _runner.can_run(ra2ce_input) ] @@ -49,7 +59,20 @@ def get_runner(ra2ce_input: ConfigWrapper) -> AnalysisRunner: raise ValueError(_err_mssg) # Initialized selected supported runner (First one available). - _selected_runner = _supported_runners[0]() if len(_supported_runners) > 1: - logging.warning(f"More than one runner available, using {_selected_runner}") - return _selected_runner + logging.warning( + "More than one runner available, computation time could be longer than expected." + ) + return _supported_runners + + @staticmethod + def run(ra2ce_input: ConfigWrapper) -> list[AnalysisResultWrapper]: + _supported_runners = AnalysisRunnerFactory.get_supported_runners(ra2ce_input) + _analysis_collection = AnalysisCollection.from_config( + ra2ce_input.analysis_config + ) + _results = [] + for _runner_type in _supported_runners: + _run_results = _runner_type().run(_analysis_collection) + _results.extend(_run_results) + return _results diff --git a/ra2ce/runners/analysis_runner_protocol.py b/ra2ce/runners/analysis_runner_protocol.py index e687fc347..cb1f2937d 100644 --- a/ra2ce/runners/analysis_runner_protocol.py +++ b/ra2ce/runners/analysis_runner_protocol.py @@ -21,7 +21,7 @@ from typing import Protocol, runtime_checkable -from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper +from ra2ce.analysis.analysis_collection import AnalysisCollection from ra2ce.analysis.analysis_result.analysis_result_wrapper_protocol import ( AnalysisResultWrapperProtocol, ) @@ -43,13 +43,13 @@ def can_run(ra2ce_input: ConfigWrapper) -> bool: """ def run( - self, analysis_config: AnalysisConfigWrapper + self, analysis_collection: AnalysisCollection ) -> list[AnalysisResultWrapperProtocol]: """ - Runs this `AnalysisRunner` with the given analysis configuration. + Runs this `AnalysisRunner` for the given analysis collection. Args: - analysis_config (AnalysisConfigWrapper): Analysis configuration representation to be run on this `AnalysisRunner`. + analysis_collection (AnalysisCollection): Collection of analyses to be run on this `AnalysisRunner`. Returns: list[AnalysisResultWrapperProtocol]: List of all results for all ran analysis. diff --git a/ra2ce/runners/damages_analysis_runner.py b/ra2ce/runners/damages_analysis_runner.py index 7e307cc90..d47e13162 100644 --- a/ra2ce/runners/damages_analysis_runner.py +++ b/ra2ce/runners/damages_analysis_runner.py @@ -20,22 +20,23 @@ """ import logging -import time from ra2ce.analysis.analysis_collection import AnalysisCollection -from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper -from ra2ce.analysis.analysis_result.analysis_result_wrapper_exporter import ( - AnalysisResultWrapperExporter, -) -from ra2ce.analysis.damages.damages_result_wrapper import DamagesResultWrapper +from ra2ce.analysis.damages.analysis_damages_protocol import AnalysisDamagesProtocol from ra2ce.configuration.config_wrapper import ConfigWrapper -from ra2ce.runners.analysis_runner_protocol import AnalysisRunner +from ra2ce.runners.simple_analysis_runner_base import SimpleAnalysisRunnerBase -class DamagesAnalysisRunner(AnalysisRunner): +class DamagesAnalysisRunner(SimpleAnalysisRunnerBase): def __str__(self) -> str: return "Damages Analysis Runner" + @staticmethod + def filter_supported_analyses( + analysis_collection: AnalysisCollection, + ) -> list[AnalysisDamagesProtocol]: + return analysis_collection.damages_analyses + @staticmethod def can_run(ra2ce_input: ConfigWrapper) -> bool: if ( @@ -52,25 +53,3 @@ def can_run(ra2ce_input: ConfigWrapper) -> bool: ) return False return True - - def run(self, analysis_config: AnalysisConfigWrapper) -> list[DamagesResultWrapper]: - _analysis_collection = AnalysisCollection.from_config(analysis_config) - _results = [] - for analysis in _analysis_collection.damages_analyses: - logging.info( - "----------------------------- Started analyzing '%s' -----------------------------", - analysis.analysis.name, - ) - starttime = time.time() - - _result_wrapper = analysis.execute() - AnalysisResultWrapperExporter().export_result(_result_wrapper) - - endtime = time.time() - logging.info( - "----------------------------- Analysis '%s' finished. " - "Time: %ss -----------------------------", - analysis.analysis.name, - str(round(endtime - starttime, 2)), - ) - return _results diff --git a/ra2ce/runners/losses_analysis_runner.py b/ra2ce/runners/losses_analysis_runner.py index 840c42f39..0bcea6444 100644 --- a/ra2ce/runners/losses_analysis_runner.py +++ b/ra2ce/runners/losses_analysis_runner.py @@ -18,24 +18,22 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - -import logging -import time - from ra2ce.analysis.analysis_collection import AnalysisCollection -from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper -from ra2ce.analysis.analysis_result.analysis_result_wrapper import AnalysisResultWrapper -from ra2ce.analysis.analysis_result.analysis_result_wrapper_exporter import ( - AnalysisResultWrapperExporter, -) +from ra2ce.analysis.losses.analysis_losses_protocol import AnalysisLossesProtocol from ra2ce.configuration.config_wrapper import ConfigWrapper -from ra2ce.runners.analysis_runner_protocol import AnalysisRunner +from ra2ce.runners.simple_analysis_runner_base import SimpleAnalysisRunnerBase -class LossesAnalysisRunner(AnalysisRunner): +class LossesAnalysisRunner(SimpleAnalysisRunnerBase): def __str__(self) -> str: return "Losses Analysis Runner" + @staticmethod + def filter_supported_analyses( + analysis_collection: AnalysisCollection, + ) -> list[AnalysisLossesProtocol]: + return analysis_collection.losses_analyses + @staticmethod def can_run(ra2ce_input: ConfigWrapper) -> bool: if ( @@ -44,28 +42,3 @@ def can_run(ra2ce_input: ConfigWrapper) -> bool: ): return False return True - - def run( - self, analysis_config: AnalysisConfigWrapper - ) -> list[AnalysisResultWrapper]: - _analysis_collection = AnalysisCollection.from_config(analysis_config) - _results = [] - for analysis in _analysis_collection.losses_analyses: - logging.info( - "----------------------------- Started analyzing '%s' -----------------------------", - analysis.analysis.name, - ) - starttime = time.time() - - _result_wrapper = analysis.execute() - _results.append(_result_wrapper) - AnalysisResultWrapperExporter().export_result(_result_wrapper) - - endtime = time.time() - logging.info( - "----------------------------- Analysis '%s' finished. " - "Time: %ss -----------------------------", - analysis.analysis.name, - str(round(endtime - starttime, 2)), - ) - return _results diff --git a/ra2ce/runners/simple_analysis_runner_base.py b/ra2ce/runners/simple_analysis_runner_base.py new file mode 100644 index 000000000..fdc893b56 --- /dev/null +++ b/ra2ce/runners/simple_analysis_runner_base.py @@ -0,0 +1,84 @@ +""" + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Risk Assessment and Adaptation for Critical Infrastructure (RA2CE). + Copyright (C) 2023 Stichting Deltares + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import logging +import time +from abc import abstractmethod + +from ra2ce.analysis.analysis_collection import AnalysisCollection +from ra2ce.analysis.analysis_protocol import AnalysisProtocol +from ra2ce.analysis.analysis_result.analysis_result_wrapper import AnalysisResultWrapper +from ra2ce.analysis.analysis_result.analysis_result_wrapper_exporter import ( + AnalysisResultWrapperExporter, +) +from ra2ce.configuration.config_wrapper import ConfigWrapper +from ra2ce.runners.analysis_runner_protocol import AnalysisRunner + + +class SimpleAnalysisRunnerBase(AnalysisRunner): + @abstractmethod + def __str__(self) -> str: + raise NotImplementedError() + + @staticmethod + @abstractmethod + def filter_supported_analyses( + analysis_collection: AnalysisCollection, + ) -> list[AnalysisProtocol]: + """ + Gets the supported analysis for a concrete runner. + + Args: + analysis_collection (AnalysisCollection): Collection of analyses to filter. + + Returns: + list[AnalysisProtocol]: Supported analyses from the provided collection. + """ + raise NotImplementedError() + + @staticmethod + @abstractmethod + def can_run(ra2ce_input: ConfigWrapper) -> bool: + raise NotImplementedError() + + def run( + self, analysis_collection: AnalysisCollection + ) -> list[AnalysisResultWrapper]: + _results = [] + for analysis in self.filter_supported_analyses(analysis_collection): + logging.info( + "----------------------------- Started analyzing '%s' -----------------------------", + analysis.analysis.name, + ) + starttime = time.time() + + _result_wrapper = analysis.execute() + _results.append(_result_wrapper) + AnalysisResultWrapperExporter().export_result(_result_wrapper) + + endtime = time.time() + logging.info( + "----------------------------- Analysis '%s' finished. " + "Time: %ss -----------------------------", + analysis.analysis.name, + str(round(endtime - starttime, 2)), + ) + return _results diff --git a/tests/analysis/test_analysis_collection.py b/tests/analysis/test_analysis_collection.py index c47f66840..8194d8adc 100644 --- a/tests/analysis/test_analysis_collection.py +++ b/tests/analysis/test_analysis_collection.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from pathlib import Path +from typing import Type import pytest +from ra2ce.analysis.adaptation.adaptation import Adaptation from ra2ce.analysis.analysis_collection import AnalysisCollection from ra2ce.analysis.analysis_config_data.analysis_config_data import ( AnalysisSectionAdaptation, @@ -17,8 +19,11 @@ AnalysisLossesEnum, ) from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper +from ra2ce.analysis.analysis_protocol import AnalysisProtocol from ra2ce.analysis.damages.analysis_damages_protocol import AnalysisDamagesProtocol +from ra2ce.analysis.damages.damages import Damages from ra2ce.analysis.losses.analysis_losses_protocol import AnalysisLossesProtocol +from ra2ce.analysis.losses.losses_base import LossesBase class TestAnalysisCollection: diff --git a/tests/runners/dummy_classes.py b/tests/runners/conftest.py similarity index 51% rename from tests/runners/dummy_classes.py rename to tests/runners/conftest.py index 1aee01383..ebb0039e3 100644 --- a/tests/runners/dummy_classes.py +++ b/tests/runners/conftest.py @@ -1,4 +1,12 @@ -from ra2ce.analysis.analysis_config_data.analysis_config_data import AnalysisConfigData +import pytest + +from ra2ce.analysis.analysis_config_data.analysis_config_data import ( + AnalysisConfigData, + AnalysisSectionDamages, +) +from ra2ce.analysis.analysis_config_data.enums.analysis_damages_enum import ( + AnalysisDamagesEnum, +) from ra2ce.analysis.analysis_config_wrapper import AnalysisConfigWrapper from ra2ce.configuration.config_wrapper import ConfigWrapper from ra2ce.network.network_config_wrapper import NetworkConfigWrapper @@ -23,3 +31,21 @@ class DummyRa2ceInput(ConfigWrapper): def __init__(self) -> None: self.analysis_config = DummyAnalysisConfigWrapper() self.network_config = NetworkConfigWrapper() + + +@pytest.fixture(name="dummy_ra2ce_input") +def _get_dummy_ra2ce_input() -> ConfigWrapper: + _ra2ce_input = DummyRa2ceInput() + assert isinstance(_ra2ce_input, ConfigWrapper) + return _ra2ce_input + + +@pytest.fixture(name="damages_ra2ce_input") +def _get_dummy_ra2ce_input_with_damages( + dummy_ra2ce_input: ConfigWrapper, +) -> ConfigWrapper: + dummy_ra2ce_input.analysis_config.config_data.analyses = [ + AnalysisSectionDamages(analysis=AnalysisDamagesEnum.DAMAGES) + ] + dummy_ra2ce_input.network_config.config_data.hazard.hazard_map = "A value" + return dummy_ra2ce_input diff --git a/tests/runners/test_adaptation_analysis_runner.py b/tests/runners/test_adaptation_analysis_runner.py new file mode 100644 index 000000000..c55d8b2dd --- /dev/null +++ b/tests/runners/test_adaptation_analysis_runner.py @@ -0,0 +1,57 @@ +from ra2ce.analysis.analysis_config_data.analysis_config_data import ( + AnalysisSectionAdaptation, + AnalysisSectionAdaptationOption, +) +from ra2ce.configuration.config_wrapper import ConfigWrapper +from ra2ce.runners.adaptation_analysis_runner import AdaptationAnalysisRunner + + +class TestAdaptationAnalysisRunner: + def test_init_adaptation_analysis_runner(self): + _runner = AdaptationAnalysisRunner() + assert str(_runner) == "Adaptation Analysis Runner" + + def test_given_wrong_analysis_configuration_cannot_run( + self, dummy_ra2ce_input: ConfigWrapper + ): + # 1. Define test data. + assert dummy_ra2ce_input.analysis_config.config_data.adaptation is None + + # 2. Run test. + _result = AdaptationAnalysisRunner.can_run(dummy_ra2ce_input) + + # 3. Verify expectations. + assert not _result + + def test_given_valid_damages_input_configuration_cannot_run( + self, damages_ra2ce_input: ConfigWrapper + ): + # 1. Define test data. + assert damages_ra2ce_input.analysis_config.config_data.adaptation is None + + # 2. Run test. + _result = AdaptationAnalysisRunner.can_run(damages_ra2ce_input) + + # 3. Verify expectation + assert _result is False + + def test_given_valid_damages_and_adaptation_input_configuration_can_run( + self, damages_ra2ce_input: ConfigWrapper + ): + # 1. Define test data. + assert damages_ra2ce_input.analysis_config.config_data.adaptation is None + _adaptation_config = AnalysisSectionAdaptation() + _adaptation_config.adaptation_options = [ + AnalysisSectionAdaptationOption(id="AO0"), + AnalysisSectionAdaptationOption(id="AO1"), + AnalysisSectionAdaptationOption(id="AO2"), + ] + damages_ra2ce_input.analysis_config.config_data.analyses.append( + _adaptation_config + ) + + # 2. Run test. + _result = AdaptationAnalysisRunner.can_run(damages_ra2ce_input) + + # 3. Verify expectation + assert _result is True diff --git a/tests/runners/test_analysis_runner_factory.py b/tests/runners/test_analysis_runner_factory.py index 5ec4b9310..50d388980 100644 --- a/tests/runners/test_analysis_runner_factory.py +++ b/tests/runners/test_analysis_runner_factory.py @@ -11,16 +11,18 @@ from ra2ce.analysis.analysis_config_data.enums.analysis_losses_enum import ( AnalysisLossesEnum, ) +from ra2ce.configuration.config_wrapper import ConfigWrapper from ra2ce.network.network_config_data.network_config_data import NetworkConfigData from ra2ce.runners.analysis_runner_factory import AnalysisRunnerFactory from ra2ce.runners.analysis_runner_protocol import AnalysisRunner -from tests.runners.dummy_classes import DummyRa2ceInput class TestAnalysisRunnerFactory: - def test_get_runner_unknown_input_raises_error(self): + def test_get_runner_unknown_input_raises_error( + self, dummy_ra2ce_input: ConfigWrapper + ): with pytest.raises(ValueError) as exc_err: - AnalysisRunnerFactory.get_runner(DummyRa2ceInput()) + AnalysisRunnerFactory.get_supported_runners(dummy_ra2ce_input) assert ( str(exc_err.value) @@ -28,10 +30,10 @@ def test_get_runner_unknown_input_raises_error(self): ) def test_get_runner_with_many_supported_runners_returns_analysis_runner_instance( - self, + self, dummy_ra2ce_input: ConfigWrapper ): # 1. Define test data. - _config_wrapper = DummyRa2ceInput() + _config_wrapper = dummy_ra2ce_input _config_wrapper.analysis_config.config_data = AnalysisConfigData( analyses=[ AnalysisSectionDamages(analysis=AnalysisDamagesEnum.DAMAGES), @@ -44,7 +46,11 @@ def test_get_runner_with_many_supported_runners_returns_analysis_runner_instance _config_wrapper.network_config.config_data.hazard.hazard_map = 4224 # 2. Run test. - _runner = AnalysisRunnerFactory.get_runner(_config_wrapper) + _supported_runners = AnalysisRunnerFactory.get_supported_runners( + _config_wrapper + ) # 3. Verify final expectations. - assert isinstance(_runner, AnalysisRunner) + assert isinstance(_supported_runners, list) + assert len(_supported_runners) == 2 + assert all(issubclass(_sr, AnalysisRunner) for _sr in _supported_runners) diff --git a/tests/runners/test_damages_analysis_runner.py b/tests/runners/test_damages_analysis_runner.py index 00134895d..729fcc5d5 100644 --- a/tests/runners/test_damages_analysis_runner.py +++ b/tests/runners/test_damages_analysis_runner.py @@ -1,5 +1,3 @@ -import pytest - from ra2ce.analysis.analysis_config_data.analysis_config_data import ( AnalysisSectionDamages, ) @@ -8,7 +6,6 @@ ) from ra2ce.configuration.config_wrapper import ConfigWrapper from ra2ce.runners.damages_analysis_runner import DamagesAnalysisRunner -from tests.runners.dummy_classes import DummyRa2ceInput class TestDamagesAnalysisRunner: @@ -16,23 +13,17 @@ def test_init_damages_analysis_runner(self): _runner = DamagesAnalysisRunner() assert str(_runner) == "Damages Analysis Runner" - @pytest.fixture - def dummy_ra2ce_input(self): - _ra2ce_input = DummyRa2ceInput() - assert isinstance(_ra2ce_input, ConfigWrapper) - yield _ra2ce_input - def test_given_damages_configuration_can_run( - self, dummy_ra2ce_input: ConfigWrapper + self, damages_ra2ce_input: ConfigWrapper ): # 1. Define test data. - dummy_ra2ce_input.analysis_config.config_data.analyses = [ - AnalysisSectionDamages(analysis=AnalysisDamagesEnum.DAMAGES) - ] - dummy_ra2ce_input.network_config.config_data.hazard.hazard_map = "A value" + assert any( + isinstance(_ad, AnalysisSectionDamages) + for _ad in damages_ra2ce_input.analysis_config.config_data.analyses + ) # 2. Run test. - _result = DamagesAnalysisRunner.can_run(dummy_ra2ce_input) + _result = DamagesAnalysisRunner.can_run(damages_ra2ce_input) # 3. Verify expectations. assert _result diff --git a/tests/runners/test_losses_analysis_runner.py b/tests/runners/test_losses_analysis_runner.py index d5893c6a4..de02b5724 100644 --- a/tests/runners/test_losses_analysis_runner.py +++ b/tests/runners/test_losses_analysis_runner.py @@ -8,7 +8,6 @@ ) from ra2ce.configuration.config_wrapper import ConfigWrapper from ra2ce.runners.losses_analysis_runner import LossesAnalysisRunner -from tests.runners.dummy_classes import DummyRa2ceInput class TestLossesAnalysisRunner: @@ -16,12 +15,6 @@ def test_init_losses_analysis_runner(self): _runner = LossesAnalysisRunner() assert str(_runner) == "Losses Analysis Runner" - @pytest.fixture - def dummy_ra2ce_input(self): - _ra2ce_input = DummyRa2ceInput() - assert isinstance(_ra2ce_input, ConfigWrapper) - yield _ra2ce_input - def test_given_losses_configuration_can_run(self, dummy_ra2ce_input: ConfigWrapper): # 1. Define test data. dummy_ra2ce_input.analysis_config.config_data.analyses = [ diff --git a/tests/test_data/acceptance_test_data/output/analyses.ini b/tests/test_data/acceptance_test_data/output/analyses.ini deleted file mode 100644 index 07450f145..000000000 --- a/tests/test_data/acceptance_test_data/output/analyses.ini +++ /dev/null @@ -1,81 +0,0 @@ -[project] -name = test - -[analysis1] -name = single link redundancy test -analysis = single_link_redundancy -weighing = time -save_gpkg = True -save_csv = True - -[analysis2] -name = multi link redundancy test -analysis = multi_link_redundancy -threshold = 1 -weighing = distance -save_gpkg = True -save_csv = True - -[analysis3] -name = optimal origin dest test -analysis = optimal_route_origin_destination -weighing = distance -save_gpkg = True -save_csv = True - -[analysis4] -name = multilink origin dest test -analysis = multi_link_origin_destination -threshold = 1 -weighing = distance -save_gpkg = True -save_csv = True - -[analysis5] -name = multilink origin closest dest test -analysis = multi_link_origin_closest_destination -threshold = 1 -weighing = distance -save_gpkg = True -save_csv = True - -[analysis6] -name = multilink isolated locations -analysis = multi_link_isolated_locations -threshold = 1 -weighing = distance -buffer_meters = 40 -category_field_name = category -save_gpkg = True -save_csv = True - -[analysis7] -name = adaptation module -analysis = adaptation -losses_analysis = single_link_losses -discount_rate = 0.05 -time_horizon = 10 -climate_factor = 0.2 -initial_frequency = 0.2 -save_gpkg = True -save_csv = True - -[adaptation_option0] -id = AO0 -name = no adaptation - -[adaptation_option1] -id = AO1 -name = first adaptation option -construction_cost = 1000 -construction_interval = 50 -maintenance_cost = 100 -maintenance_interval = 5 - -[adaptation_option2] -id = AO2 -name = second adaptation option -construction_cost = 2000 -construction_interval = 10 -maintenance_cost = 400 -maintenance_interval = 2 \ No newline at end of file