diff --git a/docs/configuring_cylc.md b/docs/configuring_cylc.md index eb36e963..30eee741 100644 --- a/docs/configuring_cylc.md +++ b/docs/configuring_cylc.md @@ -11,3 +11,5 @@ Additionally `cylc` uses a file called `$HOME/.cylc/flow/global.cylc` to control **WARNING:** The contents of the above two files will be platform specific. --- + +See [Configuring Cylc Discover](platforms/configuring_cylc_discover.md) for instructions on configuring cylc for Discover. diff --git a/docs/examples/description.md b/docs/examples/description.md index 756251dd..b66c8529 100644 --- a/docs/examples/description.md +++ b/docs/examples/description.md @@ -24,7 +24,7 @@ experiment_id: {suite_name}-suite ``` `swell create` also contains argument inputs. For instance, `-p` or `--platform` allows -user to pick which platform they would like to run on. +user to pick which platform they would like to run on. If unspecified, swell will run on SLES15 by default. ```bash swell create 3dvar -p nccs_discover_sles15 @@ -113,4 +113,4 @@ dependencies. For each suite, this will have a different structure and different `experiment.yaml`: This is the key configuration file that dictates the inputs for contain configuration variables that will be used for different scripts in the workflow. -For each JEDI bundle type (i.e., fv3-jedi, soca) and suite (3dvar, hofx etc.) in this section, we will display the `experiment.yaml` and talk about details. \ No newline at end of file +For each JEDI bundle type (i.e., fv3-jedi, soca) and suite (3dvar, hofx etc.) in this section, we will display the `experiment.yaml` and talk about details. diff --git a/docs/installing_swell.md b/docs/installing_swell.md index 39f6f24b..1bb739d7 100644 --- a/docs/installing_swell.md +++ b/docs/installing_swell.md @@ -29,4 +29,6 @@ cd swell pip install --prefix=/path/to/install/swell/ . ``` -To make the software useable ensure `/path/to/install/swell/bin` is in the `$PATH`. Also ensure that `/path/to/install/swell/lib/python/site-packages` is in the `$PYTHONPATH`, where `` denotes the version of Python used for the install, e.g. `3.9`. +To make the software usable ensure `/path/to/install/swell/bin` is in the `$PATH`. Also ensure that `/path/to/install/swell/lib/python/site-packages` is in the `$PYTHONPATH`, where `` denotes the version of Python used for the install, e.g. `3.9`. + +Swell makes use of additional packages which are located in shared directories on Discover, such as under `/discover/nobackup/projects/gmao`. When installed correctly, many of these libraries should be visible in the `$PYTHONPATH`. diff --git a/src/swell/deployment/create_experiment.py b/src/swell/deployment/create_experiment.py index 95aa66c6..6c283e22 100644 --- a/src/swell/deployment/create_experiment.py +++ b/src/swell/deployment/create_experiment.py @@ -14,6 +14,7 @@ import shutil import sys import yaml +from typing import Union, Optional from swell.deployment.prepare_config_and_suite.prepare_config_and_suite import \ PrepareExperimentConfigAndSuite @@ -27,8 +28,13 @@ # -------------------------------------------------------------------------------------------------- -def clone_config(configuration, experiment_id, method, platform, advanced): - +def clone_config( + configuration: str, + experiment_id: str, + method: str, + platform: str, + advanced: bool +) -> str: # Create a logger logger = Logger('SwellCloneExperiment') @@ -59,7 +65,14 @@ def clone_config(configuration, experiment_id, method, platform, advanced): # -------------------------------------------------------------------------------------------------- -def prepare_config(suite, method, platform, override, advanced, slurm): +def prepare_config( + suite: str, + method: str, + platform: str, + override: Union[dict, str, None], + advanced: bool, + slurm: str +) -> str: # Create a logger # --------------- @@ -136,13 +149,21 @@ def prepare_config(suite, method, platform, override, advanced, slurm): # Return path to dictionary file # ------------------------------ + return experiment_dict_string_comments # -------------------------------------------------------------------------------------------------- -def create_experiment_directory(suite, method, platform, override, advanced, slurm): +def create_experiment_directory( + suite: str, + method: str, + platform: str, + override: str, + advanced: bool, + slurm: Optional[str] +) -> None: # Create a logger # --------------- @@ -190,7 +211,7 @@ def create_experiment_directory(suite, method, platform, override, advanced, slu copy_platform_files(logger, exp_suite_path, platform) if os.path.exists(os.path.join(swell_suite_path, 'eva')): - copy_eva_files(logger, swell_suite_path, exp_suite_path) + copy_eva_files(swell_suite_path, exp_suite_path) # Set the swell paths in the modules file and create csh versions # --------------------------------------------------------------- @@ -217,7 +238,10 @@ def create_experiment_directory(suite, method, platform, override, advanced, slu # -------------------------------------------------------------------------------------------------- -def copy_eva_files(logger, swell_suite_path, exp_suite_path): +def copy_eva_files( + swell_suite_path: str, + exp_suite_path: str +) -> None: # Repo eva files eva_directory = os.path.join(swell_suite_path, 'eva') @@ -236,7 +260,11 @@ def copy_eva_files(logger, swell_suite_path, exp_suite_path): # -------------------------------------------------------------------------------------------------- -def copy_platform_files(logger, exp_suite_path, platform=None): +def copy_platform_files( + logger: Logger, + exp_suite_path: str, + platform: Optional[str] = None +) -> None: # Copy platform related files to the suite directory # -------------------------------------------------- @@ -256,7 +284,11 @@ def copy_platform_files(logger, exp_suite_path, platform=None): # -------------------------------------------------------------------------------------------------- -def template_modules_file(logger, experiment_dict, exp_suite_path): +def template_modules_file( + logger: Logger, + experiment_dict: dict, + exp_suite_path: str +) -> None: # Modules file # ------------ @@ -309,7 +341,10 @@ def template_modules_file(logger, experiment_dict, exp_suite_path): # -------------------------------------------------------------------------------------------------- -def create_modules_csh(logger, exp_suite_path): +def create_modules_csh( + logger: Logger, + exp_suite_path: str +) -> None: # Modules file # ------------ @@ -357,7 +392,13 @@ def create_modules_csh(logger, exp_suite_path): # -------------------------------------------------------------------------------------------------- -def prepare_cylc_suite_jinja2(logger, swell_suite_path, exp_suite_path, experiment_dict, platform): +def prepare_cylc_suite_jinja2( + logger: Logger, + swell_suite_path: str, + exp_suite_path: str, + experiment_dict: dict, + platform: str +) -> None: # Open suite file from swell # -------------------------- diff --git a/src/swell/deployment/launch_experiment.py b/src/swell/deployment/launch_experiment.py index eeebdb81..8917e85f 100644 --- a/src/swell/deployment/launch_experiment.py +++ b/src/swell/deployment/launch_experiment.py @@ -19,7 +19,13 @@ class DeployWorkflow(): - def __init__(self, suite_path, experiment_name, no_detach, log_path): + def __init__( + self, + suite_path: str, + experiment_name: str, + no_detach: bool, + log_path: str + ) -> None: self.logger = Logger('DeployWorkflow') self.suite_path = suite_path @@ -29,7 +35,7 @@ def __init__(self, suite_path, experiment_name, no_detach, log_path): # ---------------------------------------------------------------------------------------------- - def cylc_run_experiment(self): # NB: Could be a factory based on workflow_manager + def cylc_run_experiment(self) -> None: # NB: Could be a factory based on workflow_manager # Move to the suite path os.chdir(self.suite_path) @@ -93,7 +99,11 @@ def cylc_run_experiment(self): # NB: Could be a factory based on workflow_manag # -------------------------------------------------------------------------------------------------- -def launch_experiment(suite_path, no_detach, log_path): +def launch_experiment( + suite_path: str, + no_detach: bool, + log_path: str +) -> None: # Get the path to where the suite files are located # ------------------------------------------------- diff --git a/src/swell/deployment/platforms/platforms.py b/src/swell/deployment/platforms/platforms.py index 0d2d4b99..e5976b2c 100644 --- a/src/swell/deployment/platforms/platforms.py +++ b/src/swell/deployment/platforms/platforms.py @@ -17,7 +17,7 @@ # -------------------------------------------------------------------------------------------------- -def platform_path(): +def platform_path() -> str: return os.path.join(get_swell_path(), 'deployment', 'platforms') @@ -25,7 +25,7 @@ def platform_path(): # -------------------------------------------------------------------------------------------------- -def get_platforms(): +def get_platforms() -> list: # Get list of supported platforms platforms = [dir for dir in os.listdir(platform_path()) @@ -41,7 +41,7 @@ def get_platforms(): # -------------------------------------------------------------------------------------------------- -def login_or_compute(platform): +def login_or_compute(platform) -> str: # Open the properties file properties_file = os.path.join(platform_path(), 'properties.yaml') diff --git a/src/swell/deployment/prepare_config_and_suite/question_and_answer_cli.py b/src/swell/deployment/prepare_config_and_suite/question_and_answer_cli.py index dd11df1d..339efcee 100644 --- a/src/swell/deployment/prepare_config_and_suite/question_and_answer_cli.py +++ b/src/swell/deployment/prepare_config_and_suite/question_and_answer_cli.py @@ -10,6 +10,7 @@ import re import sys +from typing import Union import questionary from questionary import Choice @@ -20,7 +21,7 @@ class GetAnswerCli: - def get_answer(self, key, val): + def get_answer(self, key: str, val: dict) -> str: # Set questionary variable widget_type = val['type'] quest = val['prompt'] @@ -60,14 +61,14 @@ def get_answer(self, key, val): # ---------------------------------------------------------------------------------------------- - def make_string_widget(self, quest, default, prompt): + def make_string_widget(self, quest: str, default: str, prompt: questionary.text) -> str: answer = prompt(f"{quest} [{default}]", default=default).ask() return answer # ---------------------------------------------------------------------------------------------- - def make_int_widget(self, quest, default, prompt): + def make_int_widget(self, quest: str, default: str, prompt: questionary.text) -> str: default = str(default) answer = prompt(f"{quest} [{default}]", validate=lambda text: True if text.isdigit() @@ -78,7 +79,12 @@ def make_int_widget(self, quest, default, prompt): # ---------------------------------------------------------------------------------------------- - def make_float_widget(self, quest, default, prompt): + def make_float_widget( + self, + quest: str, + default: str, + prompt: questionary.text + ) -> str: default = str(default) answer = prompt(f"{quest} [{default}]", validate=lambda text: True if text.isdigit() @@ -89,7 +95,15 @@ def make_float_widget(self, quest, default, prompt): # ---------------------------------------------------------------------------------------------- - def make_drop_widget(self, method, quest, options, default, prompt): + def make_drop_widget( + self, + method: str, + quest: str, + options: list, + default: str, + prompt: questionary.text + ) -> Union[str, list]: + default = str(default) choices = [str(x) for x in options] answer = prompt(quest, choices=choices, default=default).ask() @@ -98,14 +112,25 @@ def make_drop_widget(self, method, quest, options, default, prompt): # ---------------------------------------------------------------------------------------------- - def make_boolean(self, quest, default, prompt): + def make_boolean( + self, + quest: str, + default: str, + prompt: questionary.text + ) -> str: + answer = prompt(quest, default=default, auto_enter=False).ask() return answer # ---------------------------------------------------------------------------------------------- - def make_datetime(self, quest, default, prompt): + def make_datetime( + self, + quest: str, + default: str, + prompt: questionary.text + ) -> str: class dtValidator(questionary.Validator): def validate(self, document): @@ -124,7 +149,12 @@ def validate(self, document): # ---------------------------------------------------------------------------------------------- - def make_duration(self, quest, default, prompt): + def make_duration( + self, + quest: str, + default: str, + prompt: questionary.text + ) -> str: class durValidator(questionary.Validator): def validate(self, document): @@ -157,7 +187,14 @@ def validate(self, document): # ---------------------------------------------------------------------------------------------- - def make_check_widget(self, quest, options, default, prompt): + def make_check_widget( + self, + quest: str, + options: list, + default: Union[str, list], + prompt: questionary.text + ) -> str: + choices = options.copy() if isinstance(default, list): diff --git a/src/swell/deployment/prepare_config_and_suite/question_and_answer_defaults.py b/src/swell/deployment/prepare_config_and_suite/question_and_answer_defaults.py index fb3e2082..998ec71c 100644 --- a/src/swell/deployment/prepare_config_and_suite/question_and_answer_defaults.py +++ b/src/swell/deployment/prepare_config_and_suite/question_and_answer_defaults.py @@ -8,9 +8,13 @@ # -------------------------------------------------------------------------------------------------- +from typing import Union +from datetime import datetime as dt + + class GetAnswerDefaults: - def get_answer(self, key, val): + def get_answer(self, key: str, val: dict) -> Union[int, float, str, dt]: return val['default_value'] # -------------------------------------------------------------------------------------------------- diff --git a/src/swell/swell.py b/src/swell/swell.py index bb96ea39..4c2ac016 100644 --- a/src/swell/swell.py +++ b/src/swell/swell.py @@ -9,6 +9,7 @@ import click +from typing import Union, Optional, Literal from swell.deployment.platforms.platforms import get_platforms from swell.deployment.create_experiment import clone_config, create_experiment_directory @@ -25,14 +26,14 @@ @click.group() -def swell_driver(): +def swell_driver() -> None: """ Welcome to swell! This is the top level driver for swell. It serves as a container for various commands related to experiment creation, launching, tasks, and utilities. - The normal process for createing and running an experiment is to issue: + The normal process for creating and running an experiment is to issue: swell create @@ -89,12 +90,19 @@ def swell_driver(): @click.argument('suite', type=click.Choice(get_suites())) @click.option('-m', '--input_method', 'input_method', default='defaults', type=click.Choice(['defaults', 'cli']), help=input_method_help) -@click.option('-p', '--platform', 'platform', default='nccs_discover', +@click.option('-p', '--platform', 'platform', default='nccs_discover_sles15', type=click.Choice(get_platforms()), help=platform_help) @click.option('-o', '--override', 'override', default=None, help=override_help) @click.option('-a', '--advanced', 'advanced', default=False, help=advanced_help) @click.option('-s', '--slurm', 'slurm', default=None, help=slurm_help) -def create(suite, input_method, platform, override, advanced, slurm): +def create( + suite: str, + input_method: str, + platform: str, + override: Union[dict, str, None], + advanced: bool, + slurm: str +) -> None: """ Create a new experiment @@ -118,7 +126,13 @@ def create(suite, input_method, platform, override, advanced, slurm): type=click.Choice(['defaults', 'cli']), help=input_method_help) @click.option('-p', '--platform', 'platform', default=None, help=platform_help) @click.option('-a', '--advanced', 'advanced', default=False, help=advanced_help) -def clone(configuration, experiment_id, input_method, platform, advanced): +def clone( + configuration: str, + experiment_id: str, + input_method: str, + platform: str, + advanced: bool +) -> None: """ Clone an existing experiment @@ -144,7 +158,11 @@ def clone(configuration, experiment_id, input_method, platform, advanced): @click.argument('suite_path') @click.option('-b', '--no-detach', 'no_detach', is_flag=True, default=False, help=no_detach_help) @click.option('-l', '--log_path', 'log_path', default=None, help=log_path_help) -def launch(suite_path, no_detach, log_path): +def launch( + suite_path: str, + no_detach: bool, + log_path: str +) -> None: """ Launch an experiment with the cylc workflow manager @@ -166,7 +184,13 @@ def launch(suite_path, no_detach, log_path): @click.option('-d', '--datetime', 'datetime', default=None, help=datetime_help) @click.option('-m', '--model', 'model', default=None, help=model_help) @click.option('-p', '--ensemblePacket', 'ensemblePacket', default=None, help=ensemble_help) -def task(task, config, datetime, model, ensemblePacket): +def task( + task: str, + config: str, + datetime: Optional[str], + model: Optional[str], + ensemblePacket: Optional[str] +) -> None: """ Run a workflow task @@ -185,7 +209,7 @@ def task(task, config, datetime, model, ensemblePacket): @swell_driver.command() @click.argument('utility', type=click.Choice(get_utilities())) -def utility(utility): +def utility(utility: str) -> None: """ Run a utility script @@ -203,7 +227,7 @@ def utility(utility): @swell_driver.command() @click.argument('test', type=click.Choice(valid_tests)) -def test(test): +def test(test: str) -> None: """ Run one of the test suites @@ -221,7 +245,7 @@ def test(test): @swell_driver.command() @click.argument('suite', type=click.Choice(("hofx", "3dvar", "ufo_testing"))) -def t1test(suite): +def t1test(suite: Literal["hofx", "3dvar", "ufo_testing"]) -> None: """ Run a particular swell suite from the tier 1 tests. @@ -234,7 +258,7 @@ def t1test(suite): # -------------------------------------------------------------------------------------------------- -def main(): +def main() -> None: """ Main Function diff --git a/src/swell/swell_path.py b/src/swell/swell_path.py index 357237ff..89caab65 100644 --- a/src/swell/swell_path.py +++ b/src/swell/swell_path.py @@ -14,7 +14,7 @@ # -------------------------------------------------------------------------------------------------- -def get_swell_path(): +def get_swell_path() -> str: return os.path.split(__file__)[0] diff --git a/src/swell/tasks/base/task_base.py b/src/swell/tasks/base/task_base.py index 3d2532e9..844c71ec 100644 --- a/src/swell/tasks/base/task_base.py +++ b/src/swell/tasks/base/task_base.py @@ -15,13 +15,15 @@ import importlib import os import time +from datetime import datetime as dt +from typing import Union, Optional # swell imports from swell.swell_path import get_swell_path from swell.utilities.case_switching import camel_case_to_snake_case, snake_case_to_camel_case from swell.utilities.config import Config from swell.utilities.data_assimilation_window_params import DataAssimilationWindowParams -from swell.utilities.datetime import Datetime +from swell.utilities.datetime_util import Datetime from swell.utilities.logger import Logger from swell.utilities.render_jedi_interface_files import JediConfigRendering from swell.utilities.geos import Geos @@ -33,7 +35,14 @@ class taskBase(ABC): # Base class constructor - def __init__(self, config_input, datetime_input, model, ensemblePacket, task_name): + def __init__( + self, + config_input: str, + datetime_input: Optional[str], + model: str, + ensemblePacket: Optional[str], + task_name: str + ) -> None: # Create message logger # --------------------- @@ -114,58 +123,58 @@ def __init__(self, config_input, datetime_input, model, ensemblePacket, task_nam # Execute is the place where a task does its work. It's defined as abstract in the base class # in order to force the sub classes (tasks) to implement it. @abstractmethod - def execute(self): + def execute(self) -> None: pass # ---------------------------------------------------------------------------------------------- # Method to get the experiment root - def experiment_root(self): + def experiment_root(self) -> str: return self.__experiment_root__ # ---------------------------------------------------------------------------------------------- # Method to get the experiment ID - def experiment_id(self): + def experiment_id(self) -> str: return self.__experiment_id__ # ---------------------------------------------------------------------------------------------- # Method to get the experiment directory - def experiment_path(self): + def experiment_path(self) -> str: return os.path.join(self.__experiment_root__, self.__experiment_id__) # ---------------------------------------------------------------------------------------------- # Method to get the experiment ID - def platform(self): + def platform(self) -> str: return self.__platform__ # ---------------------------------------------------------------------------------------------- # Method to get the experiment configuration directory - def experiment_config_path(self): + def experiment_config_path(self) -> str: swell_exp_path = self.experiment_path() return os.path.join(swell_exp_path, 'configuration') # ---------------------------------------------------------------------------------------------- - def get_ensemble_packet(self): + def get_ensemble_packet(self) -> Optional[str]: return self.__ensemble_packet__ # ---------------------------------------------------------------------------------------------- - def get_model(self): + def get_model(self) -> str: return self.__model__ # ---------------------------------------------------------------------------------------------- - def get_model_components(self): + def get_model_components(self) -> Union[str, list]: return self.__model_components__ # ---------------------------------------------------------------------------------------------- - def is_datetime_dependent(self): + def is_datetime_dependent(self) -> bool: if self.__datetime__ is None: return False else: @@ -173,7 +182,7 @@ def is_datetime_dependent(self): # ---------------------------------------------------------------------------------------------- - def cycle_dir(self): + def cycle_dir(self) -> str: # Check that model is set self.logger.assert_abort(self.__model__ is not None, 'In get_cycle_dir but this ' + @@ -188,7 +197,7 @@ def cycle_dir(self): # ---------------------------------------------------------------------------------------------- - def forecast_dir(self, paths=[]): + def forecast_dir(self, paths: Union[str, list[str]] = []) -> Optional[str]: # Make sure forecast directory exists # ----------------------------------- @@ -212,25 +221,25 @@ def forecast_dir(self, paths=[]): # ---------------------------------------------------------------------------------------------- - def cycle_time_dto(self): + def cycle_time_dto(self) -> dt: return self.__datetime__.dto() # ---------------------------------------------------------------------------------------------- - def cycle_time(self): + def cycle_time(self) -> str: return self.__datetime__.string_iso() # ---------------------------------------------------------------------------------------------- - def first_cycle_time(self): + def first_cycle_time(self) -> str: return self.__start_cycle_point__.string_iso() # ---------------------------------------------------------------------------------------------- - def first_cycle_time_dto(self): + def first_cycle_time_dto(self) -> dt: return self.__start_cycle_point__.dto() @@ -239,7 +248,14 @@ def first_cycle_time_dto(self): class taskFactory(): - def create_task(self, task, config, datetime, model, ensemblePacket): + def create_task( + self, + task: str, + config: str, + datetime: Union[str, dt, None], + model: str, + ensemblePacket: Optional[str] + ) -> taskBase: # Convert camel case string to snake case task_lower = camel_case_to_snake_case(task) @@ -253,7 +269,7 @@ def create_task(self, task, config, datetime, model, ensemblePacket): # -------------------------------------------------------------------------------------------------- -def get_tasks(): +def get_tasks() -> list: # Path to tasks tasks_directory = os.path.join(get_swell_path(), 'tasks', '*.py') @@ -274,7 +290,13 @@ def get_tasks(): # -------------------------------------------------------------------------------------------------- -def task_wrapper(task, config, datetime, model, ensemblePacket): +def task_wrapper( + task: str, + config: str, + datetime: Union[str, dt, None], + model: Optional[str], + ensemblePacket: Optional[str] +) -> None: # Create the object constrc_start = time.perf_counter() diff --git a/src/swell/tasks/build_geos.py b/src/swell/tasks/build_geos.py index 27c138b6..13dcf308 100644 --- a/src/swell/tasks/build_geos.py +++ b/src/swell/tasks/build_geos.py @@ -20,7 +20,7 @@ class BuildGeos(taskBase): - def execute(self): + def execute(self) -> None: # Get the experiment/geos directory # --------------------------------- diff --git a/src/swell/tasks/build_geos_by_linking.py b/src/swell/tasks/build_geos_by_linking.py index 1bc9610e..16fc91ee 100644 --- a/src/swell/tasks/build_geos_by_linking.py +++ b/src/swell/tasks/build_geos_by_linking.py @@ -19,7 +19,7 @@ class BuildGeosByLinking(taskBase): - def execute(self): + def execute(self) -> None: # Get the experiment/geos directory # --------------------------------- diff --git a/src/swell/tasks/build_jedi.py b/src/swell/tasks/build_jedi.py index b78a8200..c9be55ac 100644 --- a/src/swell/tasks/build_jedi.py +++ b/src/swell/tasks/build_jedi.py @@ -20,7 +20,7 @@ class BuildJedi(taskBase): - def execute(self): + def execute(self) -> None: # Get the experiment/jedi_bundle directory # ---------------------------------------- diff --git a/src/swell/tasks/build_jedi_by_linking.py b/src/swell/tasks/build_jedi_by_linking.py index cbd7bb9d..2d77290a 100644 --- a/src/swell/tasks/build_jedi_by_linking.py +++ b/src/swell/tasks/build_jedi_by_linking.py @@ -19,7 +19,7 @@ class BuildJediByLinking(taskBase): - def execute(self): + def execute(self) -> None: # Get the experiment/jedi_bundle directory # ---------------------------------------- diff --git a/src/swell/tasks/clean_cycle.py b/src/swell/tasks/clean_cycle.py index 8fba3a84..4c606898 100644 --- a/src/swell/tasks/clean_cycle.py +++ b/src/swell/tasks/clean_cycle.py @@ -25,7 +25,7 @@ class CleanCycle(taskBase): """ - def execute(self): + def execute(self) -> None: # Parse config clean_patterns = self.config.clean_patterns(None) diff --git a/src/swell/tasks/clone_geos.py b/src/swell/tasks/clone_geos.py index 36396022..6190a81a 100644 --- a/src/swell/tasks/clone_geos.py +++ b/src/swell/tasks/clone_geos.py @@ -20,7 +20,7 @@ class CloneGeos(taskBase): - def execute(self): + def execute(self) -> None: # Get the experiment/geos directory # --------------------------------- diff --git a/src/swell/tasks/clone_geos_mksi.py b/src/swell/tasks/clone_geos_mksi.py index 308c4d81..555b6638 100644 --- a/src/swell/tasks/clone_geos_mksi.py +++ b/src/swell/tasks/clone_geos_mksi.py @@ -17,7 +17,7 @@ class CloneGeosMksi(taskBase): - def execute(self): + def execute(self) -> None: """ Generate the satellite channel record from GEOSmksi files diff --git a/src/swell/tasks/clone_jedi.py b/src/swell/tasks/clone_jedi.py index ad2edaad..e69f4cb5 100644 --- a/src/swell/tasks/clone_jedi.py +++ b/src/swell/tasks/clone_jedi.py @@ -22,7 +22,7 @@ class CloneJedi(taskBase): - def execute(self): + def execute(self) -> None: # Get the experiment/jedi_bundle directory # ---------------------------------------- diff --git a/src/swell/tasks/eva_increment.py b/src/swell/tasks/eva_increment.py index f473e2fa..a6d20071 100644 --- a/src/swell/tasks/eva_increment.py +++ b/src/swell/tasks/eva_increment.py @@ -21,7 +21,7 @@ class EvaIncrement(taskBase): - def execute(self): + def execute(self) -> None: # Get the model and window type # ----------------------------- diff --git a/src/swell/tasks/eva_jedi_log.py b/src/swell/tasks/eva_jedi_log.py index 940b3759..5a73e255 100644 --- a/src/swell/tasks/eva_jedi_log.py +++ b/src/swell/tasks/eva_jedi_log.py @@ -22,7 +22,7 @@ class EvaJediLog(taskBase): - def execute(self): + def execute(self) -> None: # Get the model # ------------- diff --git a/src/swell/tasks/eva_observations.py b/src/swell/tasks/eva_observations.py index 1c1f49fa..7061fa20 100644 --- a/src/swell/tasks/eva_observations.py +++ b/src/swell/tasks/eva_observations.py @@ -25,7 +25,7 @@ # Pass through to avoid confusion with optional logger argument inside eva -def run_eva(eva_dict): +def run_eva(eva_dict: dict) -> eva: eva(eva_dict) @@ -34,7 +34,7 @@ def run_eva(eva_dict): class EvaObservations(taskBase): - def execute(self): + def execute(self) -> None: # Compute window beginning time # ----------------------------- diff --git a/src/swell/tasks/generate_b_climatology.py b/src/swell/tasks/generate_b_climatology.py index e6d9b912..35b6b6a0 100644 --- a/src/swell/tasks/generate_b_climatology.py +++ b/src/swell/tasks/generate_b_climatology.py @@ -11,13 +11,14 @@ from swell.tasks.base.task_base import taskBase from swell.utilities.shell_commands import run_subprocess, run_track_log_subprocess from swell.utilities.run_jedi_executables import jedi_dictionary_iterator +from swell.utilities.file_system_operations import check_if_files_exist_in_path # -------------------------------------------------------------------------------------------------- class GenerateBClimatology(taskBase): - def jedi_dictionary_iterator(self, jedi_config_dict): + def jedi_dictionary_iterator(self, jedi_config_dict: dict) -> None: # Loop over dictionary and replace if value is a dictionary # --------------------------------------------------------- @@ -32,7 +33,7 @@ def jedi_dictionary_iterator(self, jedi_config_dict): # ---------------------------------------------------------------------------------------------- - def generate_jedi_config(self): + def generate_jedi_config(self) -> dict: # Render StaticBInit (no templates needed) # ---------------------------------------- @@ -46,7 +47,7 @@ def generate_jedi_config(self): # ---------------------------------------------------------------------------------------------- - def initialize_background(self): + def initialize_background(self) -> None: if self.background_error_model == 'bump': @@ -61,7 +62,7 @@ def initialize_background(self): # ---------------------------------------------------------------------------------------------- - def generate_bump(self): + def generate_bump(self) -> None: self.logger.info(' Generating BUMP files.') @@ -104,7 +105,7 @@ def generate_bump(self): # ---------------------------------------------------------------------------------------------- - def generate_explicit_diffusion(self): + def generate_explicit_diffusion(self) -> None: self.logger.info(' Generating files required by EXPLICIT_DIFFUSION.') self.obtain_scales() @@ -112,7 +113,7 @@ def generate_explicit_diffusion(self): # ---------------------------------------------------------------------------------------------- - def obtain_scales(self): + def obtain_scales(self) -> None: # This executes calc_scales.py under SOCA/tools to obtain the vertical scale. # The output then will be used to generate the vertical correlation files via @@ -157,7 +158,7 @@ def obtain_scales(self): # ---------------------------------------------------------------------------------------------- - def parameters_diffusion_vt(self): + def parameters_diffusion_vt(self) -> None: # This generates the MLD dependent vertical correlation file using the # calculated_scales @@ -215,7 +216,7 @@ def parameters_diffusion_vt(self): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: """ Creates B Matrix files for background error model(s): - BUMP: @@ -239,7 +240,18 @@ def execute(self): window_offset = self.config.window_offset() window_type = self.config.window_type() background_error_model = self.config.background_error_model() + + swell_static_files_user = self.config.swell_static_files_user(None) self.swell_static_files = self.config.swell_static_files() + + # Use static_files_user if present in config and contains files + # ------------------------------------------------------------- + if swell_static_files_user is not None: + self.logger.info('swell_static_files_user specified, checking for files') + if check_if_files_exist_in_path(self.logger, swell_static_files_user): + self.logger.info(f'Using swell static files in {swell_static_files_user}') + self.swell_static_files = swell_static_files_user + self.horizontal_resolution = self.config.horizontal_resolution() self.vertical_resolution = self.config.vertical_resolution() diff --git a/src/swell/tasks/generate_b_climatology_by_linking.py b/src/swell/tasks/generate_b_climatology_by_linking.py index 4caafe0e..694acbfe 100644 --- a/src/swell/tasks/generate_b_climatology_by_linking.py +++ b/src/swell/tasks/generate_b_climatology_by_linking.py @@ -16,7 +16,7 @@ class GenerateBClimatologyByLinking(taskBase): - def execute(self): + def execute(self) -> None: """Acquires B Matrix files for background error model(s): - EXPLICIT_DIFFUSION: diff --git a/src/swell/tasks/generate_observing_system_records.py b/src/swell/tasks/generate_observing_system_records.py index 6a582846..146c43f9 100644 --- a/src/swell/tasks/generate_observing_system_records.py +++ b/src/swell/tasks/generate_observing_system_records.py @@ -18,7 +18,7 @@ class GenerateObservingSystemRecords(taskBase): - def execute(self): + def execute(self) -> None: """ Generate the observing system channel records from GEOS_mksi files diff --git a/src/swell/tasks/get_background.py b/src/swell/tasks/get_background.py index e11db228..d1c5b407 100644 --- a/src/swell/tasks/get_background.py +++ b/src/swell/tasks/get_background.py @@ -29,7 +29,7 @@ class GetBackground(taskBase): - def execute(self): + def execute(self) -> None: """Acquires background files for a given experiment and cycle diff --git a/src/swell/tasks/get_background_geos_experiment.py b/src/swell/tasks/get_background_geos_experiment.py index c5470868..8e2019f6 100644 --- a/src/swell/tasks/get_background_geos_experiment.py +++ b/src/swell/tasks/get_background_geos_experiment.py @@ -14,7 +14,7 @@ import tarfile from swell.tasks.base.task_base import taskBase -from swell.utilities.datetime import datetime_formats +from swell.utilities.datetime_util import datetime_formats # -------------------------------------------------------------------------------------------------- @@ -51,6 +51,7 @@ def execute(self): # Since this is an optional task, check if the geos_x_background_directory is # set to /dev/null, if so fail the task # --------------------------------------------------------------------- + if ( (geos_x_background_directory is None) or (geos_x_background_directory.startswith("/dev/null")) diff --git a/src/swell/tasks/get_ensemble.py b/src/swell/tasks/get_ensemble.py index 627724f9..08b77b97 100644 --- a/src/swell/tasks/get_ensemble.py +++ b/src/swell/tasks/get_ensemble.py @@ -19,7 +19,7 @@ class GetEnsemble(taskBase): - def execute(self): + def execute(self) -> None: """Acquires ensemble member files for a given experiment and cycle Parameters diff --git a/src/swell/tasks/get_geos_adas_background.py b/src/swell/tasks/get_geos_adas_background.py index 3a7f6e6e..354d5351 100644 --- a/src/swell/tasks/get_geos_adas_background.py +++ b/src/swell/tasks/get_geos_adas_background.py @@ -21,7 +21,7 @@ class GetGeosAdasBackground(taskBase): - def execute(self): + def execute(self) -> None: # Get the path and pattern for the background files # ------------------------------------------------- diff --git a/src/swell/tasks/get_geos_restart.py b/src/swell/tasks/get_geos_restart.py index 454e09c5..f8d3821a 100644 --- a/src/swell/tasks/get_geos_restart.py +++ b/src/swell/tasks/get_geos_restart.py @@ -11,7 +11,7 @@ import glob from swell.tasks.base.task_base import taskBase -from swell.utilities.file_system_operations import copy_to_dst_dir +from swell.utilities.file_system_operations import copy_to_dst_dir, check_if_files_exist_in_path # -------------------------------------------------------------------------------------------------- @@ -20,12 +20,21 @@ class GetGeosRestart(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: self.logger.info('Obtaining GEOS restarts for the coupled simulation') + swell_static_files_user = self.config.swell_static_files_user(None) self.swell_static_files = self.config.swell_static_files() + # Use static_files_user if present in config and contains files + # ------------------------------------------------------------- + if swell_static_files_user is not None: + self.logger.info('swell_static_files_user specified, checking for files') + if check_if_files_exist_in_path(self.logger, swell_static_files_user): + self.logger.info(f'Using swell static files in {swell_static_files_user}') + self.swell_static_files = swell_static_files_user + # Create forecast_dir and INPUT # ---------------------------- if not os.path.exists(self.forecast_dir('INPUT')): @@ -41,7 +50,7 @@ def execute(self): # ---------------------------------------------------------------------------------------------- - def initial_restarts(self, rst_path): + def initial_restarts(self, rst_path: str) -> None: # GEOS forecast checkpoint files are created in advance # TODO: check tile of restarts here for compatibility? diff --git a/src/swell/tasks/get_geovals.py b/src/swell/tasks/get_geovals.py index d4b63e94..eb8819ea 100644 --- a/src/swell/tasks/get_geovals.py +++ b/src/swell/tasks/get_geovals.py @@ -19,7 +19,7 @@ class GetGeovals(taskBase): - def execute(self): + def execute(self) -> None: # Parse config # ------------ diff --git a/src/swell/tasks/get_gsi_bc.py b/src/swell/tasks/get_gsi_bc.py index ead70322..9ca84ae2 100644 --- a/src/swell/tasks/get_gsi_bc.py +++ b/src/swell/tasks/get_gsi_bc.py @@ -22,7 +22,7 @@ class GetGsiBc(taskBase): - def execute(self): + def execute(self) -> None: # Get the build method # -------------------- diff --git a/src/swell/tasks/get_gsi_ncdiag.py b/src/swell/tasks/get_gsi_ncdiag.py index caf39f55..379badd6 100644 --- a/src/swell/tasks/get_gsi_ncdiag.py +++ b/src/swell/tasks/get_gsi_ncdiag.py @@ -19,7 +19,7 @@ class GetGsiNcdiag(taskBase): - def execute(self): + def execute(self) -> None: # Get the build method # -------------------- diff --git a/src/swell/tasks/get_observations.py b/src/swell/tasks/get_observations.py index 494bdaff..bdd49874 100644 --- a/src/swell/tasks/get_observations.py +++ b/src/swell/tasks/get_observations.py @@ -11,11 +11,12 @@ import numpy as np import os import netCDF4 as nc +from typing import Union from datetime import timedelta, datetime as dt from swell.tasks.base.task_base import taskBase from swell.utilities.r2d2 import create_r2d2_config -from swell.utilities.datetime import datetime_formats +from swell.utilities.datetime_util import datetime_formats from r2d2 import fetch @@ -24,7 +25,7 @@ class GetObservations(taskBase): - def execute(self): + def execute(self) -> None: """ Acquires observation files for a given experiment and cycle. @@ -271,7 +272,7 @@ def execute(self): # ---------------------------------------------------------------------------------------------- - def get_tlapse_files(self, observation_dict): + def get_tlapse_files(self, observation_dict: dict) -> Union[None, int]: # Function to locate instances of tlapse in the obs operator config @@ -326,7 +327,12 @@ def previous_cycle_bias(self, # Read and combine variable data from multiple files # -------------------------------------------------- - def create_obs_time_list(self, obs_times, window_begin_dto, window_end_dto): + def create_obs_time_list( + self, + obs_times: list, + window_begin_dto: dt, + window_end_dto: dt + ) -> list: day_before_dto = window_begin_dto-timedelta(days=1) day_after_dto = window_end_dto+timedelta(days=1) @@ -366,13 +372,13 @@ def create_obs_time_list(self, obs_times, window_begin_dto, window_end_dto): # Get the target data from the netcdf file # ---------------------------------------- - def get_data(self, input_file, group, var_name): + def get_data(self, input_file: str, group: str, var_name: str) -> object: with nc.Dataset(input_file, 'r') as ds: return ds[group][var_name][:] # ---------------------------------------------------------------------------------------------- - def read_and_combine(self, input_filenames, output_filename): + def read_and_combine(self, input_filenames: list, output_filename: str) -> None: ''' Combines multiple IODA v3 netcdf input files into a single output. Combining multiple files require final (total) location dimension size to be diff --git a/src/swell/tasks/gsi_bc_to_ioda.py b/src/swell/tasks/gsi_bc_to_ioda.py index 011e6afb..b64c146c 100644 --- a/src/swell/tasks/gsi_bc_to_ioda.py +++ b/src/swell/tasks/gsi_bc_to_ioda.py @@ -23,7 +23,7 @@ class GsiBcToIoda(taskBase): - def execute(self): + def execute(self) -> None: # Parse configuration # ------------------- diff --git a/src/swell/tasks/gsi_ncdiag_to_ioda.py b/src/swell/tasks/gsi_ncdiag_to_ioda.py index 2f01b7a8..c59f74fd 100644 --- a/src/swell/tasks/gsi_ncdiag_to_ioda.py +++ b/src/swell/tasks/gsi_ncdiag_to_ioda.py @@ -19,7 +19,7 @@ from pyiodaconv.combine_obsspace import combine_obsspace from swell.tasks.base.task_base import taskBase -from swell.utilities.datetime import datetime_formats +from swell.utilities.datetime_util import datetime_formats from swell.utilities.shell_commands import run_subprocess, create_executable_file @@ -28,7 +28,7 @@ class GsiNcdiagToIoda(taskBase): - def execute(self): + def execute(self) -> None: # Parse configuration # ------------------- diff --git a/src/swell/tasks/link_geos_output.py b/src/swell/tasks/link_geos_output.py index 67bcab30..320d6133 100644 --- a/src/swell/tasks/link_geos_output.py +++ b/src/swell/tasks/link_geos_output.py @@ -11,6 +11,7 @@ from netCDF4 import Dataset import numpy as np import xarray as xr +from typing import Tuple from swell.tasks.base.task_base import taskBase @@ -21,7 +22,7 @@ class LinkGeosOutput(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: """ Linking proper GEOS output files for JEDI to ingest and produce analysis. @@ -55,7 +56,7 @@ def execute(self): # ---------------------------------------------------------------------------------------------- - def link_mom6_history(self): + def link_mom6_history(self) -> Tuple[str, str]: # Create GEOS history to SOCA background link # TODO: this will only work for 3Dvar as FGAT requires multiple files @@ -69,7 +70,7 @@ def link_mom6_history(self): # ---------------------------------------------------------------------------------------------- - def link_mom6_restart(self): + def link_mom6_restart(self) -> Tuple[str, str]: # Create GEOS restart to SOCA background link # ------------------------------------------ @@ -97,7 +98,7 @@ def link_mom6_restart(self): # ---------------------------------------------------------------------------------------------- - def prepare_cice6(self): + def prepare_cice6(self) -> Tuple[str, str]: # CICE6 input in SOCA requires aggregation of multiple variables and # time dimension added to the dataset. diff --git a/src/swell/tasks/move_da_restart.py b/src/swell/tasks/move_da_restart.py index 92a4cf00..a48baa98 100644 --- a/src/swell/tasks/move_da_restart.py +++ b/src/swell/tasks/move_da_restart.py @@ -10,6 +10,7 @@ import glob import os import re +from typing import Union from swell.tasks.base.task_base import taskBase from swell.utilities.file_system_operations import move_files @@ -21,7 +22,7 @@ class MoveDaRestart(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: """ Moving restart files (i.e., _checkpoint) to the next cycle directory. @@ -55,7 +56,7 @@ def execute(self): # ---------------------------------------------------------------------------------------------- - def at_next_fcst_dir(self, paths): + def at_next_fcst_dir(self, paths: Union[str, list]) -> str: # Ensure what we have is a list (paths should be a list) # ------------------------------------------------------ @@ -69,7 +70,7 @@ def at_next_fcst_dir(self, paths): # ---------------------------------------------------------------------------------------------- - def cycling_restarts(self): + def cycling_restarts(self) -> None: # Move restarts (checkpoints) in the current cycle dir # ------------------------------------------------------ diff --git a/src/swell/tasks/move_forecast_restart.py b/src/swell/tasks/move_forecast_restart.py index 486e78d2..d4e29753 100644 --- a/src/swell/tasks/move_forecast_restart.py +++ b/src/swell/tasks/move_forecast_restart.py @@ -9,6 +9,7 @@ import os import glob +from typing import Union from swell.tasks.base.task_base import taskBase from swell.utilities.file_system_operations import move_files @@ -20,7 +21,7 @@ class MoveForecastRestart(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: """ Moving restart files (i.e., _checkpoint) to the next cycle geosdir. @@ -45,7 +46,7 @@ def execute(self): # ---------------------------------------------------------------------------------------------- - def at_next_fcst_dir(self, paths): + def at_next_fcst_dir(self, paths: Union[str, list]) -> str: # Ensure what we have is a list (paths should be a list) # ------------------------------------------------------ @@ -59,7 +60,7 @@ def at_next_fcst_dir(self, paths): # ---------------------------------------------------------------------------------------------- - def cycling_restarts(self): + def cycling_restarts(self) -> None: # Move restarts (checkpoints) in the current cycle dir # ------------------------------------------------------ diff --git a/src/swell/tasks/prep_geos_run_dir.py b/src/swell/tasks/prep_geos_run_dir.py index e67e3796..b28c2ec5 100644 --- a/src/swell/tasks/prep_geos_run_dir.py +++ b/src/swell/tasks/prep_geos_run_dir.py @@ -14,7 +14,7 @@ from datetime import datetime as dt from swell.tasks.base.task_base import taskBase -from swell.utilities.file_system_operations import copy_to_dst_dir +from swell.utilities.file_system_operations import copy_to_dst_dir, check_if_files_exist_in_path # -------------------------------------------------------------------------------------------------- @@ -23,7 +23,7 @@ class PrepGeosRunDir(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: """ Parses resource files in "geos_experiment_directory" to obtain required @@ -33,7 +33,17 @@ def execute(self): In GEOS speak, it creates the "scratch" directory. """ + swell_static_files_user = self.config.swell_static_files_user(None) self.swell_static_files = self.config.swell_static_files() + + # Use static_files_user if present in config and contains files + # ------------------------------------------------------------- + if swell_static_files_user is not None: + self.logger.info('swell_static_files_user specified, checking for files') + if check_if_files_exist_in_path(self.logger, swell_static_files_user): + self.logger.info(f'Using swell static files in {swell_static_files_user}') + self.swell_static_files = swell_static_files_user + # TODO: exp. directory location requires better handling self.geos_exp_dir = os.path.join(self.swell_static_files, 'geos', 'run_dirs', self.config.geos_experiment_directory()) @@ -152,7 +162,7 @@ def execute(self): # ---------------------------------------------------------------------------------------------- - def generate_extdata(self): + def generate_extdata(self) -> None: # Generate ExtData.rc according to emissions and EXTDATA2G options # 'w' option overwrites the contents or creates a new file @@ -186,7 +196,7 @@ def generate_extdata(self): # ---------------------------------------------------------------------------------------------- - def get_amip_emission(self): + def get_amip_emission(self) -> None: # Select proper AMIP GOCART Emission RC Files # ------------------------------------------- @@ -225,7 +235,7 @@ def get_amip_emission(self): # ---------------------------------------------------------------------------------------------- - def get_bcs(self): + def get_bcs(self) -> None: # This methods is highly dependent on the GEOSgcm version, currently # tested with GEOSgcm v11.6.0. It uses parsed .rc and .j files to define @@ -329,7 +339,7 @@ def get_bcs(self): # ---------------------------------------------------------------------------------------------- - def get_dynamic(self): + def get_dynamic(self) -> None: # Creating symlinks to BCs dictionary # Unlinks existing ones first @@ -355,7 +365,7 @@ def get_dynamic(self): # ---------------------------------------------------------------------------------------------- - def get_static(self): + def get_static(self) -> None: # Obtain experiment input files created by GEOS gcm_setup # -------------------------------------------------- @@ -374,7 +384,7 @@ def get_static(self): # ---------------------------------------------------------------------------------------------- - def link_replay(self): + def link_replay(self) -> None: # Linking REPLAY files according to AGCM.rc as in gcm_run.j # TODO: This needs another go over after GEOS Krok update @@ -402,7 +412,7 @@ def link_replay(self): # ---------------------------------------------------------------------------------------------- - def restructure_rc(self): + def restructure_rc(self) -> None: # 1MOM and GFDL microphysics do not use WSUB_NATURE # ------------------------------------------------- @@ -434,7 +444,7 @@ def restructure_rc(self): # ---------------------------------------------------------------------------------------------- - def rewrite_agcm(self, rcdict, rcfile): + def rewrite_agcm(self, rcdict: dict, rcfile: str) -> dict: # This part is relevant for move_da_restart task. Be mindful of your changes # and what impacts they might have on others (also a good motto in life). @@ -467,7 +477,7 @@ def rewrite_agcm(self, rcdict, rcfile): # ---------------------------------------------------------------------------------------------- - def rewrite_cap(self, rcdict, rcfile): + def rewrite_cap(self, rcdict: dict, rcfile: str) -> dict: # CAP.rc requires modifications before job submission # This method returns rcdict with the bool fix diff --git a/src/swell/tasks/prepare_analysis.py b/src/swell/tasks/prepare_analysis.py index 5a5de888..f8cb6507 100644 --- a/src/swell/tasks/prepare_analysis.py +++ b/src/swell/tasks/prepare_analysis.py @@ -11,6 +11,7 @@ import netCDF4 as nc import os import shutil +from typing import Union from swell.utilities.shell_commands import run_subprocess from swell.tasks.base.task_base import taskBase @@ -22,7 +23,7 @@ class PrepareAnalysis(taskBase): # -------------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: """ Updates variables in restart files with analysis variables. @@ -90,7 +91,7 @@ def execute(self): # ---------------------------------------------------------------------------------------- - def at_cycledir(self, paths=[]): + def at_cycledir(self, paths: Union[list, str] = []) -> str: # Ensure what we have is a list (paths should be a list) # ------------------------------------------------------ @@ -104,7 +105,7 @@ def at_cycledir(self, paths=[]): # -------------------------------------------------------------------------------------------------- - def mom6_increment(self, f_rst, ana_path, incr_path): + def mom6_increment(self, f_rst: str, ana_path: str, incr_path: str) -> None: # This method prepares MOM6 increment file for IAU during next cycle. # SOCA increment does not contain layer thickness (h) variable. Hence, @@ -133,7 +134,7 @@ def mom6_increment(self, f_rst, ana_path, incr_path): # -------------------------------------------------------------------------------------------------- - def replace_ocn(self, f_rst, ana_pth): + def replace_ocn(self, f_rst: str, ana_pth: str) -> None: # TODO: This will fail for multiple restart files and no IAU # ---------------------------------------------------------- diff --git a/src/swell/tasks/remove_forecast_dir.py b/src/swell/tasks/remove_forecast_dir.py index e7c96c82..a527e680 100644 --- a/src/swell/tasks/remove_forecast_dir.py +++ b/src/swell/tasks/remove_forecast_dir.py @@ -19,7 +19,7 @@ class RemoveForecastDir(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: self.logger.info(f"Removing old forecast directory: {self.forecast_dir()}") shutil.rmtree(self.forecast_dir()) diff --git a/src/swell/tasks/run_geos_executable.py b/src/swell/tasks/run_geos_executable.py index 926c5ad1..3e5af9f8 100644 --- a/src/swell/tasks/run_geos_executable.py +++ b/src/swell/tasks/run_geos_executable.py @@ -8,6 +8,7 @@ # -------------------------------------------------------------------------------------------------- import os +from typing import Optional from swell.tasks.base.task_base import taskBase from swell.utilities.shell_commands import run_track_log_subprocess @@ -17,7 +18,7 @@ class RunGeosExecutable(taskBase): - def execute(self): + def execute(self) -> None: # Obtain processor information from AGCM.rc # Strip is required in case AGCM.rc file was rewritten @@ -57,8 +58,15 @@ def execute(self): # ---------------------------------------------------------------------------------------------- - def run_executable(self, cycle_dir, np, geos_executable, geos_modules, output_log, - geos_lib_path=None): + def run_executable( + self, + cycle_dir: str, + np: int, + geos_executable: str, + geos_modules: str, + output_log: str, + geos_lib_path: Optional[str] = None + ) -> None: # Run the GEOS executable # ----------------------- diff --git a/src/swell/tasks/run_jedi_ensemble_mean_variance.py b/src/swell/tasks/run_jedi_ensemble_mean_variance.py index d9b2127f..c84e371d 100644 --- a/src/swell/tasks/run_jedi_ensemble_mean_variance.py +++ b/src/swell/tasks/run_jedi_ensemble_mean_variance.py @@ -22,7 +22,7 @@ class RunJediEnsembleMeanVariance(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: # Jedi application name # --------------------- diff --git a/src/swell/tasks/run_jedi_hofx_ensemble_executable.py b/src/swell/tasks/run_jedi_hofx_ensemble_executable.py index b8eb5fee..38d6b663 100644 --- a/src/swell/tasks/run_jedi_hofx_ensemble_executable.py +++ b/src/swell/tasks/run_jedi_hofx_ensemble_executable.py @@ -24,7 +24,7 @@ class RunJediHofxEnsembleExecutable(RunJediHofxExecutable, taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: # Jedi application name # --------------------- diff --git a/src/swell/tasks/run_jedi_hofx_executable.py b/src/swell/tasks/run_jedi_hofx_executable.py index d168e7b1..b675ab68 100644 --- a/src/swell/tasks/run_jedi_hofx_executable.py +++ b/src/swell/tasks/run_jedi_hofx_executable.py @@ -11,6 +11,7 @@ import glob import os import yaml +from typing import Optional from swell.tasks.base.task_base import taskBase from swell.utilities.netcdf_files import combine_files_without_groups @@ -24,7 +25,7 @@ class RunJediHofxExecutable(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self, ensemble_members=None): + def execute(self, ensemble_members: Optional[list] = None) -> None: # Jedi application name # --------------------- @@ -243,7 +244,13 @@ def execute(self, ensemble_members=None): # ---------------------------------------------------------------------------------------------- - def append_gomsaver(self, observations, jedi_config_dict, window_begin, mem=None): + def append_gomsaver( + self, + observations: list, + jedi_config_dict: dict, + window_begin: str, + mem: Optional[str] = None + ) -> None: # We may need to save the GeoVaLs for ensemble members. This will # prevent code repetition. diff --git a/src/swell/tasks/run_jedi_local_ensemble_da_executable.py b/src/swell/tasks/run_jedi_local_ensemble_da_executable.py index c3f6f450..3c52532e 100644 --- a/src/swell/tasks/run_jedi_local_ensemble_da_executable.py +++ b/src/swell/tasks/run_jedi_local_ensemble_da_executable.py @@ -22,7 +22,7 @@ class RunJediLocalEnsembleDaExecutable(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: # Jedi application name # --------------------- diff --git a/src/swell/tasks/run_jedi_ufo_tests_executable.py b/src/swell/tasks/run_jedi_ufo_tests_executable.py index 721bd0a5..b9eb33ee 100644 --- a/src/swell/tasks/run_jedi_ufo_tests_executable.py +++ b/src/swell/tasks/run_jedi_ufo_tests_executable.py @@ -24,7 +24,7 @@ class RunJediUfoTestsExecutable(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: # Jedi application name # --------------------- diff --git a/src/swell/tasks/run_jedi_variational_executable.py b/src/swell/tasks/run_jedi_variational_executable.py index ce94373e..8827ff1a 100644 --- a/src/swell/tasks/run_jedi_variational_executable.py +++ b/src/swell/tasks/run_jedi_variational_executable.py @@ -22,7 +22,7 @@ class RunJediVariationalExecutable(taskBase): # ---------------------------------------------------------------------------------------------- - def execute(self): + def execute(self) -> None: # Jedi application name # --------------------- diff --git a/src/swell/tasks/save_obs_diags.py b/src/swell/tasks/save_obs_diags.py index 02947311..a1c791c8 100644 --- a/src/swell/tasks/save_obs_diags.py +++ b/src/swell/tasks/save_obs_diags.py @@ -22,7 +22,7 @@ class SaveObsDiags(taskBase): Task to use R2D2 to save obs diag files from experiment to database """ - def execute(self): + def execute(self) -> None: # Parse config # ------------ diff --git a/src/swell/tasks/save_restart.py b/src/swell/tasks/save_restart.py index 8a880a83..5d9b86ad 100644 --- a/src/swell/tasks/save_restart.py +++ b/src/swell/tasks/save_restart.py @@ -16,7 +16,7 @@ class SaveRestart(taskBase): - def execute(self): + def execute(self) -> None: self.logger.info('SaveRestart') diff --git a/src/swell/tasks/stage_jedi.py b/src/swell/tasks/stage_jedi.py index e36d6f93..0bc76b46 100644 --- a/src/swell/tasks/stage_jedi.py +++ b/src/swell/tasks/stage_jedi.py @@ -13,6 +13,7 @@ from swell.tasks.base.task_base import taskBase from swell.utilities.filehandler import * from swell.utilities.exceptions import * +from swell.utilities.file_system_operations import check_if_files_exist_in_path # -------------------------------------------------------------------------------------------------- @@ -20,7 +21,7 @@ class StageJedi(taskBase): - def execute(self): + def execute(self) -> None: """Acquires listed files under the configuration/jedi/interface/model/stage.yaml file. Parameters @@ -31,7 +32,18 @@ def execute(self): # Extract potential template variables from config horizontal_resolution = self.config.horizontal_resolution() + + swell_static_files_user = self.config.swell_static_files_user(None) swell_static_files = self.config.swell_static_files() + + # Use static_files_user if present in config and contains files + # ------------------------------------------------------------- + if swell_static_files_user is not None: + self.logger.info('swell_static_files_user specified, checking for files') + if check_if_files_exist_in_path(self.logger, swell_static_files_user): + self.logger.info(f'Using swell static files in {swell_static_files_user}') + swell_static_files = swell_static_files_user + vertical_resolution = self.config.vertical_resolution() gsibec_configuration = self.config.gsibec_configuration(None) diff --git a/src/swell/tasks/store_background.py b/src/swell/tasks/store_background.py index b584ddde..ef49109b 100644 --- a/src/swell/tasks/store_background.py +++ b/src/swell/tasks/store_background.py @@ -15,7 +15,7 @@ from swell.tasks.base.task_base import taskBase -from swell.utilities.datetime import datetime_formats +from swell.utilities.datetime_util import datetime_formats from swell.utilities.r2d2 import create_r2d2_config @@ -24,7 +24,7 @@ class StoreBackground(taskBase): - def execute(self): + def execute(self) -> None: """Store background files for a given experiment and cycle in R2D2 diff --git a/src/swell/tasks/task_questions.yaml b/src/swell/tasks/task_questions.yaml index 1e7da753..19c2d535 100644 --- a/src/swell/tasks/task_questions.yaml +++ b/src/swell/tasks/task_questions.yaml @@ -848,7 +848,11 @@ swell_static_files_user: default_value: None prompt: What is the path to the user provided Swell Static Files directory? tasks: + - GenerateBClimatology - GenerateBClimatologyByLinking + - GetGeosRestart + - PrepGeosRunDir + - StageJedi type: string total_processors: diff --git a/src/swell/test/code_tests/code_tests.py b/src/swell/test/code_tests/code_tests.py index 9e919902..be6c3bb1 100644 --- a/src/swell/test/code_tests/code_tests.py +++ b/src/swell/test/code_tests/code_tests.py @@ -20,7 +20,7 @@ # -------------------------------------------------------------------------------------------------- -def code_tests(): +def code_tests() -> None: # Create a logger logger = Logger('TestSuite') diff --git a/src/swell/test/code_tests/missing_obs_test.py b/src/swell/test/code_tests/missing_obs_test.py index 7e9e609c..bb799abd 100644 --- a/src/swell/test/code_tests/missing_obs_test.py +++ b/src/swell/test/code_tests/missing_obs_test.py @@ -1,4 +1,4 @@ -from swell.utilities.datetime import Datetime +from swell.utilities.datetime_util import Datetime from swell.utilities.run_jedi_executables import check_obs from swell.utilities.render_jedi_interface_files import JediConfigRendering from swell.utilities.logger import Logger diff --git a/src/swell/test/code_tests/slurm_test.py b/src/swell/test/code_tests/slurm_test.py index 3f4973b7..f84977cc 100644 --- a/src/swell/test/code_tests/slurm_test.py +++ b/src/swell/test/code_tests/slurm_test.py @@ -11,7 +11,7 @@ import unittest from swell.utilities.slurm import prepare_scheduling_dict -from unittest.mock import patch +from unittest.mock import patch, Mock # -------------------------------------------------------------------------------------------------- @@ -22,7 +22,7 @@ class SLURMConfigTest(unittest.TestCase): # configuration and platform-specific settings @patch("swell.utilities.slurm.slurm_global_defaults") @patch("platform.platform") - def test_slurm_config(self, platform_mocked, mock_global_defaults): + def test_slurm_config(self, platform_mocked: Mock, mock_global_defaults: Mock) -> None: logger = logging.getLogger() diff --git a/src/swell/test/code_tests/unused_variables_test.py b/src/swell/test/code_tests/unused_variables_test.py index 865b1bb7..4d348b08 100644 --- a/src/swell/test/code_tests/unused_variables_test.py +++ b/src/swell/test/code_tests/unused_variables_test.py @@ -18,7 +18,7 @@ # -------------------------------------------------------------------------------------------------- -def run_flake8(file_path): +def run_flake8(file_path: str) -> str: flake8_cmd = ['flake8', '--select', 'F401,F841', file_path] result = subprocess.run(flake8_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) @@ -30,7 +30,7 @@ def run_flake8(file_path): class UnusedVariablesTest(unittest.TestCase): - def test_unused_variables(self): + def test_unused_variables(self) -> None: for root, _, files in os.walk(get_swell_path()): for filename in files: diff --git a/src/swell/test/test_driver.py b/src/swell/test/test_driver.py index 414e89b5..603a5925 100644 --- a/src/swell/test/test_driver.py +++ b/src/swell/test/test_driver.py @@ -15,7 +15,7 @@ # -------------------------------------------------------------------------------------------------- -def test_wrapper(test): +def test_wrapper(test: str) -> None: # Test script test_script_file = 'swell.test.'+test+'.'+test diff --git a/src/swell/utilities/build.py b/src/swell/utilities/build.py index e213d169..450a5077 100644 --- a/src/swell/utilities/build.py +++ b/src/swell/utilities/build.py @@ -9,6 +9,7 @@ import os import shutil +from typing import Tuple from jedi_bundle.bin.jedi_bundle import get_default_config from jedi_bundle.config.config import check_platform @@ -17,7 +18,7 @@ # -------------------------------------------------------------------------------------------------- -def build_and_source_dirs(package_path): +def build_and_source_dirs(package_path: str) -> Tuple[str, str]: # Make package directory # ---------------------- @@ -36,7 +37,7 @@ def build_and_source_dirs(package_path): # -------------------------------------------------------------------------------------------------- -def link_path(source, target): +def link_path(source: str, target: str) -> None: # Remove existing source path if present if os.path.islink(target): # Is a link @@ -51,8 +52,13 @@ def link_path(source, target): # -------------------------------------------------------------------------------------------------- -def set_jedi_bundle_config(bundles, path_to_source, path_to_build, platform, - cores_to_use_for_make=6): +def set_jedi_bundle_config( + bundles: list, + path_to_source: str, + path_to_build: str, + platform: str, + cores_to_use_for_make: int = 6 +) -> dict: # Start from the default jedi_bundle config file jedi_bundle_config = get_default_config() diff --git a/src/swell/utilities/case_switching.py b/src/swell/utilities/case_switching.py index 900962ed..7e8744a9 100644 --- a/src/swell/utilities/case_switching.py +++ b/src/swell/utilities/case_switching.py @@ -8,7 +8,7 @@ # -------------------------------------------------------------------------------------------------- -def camel_case_to_snake_case(CamelCaseString): +def camel_case_to_snake_case(CamelCaseString: str) -> str: # Convert a string that looks like e.g. ThisIsAString to this_is_a_string # ----------------------------------------------------------------------- @@ -24,7 +24,7 @@ def camel_case_to_snake_case(CamelCaseString): # -------------------------------------------------------------------------------------------------- -def snake_case_to_camel_case(snake_case_string): +def snake_case_to_camel_case(snake_case_string: str) -> str: # Convert a string that looks like e.g. this_is_a_string to ThisIsAString # ----------------------------------------------------------------------- diff --git a/src/swell/utilities/config.py b/src/swell/utilities/config.py index 1d24bf08..e21ed912 100644 --- a/src/swell/utilities/config.py +++ b/src/swell/utilities/config.py @@ -9,8 +9,10 @@ import os import yaml +from typing import Callable from swell.swell_path import get_swell_path +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- @@ -44,7 +46,7 @@ class Config(): # ---------------------------------------------------------------------------------------------- - def __init__(self, input_file, logger, task_name, model): + def __init__(self, input_file: str, logger: Logger, task_name: str, model: str) -> None: # Keep copy of owner's logger self.__logger__ = logger @@ -120,7 +122,7 @@ def __init__(self, input_file, logger, task_name, model): # ---------------------------------------------------------------------------------------------- - def get(self, experiment_key): + def get(self, experiment_key: str) -> Callable: def getter(default='None'): return getattr(self, f'__{experiment_key}__') return getter @@ -129,7 +131,7 @@ def getter(default='None'): # Implementation of __getattr__ to ensure there is no crash when a task requests a variable that # does not exist. This is valid so long as the task provides a default value. - def __getattr__(self, name): + def __getattr__(self, name: str) -> Callable: def variable_not_found(default='LrZRExPGcQ'): if default == 'LrZRExPGcQ': self.__logger__.abort(f'In config class, trying to get variable \'{name}\' but ' + diff --git a/src/swell/utilities/data_assimilation_window_params.py b/src/swell/utilities/data_assimilation_window_params.py index bda948f1..6857b363 100644 --- a/src/swell/utilities/data_assimilation_window_params.py +++ b/src/swell/utilities/data_assimilation_window_params.py @@ -9,8 +9,10 @@ import datetime import isodate +from typing import Union -from swell.utilities.datetime import datetime_formats +from swell.utilities.datetime_util import datetime_formats +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- @@ -18,7 +20,7 @@ class DataAssimilationWindowParams(): - def __init__(self, logger, cycle_time): + def __init__(self, logger: Logger, cycle_time: str) -> None: """ Defines cycle dependent parameters for the data assimilation window and adds to config @@ -33,14 +35,18 @@ def __init__(self, logger, cycle_time): # ---------------------------------------------------------------------------------------------- - def __get_window_begin_dto__(self, window_offset): + def __get_window_begin_dto__(self, window_offset: str) -> datetime.datetime: window_offset_dur = isodate.parse_duration(window_offset) return self.__current_cycle_dto__ - window_offset_dur # ---------------------------------------------------------------------------------------------- - def __get_local_background_time__(self, window_type, window_offset): + def __get_local_background_time__( + self, + window_type: str, + window_offset: str + ) -> datetime.datetime: # Background time for the window if window_type == '4D': @@ -52,7 +58,7 @@ def __get_local_background_time__(self, window_type, window_offset): # ---------------------------------------------------------------------------------------------- - def window_begin(self, window_offset, dto=False): + def window_begin(self, window_offset: str, dto: bool = False) -> Union[str, datetime.datetime]: window_begin_dto = self.__get_window_begin_dto__(window_offset) @@ -64,7 +70,7 @@ def window_begin(self, window_offset, dto=False): # ---------------------------------------------------------------------------------------------- - def window_begin_iso(self, window_offset, dto=False): + def window_begin_iso(self, window_offset: str, dto: bool = False): window_begin_dto = self.__get_window_begin_dto__(window_offset) @@ -76,7 +82,7 @@ def window_begin_iso(self, window_offset, dto=False): # ---------------------------------------------------------------------------------------------- - def window_end_iso(self, window_offset, window_length, dto=False): + def window_end_iso(self, window_offset: str, window_length: str, dto: bool = False) -> str: # Compute window length duration window_length_dur = isodate.parse_duration(window_length) @@ -95,7 +101,7 @@ def window_end_iso(self, window_offset, window_length, dto=False): # ---------------------------------------------------------------------------------------------- - def background_time(self, window_offset, background_time_offset): + def background_time(self, window_offset: str, background_time_offset: str) -> str: background_time_offset_dur = isodate.parse_duration(background_time_offset) background_time_dto = self.__current_cycle_dto__ - background_time_offset_dur @@ -103,14 +109,14 @@ def background_time(self, window_offset, background_time_offset): # ---------------------------------------------------------------------------------------------- - def local_background_time_iso(self, window_offset, window_type): + def local_background_time_iso(self, window_offset: str, window_type: str) -> str: local_background_time = self.__get_local_background_time__(window_type, window_offset) return local_background_time.strftime(datetime_formats['iso_format']) # ---------------------------------------------------------------------------------------------- - def local_background_time(self, window_offset, window_type): + def local_background_time(self, window_offset: str, window_type: str) -> str: local_background_time = self.__get_local_background_time__(window_type, window_offset) return local_background_time.strftime(datetime_formats['directory_format']) diff --git a/src/swell/utilities/datetime.py b/src/swell/utilities/datetime_util.py similarity index 94% rename from src/swell/utilities/datetime.py rename to src/swell/utilities/datetime_util.py index 904291c1..024409d4 100644 --- a/src/swell/utilities/datetime.py +++ b/src/swell/utilities/datetime_util.py @@ -11,7 +11,6 @@ import re import datetime as pydatetime - # -------------------------------------------------------------------------------------------------- datetime_formats = { @@ -26,7 +25,7 @@ class Datetime: - def __init__(self, datetime_input): + def __init__(self, datetime_input) -> None: # Convert input string to standard format yyyymmddHHMMSS datetime_str = re.sub('[^0-9]', '', datetime_input+'000000')[0:14] @@ -42,13 +41,13 @@ def dto(self): # ---------------------------------------------------------------------------------------------- - def string_iso(self): + def string_iso(self) -> str: return self.__datetime__.strftime(datetime_formats['iso_format']) # ---------------------------------------------------------------------------------------------- - def string_directory(self): + def string_directory(self) -> str: return self.__datetime__.strftime(datetime_formats['directory_format']) diff --git a/src/swell/utilities/dictionary.py b/src/swell/utilities/dictionary.py index e5d28826..82f0a73b 100644 --- a/src/swell/utilities/dictionary.py +++ b/src/swell/utilities/dictionary.py @@ -9,12 +9,19 @@ import yaml from collections.abc import Hashable +from typing import Union +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- -def dict_get(logger, dictionary, key, default='NODEFAULT'): +def dict_get( + logger: Logger, + dictionary: dict, + key: str, + default: str = 'NODEFAULT' +) -> str: if key in dictionary.keys(): @@ -32,7 +39,7 @@ def dict_get(logger, dictionary, key, default='NODEFAULT'): # -------------------------------------------------------------------------------------------------- -def remove_matching_keys(d, key): +def remove_matching_keys(d: Union[dict, list], key: str) -> None: """ Recursively locates and removes all dictionary items matching the supplied key. Parameters @@ -60,7 +67,11 @@ def remove_matching_keys(d, key): # -------------------------------------------------------------------------------------------------- -def add_comments_to_dictionary(logger, dictionary_string, comment_dictionary): +def add_comments_to_dictionary( + logger: Logger, + dictionary_string: str, + comment_dictionary: dict +) -> str: dict_str_items = dictionary_string.split('\n') @@ -107,7 +118,7 @@ def add_comments_to_dictionary(logger, dictionary_string, comment_dictionary): # -------------------------------------------------------------------------------------------------- -def replace_string_in_dictionary(dictionary, string_in, string_out): +def replace_string_in_dictionary(dictionary: dict, string_in: str, string_out: str) -> object: # Convert dictionary to string dictionary_string = yaml.dump(dictionary, default_flow_style=False, sort_keys=False) @@ -122,7 +133,7 @@ def replace_string_in_dictionary(dictionary, string_in, string_out): # -------------------------------------------------------------------------------------------------- -def write_dict_to_yaml(dictionary, file): +def write_dict_to_yaml(dictionary: dict, file: str) -> None: # Convert dictionary to YAML string dictionary_string = yaml.dump(dictionary, default_flow_style=False, sort_keys=False) @@ -135,7 +146,7 @@ def write_dict_to_yaml(dictionary, file): # -------------------------------------------------------------------------------------------------- -def update_dict(orig_dict, overwrite_dict): +def update_dict(orig_dict: dict, overwrite_dict: dict) -> dict: # Create output dictionary from original dictionary output_dict = orig_dict.copy() @@ -152,7 +163,7 @@ def update_dict(orig_dict, overwrite_dict): # -------------------------------------------------------------------------------------------------- -def dictionary_override(logger, orig_dict, override_dict): +def dictionary_override(logger: Logger, orig_dict: dict, override_dict: dict) -> dict: for key, value in override_dict.items(): if value == 'REMOVE': orig_dict.pop(key, None) diff --git a/src/swell/utilities/exceptions.py b/src/swell/utilities/exceptions.py index b0031d08..943ce607 100644 --- a/src/swell/utilities/exceptions.py +++ b/src/swell/utilities/exceptions.py @@ -5,10 +5,17 @@ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. # ----------------------------------------------------------------------------- +from typing import Optional +from logging import Logger as pyLogger + class SWELLError(Exception): - def __init__(self, message, logger=None): + def __init__( + self, + message: str, + logger: Optional[pyLogger] = None + ) -> None: self.message = message super().__init__(message) diff --git a/src/swell/utilities/file_system_operations.py b/src/swell/utilities/file_system_operations.py index e8aeae15..03a5f9b6 100644 --- a/src/swell/utilities/file_system_operations.py +++ b/src/swell/utilities/file_system_operations.py @@ -10,10 +10,12 @@ import os import shutil +from swell.utilities.logger import Logger + # -------------------------------------------------------------------------------------------------- -def copy_to_dst_dir(logger, src, dst_dir): +def copy_to_dst_dir(logger: Logger, src: str, dst_dir: str) -> None: """ Source could be a directory or single file which necesitates different handling """ @@ -32,7 +34,11 @@ def copy_to_dst_dir(logger, src, dst_dir): # -------------------------------------------------------------------------------------------------- -def link_all_files_from_first_in_hierarchy_of_sources(logger, source_paths, target_path): +def link_all_files_from_first_in_hierarchy_of_sources( + logger: Logger, + source_paths: list, + target_path: str +) -> None: """For a list of source paths check for the existence of the source paths and for files residing in at least one of the paths. For the first source path in the list that is found to contain files link them into the target path (remove first if existing) @@ -45,6 +51,7 @@ def link_all_files_from_first_in_hierarchy_of_sources(logger, source_paths, targ """ # First sweep to see if directories exist + found_paths = [] for source_path in source_paths: if os.path.exists(source_path): @@ -82,7 +89,35 @@ def link_all_files_from_first_in_hierarchy_of_sources(logger, source_paths, targ # -------------------------------------------------------------------------------------------------- -def link_file_existing_link_ok(logger, source_path_file, target_path_file): +def check_if_files_exist_in_path(logger: Logger, path_to_files: str) -> bool: + """Checks if path to directory exists, and if files are present within it. + + Parameters + ---------- + logger: Logger to output results to + path_to_files: path to target directory + """ + + files_exist = False + + if os.path.exists(path_to_files): + if os.listdir(path_to_files): + logger.info(f'Files found within {path_to_files}') + files_exist = True + else: + logger.info(f'No files found within {path_to_files}') + + return files_exist + + +# -------------------------------------------------------------------------------------------------- + + +def link_file_existing_link_ok( + logger: Logger, + source_path_file: str, + target_path_file: str +) -> None: """Create a symbolic link from a source location to a target location. If a symbolic link already exists it will be deleted. If a file already exists and it is not a link the code @@ -112,7 +147,7 @@ def link_file_existing_link_ok(logger, source_path_file, target_path_file): # ---------------------------------------------------------------------------------------------- -def move_files(logger, src_dir, dst_dir): +def move_files(logger: Logger, src_dir: str, dst_dir: str) -> None: try: logger.info(' Moving file(s) from: '+src_dir) diff --git a/src/swell/utilities/filehandler.py b/src/swell/utilities/filehandler.py index 6d32c5d1..d1adbae9 100755 --- a/src/swell/utilities/filehandler.py +++ b/src/swell/utilities/filehandler.py @@ -57,16 +57,18 @@ # is omitted, all files in "src" will be copied to "dst". # ----------------------------------------------------------------------------- +from __future__ import annotations import os import glob import copy import datetime as dt from shutil import copyfile +from typing import Union, Optional, Any from swell.utilities.exceptions import * -def get_file_handler(config, **kwargs): +def get_file_handler(config: list, **kwargs) -> Union[StageFileHandler, GetDataFileHandler]: """Factory for determining the file handler type for retrieving data. This method uses a heuristic algorithm to determine the staging @@ -106,7 +108,7 @@ def get_file_handler(config, **kwargs): class FileHandler(object): - def __init__(self, config, **kwargs): + def __init__(self, config: list, **kwargs) -> None: self.listing = [] self.config = copy.deepcopy(config) @@ -114,7 +116,7 @@ def __init__(self, config, **kwargs): # ------------------------------------------------------------------------------ - def is_ready(self, fc=None): + def is_ready(self, fc: Optional[FileCollection] = None) -> bool: """Determines if the file collection meets the criteria for readiness (e.g. minimum file count etc.) @@ -156,7 +158,7 @@ def is_ready(self, fc=None): # ------------------------------------------------------------------------ - def get(self, fc=None): + def get(self, fc: Optional[FileCollection] = None) -> None: """Retrieves the files in the specified file collection. Parameters @@ -187,7 +189,7 @@ def get(self, fc=None): # --------------------------------------------------------------------------- - def copy(self, src, dst): + def copy(self, src: str, dst: str) -> None: """File handler - copies a file Parameters @@ -209,7 +211,7 @@ def copy(self, src, dst): # --------------------------------------------------------------------------- - def link(self, src, dst): + def link(self, src: str, dst: str) -> None: """File handler - Symbolically links a file Parameters @@ -239,7 +241,7 @@ def link(self, src, dst): class StageFileHandler(FileHandler): - def list(self, force=False): + def list(self, force: bool = False) -> list: """Creates a list of file collections defined in configuration using the "stage" data structure convention. @@ -299,7 +301,7 @@ def list(self, force=False): class GetDataFileHandler(FileHandler): - def list(self, force=False): + def list(self, force: bool = False) -> list: """Creates a list of file collections defined in configuration using the "get_data" data structure convention. @@ -383,7 +385,7 @@ def list(self, force=False): class FileCollection(object): - def __init__(self, config): + def __init__(self, config: dict[Any, Any]) -> None: self.config = copy.deepcopy(config) @@ -394,13 +396,13 @@ def __init__(self, config): # ------------------------------------------------------------------------------ - def update(self, srcfile, dstfile): + def update(self, srcfile: list, dstfile: list) -> None: self.listing.append((srcfile, dstfile)) # ------------------------------------------------------------------------------ - def num_files(self): return len(self.listing) + def num_files(self) -> int: return len(self.listing) # ------------------------------------------------------------------------------ diff --git a/src/swell/utilities/geos.py b/src/swell/utilities/geos.py index 2e00ed18..c3e88117 100644 --- a/src/swell/utilities/geos.py +++ b/src/swell/utilities/geos.py @@ -7,16 +7,18 @@ # -------------------------------------------------------------------------------------------------- -from datetime import datetime +import datetime import f90nml import glob import isodate import netCDF4 import os import re +from typing import Tuple, Optional, Union from swell.utilities.shell_commands import run_subprocess -from swell.utilities.datetime import datetime_formats +from swell.utilities.datetime_util import datetime_formats +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- @@ -25,7 +27,7 @@ class Geos(): # ---------------------------------------------------------------------------------------------- - def __init__(self, logger, forecast_dir): + def __init__(self, logger: Logger, forecast_dir: Optional[str]) -> None: ''' Intention with GEOS class is to not have any model dependent methods. @@ -38,12 +40,16 @@ def __init__(self, logger, forecast_dir): # ---------------------------------------------------------------------------------------------- - def adjacent_cycle(self, offset, return_date=False): + def adjacent_cycle( + self, + offset: str, + return_date: bool = False + ) -> Union[str, datetime.datetime]: # Basename consists of swell datetime and model # --------------------------------------------- dt_str = os.path.basename(os.path.dirname(self.forecast_dir)) - dt_obj = datetime.strptime(dt_str, datetime_formats['directory_format']) + dt_obj = datetime.datetime.strptime(dt_str, datetime_formats['directory_format']) # Modify datetime by using date offset # ------------------------------------ @@ -64,7 +70,7 @@ def adjacent_cycle(self, offset, return_date=False): # ---------------------------------------------------------------------------------------------- - def chem_rename(self, rcdict): + def chem_rename(self, rcdict: dict) -> None: # Some files are renamed according to bool. switches in GEOS_ChemGridComp.rc # ------------------------------------------------------------------------- @@ -93,7 +99,7 @@ def chem_rename(self, rcdict): # ---------------------------------------------------------------------------------------------- - def exec_python(self, script_src, script, input=''): + def exec_python(self, script_src: str, script: str, input: str = '') -> None: # Source g5_modules and execute py scripts in a new shell process then # return to the current one @@ -109,7 +115,7 @@ def exec_python(self, script_src, script, input=''): # ---------------------------------------------------------------------------------------------- - def get_rst_time(self): + def get_rst_time(self) -> datetime.datetime: # Obtain time information from any of the rst files listed by glob # ---------------------------------------------------------------- @@ -132,7 +138,11 @@ def get_rst_time(self): # ---------------------------------------------------------------------------------------------- - def iso_to_time_str(self, iso_duration, half=False): + def iso_to_time_str( + self, + iso_duration: str, + half: bool = False + ) -> Tuple[str, int, datetime.timedelta]: # Parse the ISO duration string and get the total number of seconds # It is written to handle fcst_duration less than a day for now @@ -163,7 +173,7 @@ def iso_to_time_str(self, iso_duration, half=False): # ---------------------------------------------------------------------------------------------- - def linker(self, src, dst, dst_dir=None): + def linker(self, src: str, dst: str, dst_dir: str = None) -> None: # Link files from BC directories # ------------------------------ @@ -196,7 +206,7 @@ def linker(self, src, dst, dst_dir=None): # ---------------------------------------------------------------------------------------------- - def parse_gcmrun(self, jfile): + def parse_gcmrun(self, jfile: str) -> dict: # Parse gcm_run.j line by line and snatch setenv variables. gcm_setup # creates gcm_run.j and handles platform dependencies. @@ -234,7 +244,7 @@ def parse_gcmrun(self, jfile): # ---------------------------------------------------------------------------------------------- - def parse_rc(self, rcfile): + def parse_rc(self, rcfile: str) -> dict: # Parse AGCM.rc & CAP.rc line by line. It ignores comments and commented # out lines. Some values involve multiple ":" characters which required @@ -283,7 +293,7 @@ def parse_rc(self, rcfile): # ---------------------------------------------------------------------------------------------- - def process_nml(self, cold_restart=False): + def process_nml(self, cold_restart: bool = False) -> None: # In gcm_run.j, fvcore_layout.rc is concatenated with input.nml # ------------------------------------------------------------- @@ -310,7 +320,7 @@ def process_nml(self, cold_restart=False): # ---------------------------------------------------------------------------------------------- - def rc_assign(self, rcdict, key_inquiry): + def rc_assign(self, rcdict: dict, key_inquiry: str) -> None: # Some of the gcm_run.j steps involve setting environment values using # .rc files. These files may or may not have some of the key values used @@ -322,7 +332,7 @@ def rc_assign(self, rcdict, key_inquiry): # -------------------------------------------------------------------------------------------------- - def rc_to_bool(self, rcdict): + def rc_to_bool(self, rcdict: dict) -> dict: # .rc files have switch values in .TRUE. or .FALSE. format, some might # have T and F. @@ -361,7 +371,7 @@ def rename_checkpoints(self, next_geosdir): # -------------------------------------------------------------------------------------------------- - def resub(self, filename, pattern, replacement): + def resub(self, filename: str, pattern: str, replacement: str) -> None: # Replacing string values involving wildcards # ------------------------------------------- diff --git a/src/swell/utilities/get_channels.py b/src/swell/utilities/get_channels.py index a35e5edc..11aba2d1 100644 --- a/src/swell/utilities/get_channels.py +++ b/src/swell/utilities/get_channels.py @@ -11,11 +11,14 @@ import os from datetime import datetime as dt from itertools import groupby +from typing import Tuple, Optional + +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- -def process_channel_lists(channel_list): +def process_channel_lists(channel_list: list) -> list: ''' Function processes list of elements in channel list @@ -37,7 +40,7 @@ def process_channel_lists(channel_list): # -------------------------------------------------------------------------------------------------- -def create_range_string(avail_list): +def create_range_string(avail_list: list) -> str: ''' Function converts integer list into string of ranges ''' @@ -53,7 +56,7 @@ def create_range_string(avail_list): # -------------------------------------------------------------------------------------------------- -def get_channel_list(input_dict, dt_cycle_time): +def get_channel_list(input_dict: dict, dt_cycle_time: dt) -> list: ''' Function retrieves channel lists from dict loaded from a yaml file @@ -68,7 +71,12 @@ def get_channel_list(input_dict, dt_cycle_time): # -------------------------------------------------------------------------------------------------- -def get_channels(path_to_observing_sys_yamls, observation, dt_cycle_time, logger): +def get_channels( + path_to_observing_sys_yamls: str, + observation: str, + dt_cycle_time: dt, + logger: Logger +) -> Tuple[Optional[str], Optional[list[int]]]: ''' Comparing available channels and active channels from the observing @@ -107,7 +115,11 @@ def get_channels(path_to_observing_sys_yamls, observation, dt_cycle_time, logger # -------------------------------------------------------------------------------------------------- -def num_active_channels(path_to_observing_sys_yamls, observation, dt_cycle_time): +def num_active_channels( + path_to_observing_sys_yamls: str, + observation: str, + dt_cycle_time: dt +) -> Optional[int]: # Retrieve available and active channels from records yaml path_to_observing_sys_config = path_to_observing_sys_yamls + '/' + \ diff --git a/src/swell/utilities/git_utils.py b/src/swell/utilities/git_utils.py index 950073ac..9db9e7af 100644 --- a/src/swell/utilities/git_utils.py +++ b/src/swell/utilities/git_utils.py @@ -10,12 +10,12 @@ import os from swell.utilities.shell_commands import run_subprocess_dev_null - +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- -def git_change_branch(logger, git_branch, out_dir): +def git_change_branch(logger: Logger, git_branch: str, out_dir: str) -> None: # Change to a specific branch # --------------------------- @@ -34,7 +34,13 @@ def git_change_branch(logger, git_branch, out_dir): # -------------------------------------------------------------------------------------------------- -def git_clone(logger, git_url, git_branch, out_dir, change_branch=False): +def git_clone( + logger: Logger, + git_url: str, + git_branch: str, + out_dir: str, + change_branch: bool = False +) -> None: # Clone repo at git_url to out_dir # -------------------------------- @@ -55,7 +61,7 @@ def git_clone(logger, git_url, git_branch, out_dir, change_branch=False): # -------------------------------------------------------------------------------------------------- -def git_got(git_url, git_branch, out_dir, logger): +def git_got(git_url: str, git_branch: str, out_dir: str, logger: Logger): # Clone repo at git_url to out_dir # -------------------------------- diff --git a/src/swell/utilities/gsi_record_parser.py b/src/swell/utilities/gsi_record_parser.py index 50862b43..d11dabbc 100644 --- a/src/swell/utilities/gsi_record_parser.py +++ b/src/swell/utilities/gsi_record_parser.py @@ -2,7 +2,7 @@ import numpy as np -def check_end_time(end_time): +def check_end_time(end_time: str) -> str: ''' Checks end times for 24 hour strings and converts them to 18 ''' @@ -17,13 +17,13 @@ def check_end_time(end_time): class GSIRecordParser: - def __init__(self): + def __init__(self) -> None: self.instr_df = None self.return_df = None self.sat = None self.instr = None - def get_channel_list(self, start): + def get_channel_list(self, start: int) -> list: channel_list = [] rows = self.instr_df.loc[self.instr_df["start"] == start] for row_ch_list in rows["channels"].values: @@ -34,7 +34,7 @@ def get_channel_list(self, start): channel_list.sort(key=int) return channel_list - def run(self, instr_df): + def run(self, instr_df: pd.DataFrame) -> None: # Save instrument dataframe self.instr_df = instr_df @@ -118,7 +118,7 @@ def run(self, instr_df): channel_list, comment) done.append(inner_start[inner_idx]) - def update_return_df(self, start, end, channel_list, comment): + def update_return_df(self, start: str, end: str, channel_list: list, comment: str) -> None: # Fix end time if on the 24 hour mark end = check_end_time(end) @@ -134,7 +134,7 @@ def update_return_df(self, start, end, channel_list, comment): self.return_df = pd.concat([self.return_df, new_row], ignore_index=True) - def get_instr_df(self): + def get_instr_df(self) -> pd.DataFrame: ''' Returns the dataframe that the state machine generated! ''' diff --git a/src/swell/utilities/jinja2.py b/src/swell/utilities/jinja2.py index 1403d3e5..b45c2073 100644 --- a/src/swell/utilities/jinja2.py +++ b/src/swell/utilities/jinja2.py @@ -6,6 +6,8 @@ # -------------------------------------------------------------------------------------------------- +from __future__ import annotations +from typing import Union import jinja2 as j2 @@ -28,25 +30,25 @@ class SilentUndefined(j2.Undefined): See `ask_questions_and_configure_suite` method in `prepare_config_and_suite.py` for more details on Jinja2 passes. """ - def __getattr__(self, name): + def __getattr__(self, name: str) -> SilentUndefined: # Return a new SilentUndefined instance but append the attribute access to the name. return SilentUndefined(name=f"{self._undefined_name}.{name}") - def __getitem__(self, key): + def __getitem__(self, key: Union[str, int]) -> SilentUndefined: # Similar to __getattr__, return a new instance with the key access incorporated. if isinstance(key, str): return SilentUndefined(name=f"{self._undefined_name}['{key}']") return SilentUndefined(name=f"{self._undefined_name}[{key}]") - def items(self): + def items(self) -> list: # Return an empty list when items method is called. return [] - def __str__(self): + def __str__(self) -> str: # Ensure the name returned reflects the original template placeholder. return f"{{{{ {self._undefined_name} }}}}" - def __repr__(self): + def __repr__(self) -> str: return str(self) diff --git a/src/swell/utilities/logger.py b/src/swell/utilities/logger.py index 8b02bc89..66f6c575 100644 --- a/src/swell/utilities/logger.py +++ b/src/swell/utilities/logger.py @@ -35,7 +35,7 @@ class Logger: - def __init__(self, task_name): + def __init__(self, task_name: str) -> None: self.task_name = task_name @@ -61,7 +61,7 @@ def __init__(self, task_name): # ---------------------------------------------------------------------------------------------- - def send_message(self, level, message, wrap): + def send_message(self, level: str, message: str, wrap: bool) -> None: # Wrap the message if needed if wrap: @@ -97,37 +97,37 @@ def send_message(self, level, message, wrap): # ---------------------------------------------------------------------------------------------- - def info(self, message, wrap=True): + def info(self, message: str, wrap: bool = True) -> None: self.send_message('INFO', message, wrap) # ---------------------------------------------------------------------------------------------- - def test(self, message, wrap=True): + def test(self, message: str, wrap: bool = True) -> None: self.send_message('TEST', message, wrap) # ---------------------------------------------------------------------------------------------- - def trace(self, message, wrap=True): + def trace(self, message: str, wrap: bool = True) -> None: self.send_message('TRACE', message, wrap) # ---------------------------------------------------------------------------------------------- - def debug(self, message, wrap=True): + def debug(self, message: str, wrap: bool = True) -> None: self.send_message('DEBUG', message, wrap) # ---------------------------------------------------------------------------------------------- - def blank(self, message, wrap=True): + def blank(self, message: str, wrap: bool = True) -> None: self.send_message('BLANK', message, wrap) # ---------------------------------------------------------------------------------------------- - def abort(self, message, wrap=True): + def abort(self, message: str, wrap: bool = True) -> None: # Make the text red message = red + message + end @@ -147,7 +147,7 @@ def abort(self, message, wrap=True): # ---------------------------------------------------------------------------------------------- - def assert_abort(self, condition, message, wrap=True): + def assert_abort(self, condition: bool, message: str, wrap: bool = True) -> None: if condition: return @@ -156,7 +156,7 @@ def assert_abort(self, condition, message, wrap=True): # ---------------------------------------------------------------------------------------------- - def input(self, message): + def input(self, message: str) -> None: input(' '+self.task_name+': '+message + ". Press any key to continue...") diff --git a/src/swell/utilities/netcdf_files.py b/src/swell/utilities/netcdf_files.py index 8fef1a23..4f1925f9 100644 --- a/src/swell/utilities/netcdf_files.py +++ b/src/swell/utilities/netcdf_files.py @@ -10,13 +10,20 @@ import os import xarray as xr +from typing import Hashable, Union +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- -def combine_files_without_groups(logger, list_of_input_files, output_file, concat_dim, - delete_input=False): +def combine_files_without_groups( + logger: Logger, + list_of_input_files: list, + output_file: str, + concat_dim: Union[Hashable, xr.Variable, xr.DataArray], + delete_input: bool = False +) -> None: # Write some information logger.info('Combining the following netCDF files (using no-group combine): ') diff --git a/src/swell/utilities/observations.py b/src/swell/utilities/observations.py index ae22e7f8..f8fe8f7a 100644 --- a/src/swell/utilities/observations.py +++ b/src/swell/utilities/observations.py @@ -11,12 +11,13 @@ import yaml from swell.swell_path import get_swell_path +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- -def ioda_name_to_long_name(ioda_name, logger): +def ioda_name_to_long_name(ioda_name: str, logger: Logger) -> str: # Get configuration path jedi_configuration_path = os.path.join(get_swell_path(), 'configuration', 'jedi') diff --git a/src/swell/utilities/observing_system_records.py b/src/swell/utilities/observing_system_records.py index c0cd9722..fc402588 100644 --- a/src/swell/utilities/observing_system_records.py +++ b/src/swell/utilities/observing_system_records.py @@ -3,13 +3,15 @@ import pandas as pd import numpy as np import datetime as dt +from typing import Optional + from swell.utilities.logger import Logger from swell.utilities.gsi_record_parser import GSIRecordParser # -------------------------------------------------------------------------------------------------- -def format_date(old_date): +def format_date(old_date: str) -> str: ''' Formatting date into expected template ''' date = dt.datetime.strptime(old_date, '%Y%m%d%H%M%S') @@ -18,7 +20,7 @@ def format_date(old_date): # -------------------------------------------------------------------------------------------------- -def read_sat_db(path_to_sat_db, column_names): +def read_sat_db(path_to_sat_db: str, column_names: list[str]) -> pd.DataFrame: ''' Reading GSI observing system records row by row into @@ -83,7 +85,7 @@ class ObservingSystemRecords: yaml files. ''' - def __init__(self, record_type): + def __init__(self, record_type: str) -> None: ''' Supports either 'channel' or 'level' record type. This only affects naming conventions. @@ -96,7 +98,7 @@ def __init__(self, record_type): self.record_type = record_type self.logger = Logger('ObservingSystemRecords') - def parse_records(self, path_to_sat_db): + def parse_records(self, path_to_sat_db: str) -> None: ''' This method reads in the active.tbl and available.tbl files @@ -116,7 +118,7 @@ def parse_records(self, path_to_sat_db): 'level', 'comments'] file_ext_name = '.tbl' else: - logger.abort(f'Record type {self.record_type} not supported. \ + self.logger.abort(f'Record type {self.record_type} not supported. \ Use channel or level') parser = GSIRecordParser() @@ -144,9 +146,13 @@ def parse_records(self, path_to_sat_db): elif channel_type == 'available': self.available_df = df else: - logger.abort(f'record parsing unavailable for {channel_type}') + self.logger.abort(f'record parsing unavailable for {channel_type}') - def save_yamls(self, output_dir, observation_list=None): + def save_yamls( + self, + output_dir: str, + observation_list: Optional[list] = None + ) -> None: ''' Fields are taken from the internal dataframes populated @@ -202,7 +208,7 @@ def save_yamls(self, output_dir, observation_list=None): elif self.record_type == 'level': output_ext_name = '_level_info.yaml' else: - logger.abort(f'Record type {self.record_type} not supported. \ + self.logger.abort(f'Record type {self.record_type} not supported. \ Use channel or level') with open(output_dir + '/' + instr + '_' + sat + output_ext_name, 'w') as file: diff --git a/src/swell/utilities/render_jedi_interface_files.py b/src/swell/utilities/render_jedi_interface_files.py index b8701a5e..6cfb6926 100644 --- a/src/swell/utilities/render_jedi_interface_files.py +++ b/src/swell/utilities/render_jedi_interface_files.py @@ -9,18 +9,27 @@ import os import yaml +from typing import Union, Optional, Any from swell.utilities.jinja2 import template_string_jinja2 from swell.utilities.get_channels import get_channels - +from swell.utilities.logger import Logger +from swell.utilities.datetime_util import Datetime # -------------------------------------------------------------------------------------------------- class JediConfigRendering(): - def __init__(self, logger, experiment_root, experiment_id, cycle_dir, cycle_time, - jedi_interface=None): + def __init__( + self, + logger: Logger, + experiment_root: str, + experiment_id: str, + cycle_dir: Optional[str], + cycle_time: Optional[Datetime], + jedi_interface: Optional[str] = None + ) -> None: # Keep a copy of the logger self.logger = logger @@ -112,7 +121,7 @@ def __init__(self, logger, experiment_root, experiment_id, cycle_dir, cycle_time # ---------------------------------------------------------------------------------------------- # Function to add key to the template dictionary - def add_key(self, key, element): + def add_key(self, key: str, element: Any) -> None: # First assert that key is allowed self.logger.assert_abort(key in self.valid_template_keys, f'Trying to add key \'{key}\' ' + @@ -125,7 +134,7 @@ def add_key(self, key, element): # ---------------------------------------------------------------------------------------------- # Open the file at the provided path, use dictionary to complete templates and return dictionary - def __open_file_render_to_dict__(self, config_file): + def __open_file_render_to_dict__(self, config_file: str) -> dict[Any, Any]: # Check that config file exists self.logger.assert_abort(os.path.exists(config_file), f'In open_file_and_render failed ' + @@ -145,7 +154,7 @@ def __open_file_render_to_dict__(self, config_file): # ---------------------------------------------------------------------------------------------- # Prepare path to oops file and call rendering - def render_oops_file(self, config_name): + def render_oops_file(self, config_name: str) -> dict: # Path to configuration file config_file = os.path.join(self.jedi_config_path, 'oops', f'{config_name}.yaml') @@ -156,7 +165,7 @@ def render_oops_file(self, config_name): # ---------------------------------------------------------------------------------------------- # Prepare path to interface model file and call rendering - def render_interface_model(self, config_name): + def render_interface_model(self, config_name: str) -> dict[Any, Any]: # Assert that there is a jedi interface associated with the task self.logger.assert_abort(self.jedi_interface is not None, f'In order to render a ' + @@ -172,7 +181,7 @@ def render_interface_model(self, config_name): # ---------------------------------------------------------------------------------------------- - def set_obs_records_path(self, path): + def set_obs_records_path(self, path: str) -> None: # Never put a path that is string None in place if path == 'None': @@ -184,7 +193,7 @@ def set_obs_records_path(self, path): # ---------------------------------------------------------------------------------------------- # Prepare path to interface observations file and call rendering - def render_interface_observations(self, config_name): + def render_interface_observations(self, config_name: str) -> dict: # Assert that there is a jedi interface associated with the task self.logger.assert_abort(self.jedi_interface is not None, f'In order to render a ' + @@ -222,7 +231,7 @@ def render_interface_observations(self, config_name): # Prepare path to interface metadata file and call rendering - def render_interface_meta(self, model_component_in=None): + def render_interface_meta(self, model_component_in: Union[str, dict, None] = None) -> dict: # Optionally open a different model interface model_component = self.jedi_interface diff --git a/src/swell/utilities/run_jedi_executables.py b/src/swell/utilities/run_jedi_executables.py index ea2d2d12..66d39c33 100644 --- a/src/swell/utilities/run_jedi_executables.py +++ b/src/swell/utilities/run_jedi_executables.py @@ -10,12 +10,22 @@ import os import netCDF4 as nc +from typing import Optional +import datetime + from swell.utilities.shell_commands import run_track_log_subprocess +from swell.utilities.logger import Logger +from swell.tasks.base.task_base import JediConfigRendering # -------------------------------------------------------------------------------------------------- -def check_obs(path_to_observing_sys_yamls, observation, obs_dict, cycle_time): +def check_obs( + path_to_observing_sys_yamls: Optional[str], + observation: str, + obs_dict: dict, + cycle_time: Optional[str] +) -> bool: use_observation = False @@ -38,8 +48,14 @@ def check_obs(path_to_observing_sys_yamls, observation, obs_dict, cycle_time): # -------------------------------------------------------------------------------------------------- -def jedi_dictionary_iterator(jedi_config_dict, jedi_rendering, window_type=None, obs=None, - cycle_time=None, jedi_forecast_model=None): +def jedi_dictionary_iterator( + jedi_config_dict: dict, + jedi_rendering: JediConfigRendering, + window_type: Optional[str] = None, + obs: Optional[list[str]] = None, + cycle_time: Optional[datetime.datetime] = None, + jedi_forecast_model: Optional[str] = None +) -> None: # Assemble configuration YAML file # -------------------------------- @@ -90,7 +106,14 @@ def jedi_dictionary_iterator(jedi_config_dict, jedi_rendering, window_type=None, # ---------------------------------------------------------------------------------------------- -def run_executable(logger, cycle_dir, np, jedi_executable_path, jedi_config_file, output_log): +def run_executable( + logger: Logger, + cycle_dir: str, + np: int, + jedi_executable_path: str, + jedi_config_file: str, + output_log: str +) -> None: # Run the JEDI executable # ----------------------- diff --git a/src/swell/utilities/scripts/check_jedi_interface_templates.py b/src/swell/utilities/scripts/check_jedi_interface_templates.py index aad74256..252d3f28 100644 --- a/src/swell/utilities/scripts/check_jedi_interface_templates.py +++ b/src/swell/utilities/scripts/check_jedi_interface_templates.py @@ -20,7 +20,7 @@ # -------------------------------------------------------------------------------------------------- -def main(): +def main() -> None: # Create a logger logger = Logger('CheckJediInterfaceTemplates') diff --git a/src/swell/utilities/scripts/task_question_dicts_defaults.py b/src/swell/utilities/scripts/task_question_dicts_defaults.py index 24349d72..d8bd7a77 100644 --- a/src/swell/utilities/scripts/task_question_dicts_defaults.py +++ b/src/swell/utilities/scripts/task_question_dicts_defaults.py @@ -12,6 +12,7 @@ import random import string import yaml +from typing import Union # swell imports from swell.swell_path import get_swell_path @@ -21,7 +22,12 @@ # -------------------------------------------------------------------------------------------------- -def create_jedi_tq_dicts(logger, jedi_interface_name, tq_dicts, jedi_tq_dicts_str_in): +def create_jedi_tq_dicts( + logger: Logger, + jedi_interface_name: str, + tq_dicts: Union[list, dict], + jedi_tq_dicts_str_in: str +) -> str: # Convert string read from file to dictionary if jedi_tq_dicts_str_in == '': @@ -78,7 +84,12 @@ def create_jedi_tq_dicts(logger, jedi_interface_name, tq_dicts, jedi_tq_dicts_st # -------------------------------------------------------------------------------------------------- -def create_platform_tq_dicts(logger, platform_name, tq_dicts, platform_tq_dicts_str_in): +def create_platform_tq_dicts( + logger: Logger, + platform_name: str, + tq_dicts: Union[list, dict], + platform_tq_dicts_str_in: str +) -> str: # Convert string read from file to dictionary if platform_tq_dicts_str_in == '': @@ -124,7 +135,7 @@ def create_platform_tq_dicts(logger, platform_name, tq_dicts, platform_tq_dicts_ # -------------------------------------------------------------------------------------------------- -def main(): +def main() -> int: # Create a logger logger = Logger('ListOfTaskQuestions') diff --git a/src/swell/utilities/scripts/utility_driver.py b/src/swell/utilities/scripts/utility_driver.py index 7f58961c..bae99321 100644 --- a/src/swell/utilities/scripts/utility_driver.py +++ b/src/swell/utilities/scripts/utility_driver.py @@ -19,7 +19,7 @@ # -------------------------------------------------------------------------------------------------- -def get_utilities(): +def get_utilities() -> list: # Path to util scripts util_scripts_dir = os.path.join(get_swell_path(), 'utilities', 'scripts', '*.py') @@ -44,7 +44,7 @@ def get_utilities(): # -------------------------------------------------------------------------------------------------- -def utility_wrapper(utility): +def utility_wrapper(utility: str) -> None: # Convert utility to snake case utility_snake = camel_case_to_snake_case(utility) diff --git a/src/swell/utilities/shell_commands.py b/src/swell/utilities/shell_commands.py index f48a39af..8bcfa136 100644 --- a/src/swell/utilities/shell_commands.py +++ b/src/swell/utilities/shell_commands.py @@ -10,12 +10,19 @@ import os import stat import subprocess +from typing import Any, Optional, IO, Union + +from swell.utilities.logger import Logger # -------------------------------------------------------------------------------------------------- -def run_track_log_subprocess(logger, command, output_log=None): +def run_track_log_subprocess( + logger: Logger, + command: Union[list[str], str], + output_log: Optional[str] = None +) -> None: # Prepare output file # ------------------- @@ -56,7 +63,10 @@ def run_track_log_subprocess(logger, command, output_log=None): # -------------------------------------------------------------------------------------------------- -def run_subprocess_dev_null(logger, command): +def run_subprocess_dev_null( + logger: Logger, + command: Union[list[str], str] +) -> None: run_subprocess(logger, command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) @@ -64,7 +74,12 @@ def run_subprocess_dev_null(logger, command): # -------------------------------------------------------------------------------------------------- -def run_subprocess(logger, command, stdout=None, stderr=None): +def run_subprocess( + logger: Logger, + command: Union[list[str], str], + stdout: Union[int, IO[Any], None] = None, + stderr: Union[int, IO[Any], None] = None +) -> None: # Run subprocess try: @@ -77,7 +92,7 @@ def run_subprocess(logger, command, stdout=None, stderr=None): # -------------------------------------------------------------------------------------------------- -def create_executable_file(logger, file_name, file_contents): +def create_executable_file(logger: Logger, file_name: str, file_contents: str) -> None: # Write contents to file with open(os.path.join(file_name), "w") as file_name_open: diff --git a/src/swell/utilities/slurm.py b/src/swell/utilities/slurm.py index 607fa5e2..4d52250d 100644 --- a/src/swell/utilities/slurm.py +++ b/src/swell/utilities/slurm.py @@ -11,16 +11,19 @@ import platform as pltfrm import re import yaml +from typing import Union -from swell.utilities.logger import Logger from importlib import resources +from logging import Logger as pyLogger + +from swell.utilities.logger import Logger def prepare_scheduling_dict( - logger: Logger, + logger: Union[Logger, pyLogger], experiment_dict: dict, platform: str, -): +) -> dict: # Obtain platform-specific SLURM directives and set them as global defaults # Start by constructing the full platforms path @@ -184,7 +187,7 @@ def prepare_scheduling_dict( return scheduling_dict -def add_directives(target_dict, input_dict, key): +def add_directives(target_dict: dict, input_dict: dict, key: str) -> dict: if key in input_dict: return { **target_dict, @@ -194,7 +197,7 @@ def add_directives(target_dict, input_dict, key): return target_dict -def validate_directives(directive_dict): +def validate_directives(directive_dict: dict) -> None: directive_pattern = r'(?<=--)[a-zA-Z-]+' # Parse sbatch docs and extract all directives (e.g., `--account`) directive_list = { @@ -211,7 +214,7 @@ def validate_directives(directive_dict): def slurm_global_defaults( - logger: Logger, + logger: Union[Logger, pyLogger], yaml_path: str = "~/.swell/swell-slurm.yaml" ) -> dict: yaml_path = os.path.expanduser(yaml_path) diff --git a/src/swell/utilities/suite_utils.py b/src/swell/utilities/suite_utils.py index 9abc856a..fecb14be 100644 --- a/src/swell/utilities/suite_utils.py +++ b/src/swell/utilities/suite_utils.py @@ -17,7 +17,7 @@ # -------------------------------------------------------------------------------------------------- -def get_suites(): +def get_suites() -> list: # Path to platforms suites_directory = os.path.join(get_swell_path(), 'suites') @@ -36,7 +36,7 @@ def get_suites(): # -------------------------------------------------------------------------------------------------- -def get_suite_tests(): +def get_suite_tests() -> list: # Path to platforms suite_tests_directory = os.path.join(get_swell_path(), 'test', 'suite_tests', '*.yaml') diff --git a/src/swell/utilities/welcome_message.py b/src/swell/utilities/welcome_message.py index 3b0c1a93..949de1bd 100644 --- a/src/swell/utilities/welcome_message.py +++ b/src/swell/utilities/welcome_message.py @@ -14,7 +14,7 @@ # -------------------------------------------------------------------------------------------------- -def write_welcome_message(): +def write_welcome_message() -> None: logger = Logger('')