From 5f7f418720f06026331bcb8724a8f612794aadec Mon Sep 17 00:00:00 2001 From: Bart Schilperoort Date: Wed, 22 Nov 2023 16:39:22 +0100 Subject: [PATCH] Fix type errors, comply with linters & formatter --- PyStemmusScope/bmi.py | 83 ++++++++++++++++++++--------- PyStemmusScope/global_data/utils.py | 8 +-- tests/test_utils.py | 2 +- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/PyStemmusScope/bmi.py b/PyStemmusScope/bmi.py index a29124ad..f5aa2d1d 100644 --- a/PyStemmusScope/bmi.py +++ b/PyStemmusScope/bmi.py @@ -1,22 +1,24 @@ """BMI wrapper for the STEMMUS_SCOPE model.""" import os +import subprocess import sys from pathlib import Path +import h5py import numpy as np from bmipy.bmi import Bmi from PyStemmusScope.config_io import read_config -import subprocess -import h5py -MODEL_INPUT_VARNAMES: tuple[str] = ("soil_temperature",) +MODEL_INPUT_VARNAMES: tuple[str, ...] = ("soil_temperature",) -MODEL_OUTPUT_VARNAMES: tuple[str] = ( +MODEL_OUTPUT_VARNAMES: tuple[str, ...] = ( "soil_temperature", "respiration", ) -MODEL_VARNAMES: tuple[str] = tuple(set(MODEL_INPUT_VARNAMES + MODEL_OUTPUT_VARNAMES)) +MODEL_VARNAMES: tuple[str, ...] = tuple( + set(MODEL_INPUT_VARNAMES + MODEL_OUTPUT_VARNAMES) +) VARNAME_UNITS: dict[str, str] = {"respiration": "unknown", "soil_temperature": "degC"} @@ -30,6 +32,12 @@ "soil_temperature": 1, } +NO_STATE_MSG = ( + "The model state is not available. Please run `.update()` before requesting " + "\nthis model info. If you did run .update() before, something seems to have " + "\ngone wrong and you have to restart the model." +) + def ipython_info(): """Get ipython info: if the code is being run from notebook or terminal.""" @@ -54,7 +62,7 @@ def load_state(config: dict) -> h5py.File: return h5py.File(matfile, mode="a") -def get_variable(state: h5py.File, varname: str) -> None: +def get_variable(state: h5py.File, varname: str) -> np.ndarray: """Get a variable from the model state. Args: @@ -65,7 +73,7 @@ def get_variable(state: h5py.File, varname: str) -> None: case "respiration": return state["fluxes"]["Resp"][0] case "soil_temperature": - return state["TT"][0,:-1] + return state["TT"][0, :-1] case _: if varname in MODEL_VARNAMES: msg = "Varname is missing in get_variable! Contact devs." @@ -89,7 +97,7 @@ def set_variable(state: h5py.File, varname: str, value: np.ndarray) -> dict: case "respiration": state["fluxes"]["Resp"][0] = value case "soil_temperature": - state["TT"][0,:-1] = value + state["TT"][0, :-1] = value case _: if varname in MODEL_VARNAMES: msg = "Varname is missing in get_variable! Contact devs." @@ -99,19 +107,22 @@ def set_variable(state: h5py.File, varname: str, value: np.ndarray) -> dict: return state -def is_alive(process: subprocess.Popen) -> bool: - """Check if the process is alive, and raise an exception if it is not.""" +def is_alive(process: subprocess.Popen | None) -> subprocess.Popen: + """Return process if the process is alive, raise an exception if it is not.""" + if process is None: + msg = "Model process does not seem to be open." + raise ConnectionError(msg) if process.poll() is not None: msg = f"Model terminated with return code {process.poll()}" raise ConnectionError(msg) - return True + return process def wait_for_model(process: subprocess.Popen, phrase=b"Select run mode:") -> None: """Wait for model to be ready for interaction.""" - output = b'' + output = b"" while is_alive(process) and phrase not in output: - output+=process.stdout.read(1) + output += str(process.stdout.read(1)) # type: ignore class StemmusScopeBmi(Bmi): @@ -148,8 +159,8 @@ def initialize(self, config_file: str) -> None: ) wait_for_model(self.matlab_process) - is_alive(self.matlab_process) - self.matlab_process.stdin.write(b"initialize\n") + self.matlab_process = is_alive(self.matlab_process) + self.matlab_process.stdin.write(b"initialize\n") # type: ignore wait_for_model(self.matlab_process) def update(self) -> None: @@ -160,9 +171,9 @@ def update(self) -> None: if self.matlab_process is None: msg = "Run initialize before trying to update the model." raise AttributeError(msg) - - is_alive(self.matlab_process) - self.matlab_process.stdin.write(b"update\n") + + self.matlab_process = is_alive(self.matlab_process) + self.matlab_process.stdin.write(b"update\n") # type: ignore wait_for_model(self.matlab_process) self.state = load_state(self.config) @@ -177,8 +188,8 @@ def update_until(self, time: float) -> None: def finalize(self) -> None: """Finalize the STEMMUS_SCOPE model.""" - is_alive(self.matlab_process) - self.matlab_process.stdin.write(b"finalize\n") + self.matlab_process = is_alive(self.matlab_process) + self.matlab_process.stdin.write(b"finalize\n") # type: ignore wait_for_model(self.matlab_process, phrase=b"Finished clean up.") def get_component_name(self) -> str: @@ -205,13 +216,15 @@ def get_output_item_count(self) -> int: """ return len(MODEL_OUTPUT_VARNAMES) - def get_input_var_names(self) -> list[str]: + # The types of the following two methods are wrong in python-bmi + # see: https://github.com/csdms/bmi-python/issues/38 + def get_input_var_names(self) -> tuple[str, ...]: # type: ignore """List of the model's input variables (as CSDMS Standard Names).""" - return list(MODEL_INPUT_VARNAMES.keys()) + return MODEL_INPUT_VARNAMES - def get_output_var_names(self) -> list[str]: + def get_output_var_names(self) -> tuple[str, ...]: # type: ignore """List of the model's output variables (as CSDMS Standard Names).""" - return list(MODEL_OUTPUT_VARNAMES.keys()) + return MODEL_OUTPUT_VARNAMES def get_var_grid(self, name: str) -> int: """Get grid identifier for the given variable.""" @@ -252,6 +265,8 @@ def get_time_units(self) -> str: def get_time_step(self) -> float: """Return the current time step of the model.""" + if self.state is None: + raise ValueError(NO_STATE_MSG) return float(self.state["KT"][0]) def get_value(self, name: str, dest: np.ndarray) -> np.ndarray: @@ -264,6 +279,8 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray: Returns: The same numpy array that was passed as an input buffer. """ + if self.state is None: + raise ValueError(NO_STATE_MSG) dest[:] = get_variable(self.state, name) return dest @@ -296,6 +313,8 @@ def set_value(self, name: str, src: np.ndarray) -> None: name: Input or output variable name, a CSDMS Standard Name. src: The new value for the specified variable. """ + if self.state is None: + raise ValueError(NO_STATE_MSG) self.state = set_variable(self.state, name, src) def set_value_at_indices( @@ -328,6 +347,8 @@ def get_grid_rank(self, grid: int) -> int: return 2 if grid == 1: return 3 + msg = f"Invalid grid identifier '{grid}'" + raise ValueError(msg) def get_grid_size(self, grid: int) -> int: """Get the total number of elements in the computational grid. @@ -338,11 +359,17 @@ def get_grid_size(self, grid: int) -> int: Returns: Size of the grid. """ + if self.state is None: + raise ValueError(NO_STATE_MSG) + if grid == 0: return 1 if grid == 1: return int(self.state["ModelSettings"]["mN"][0]) - 1 + msg = f"Invalid grid identifier '{grid}'" + raise ValueError(msg) + def get_grid_type(self, grid: int) -> str: """Get the grid type as a string. @@ -365,6 +392,8 @@ def get_grid_x(self, grid: int, x: np.ndarray) -> np.ndarray: Returns: The input numpy array that holds the grid's column x-coordinates. """ + if self.state is None: + raise ValueError(NO_STATE_MSG) x[:] = self.state["SiteProperties"]["latitude"][0] return x @@ -378,6 +407,8 @@ def get_grid_y(self, grid: int, y: np.ndarray) -> np.ndarray: Returns: The input numpy array that holds the grid's column y-coordinates. """ + if self.state is None: + raise ValueError(NO_STATE_MSG) y[:] = self.state["SiteProperties"]["latitude"][0] return y @@ -391,11 +422,13 @@ def get_grid_z(self, grid: int, z: np.ndarray) -> np.ndarray: Returns: The input numpy array that holds the grid's column z-coordinates. """ + if self.state is None: + raise ValueError(NO_STATE_MSG) if grid == 1: z[:] = ( -np.hstack( ( - self.state["ModelSettings"]["DeltZ_R"][:,0].cumsum()[::-1], + self.state["ModelSettings"]["DeltZ_R"][:, 0].cumsum()[::-1], np.array([0.0]), ) ) diff --git a/PyStemmusScope/global_data/utils.py b/PyStemmusScope/global_data/utils.py index 835a75ed..8403c469 100644 --- a/PyStemmusScope/global_data/utils.py +++ b/PyStemmusScope/global_data/utils.py @@ -70,10 +70,10 @@ def assert_time_within_bounds( raise MissingDataError( "\nThe available data cannot cover the specified start and end time.\n" f" Specified model time range:\n" - f" {np.datetime_as_string(start_time, unit='m')}" - f" - {np.datetime_as_string(end_time, unit='m')}\n" - f" Data start: {np.datetime_as_string(data[time_dim].min(), unit='m')}\n" - f" Data end: {np.datetime_as_string(data[time_dim].max(), unit='m')}" + f" {np.datetime_as_string(start_time, unit='m')}" # type: ignore + f" - {np.datetime_as_string(end_time, unit='m')}\n" # type: ignore + f" Data start: {np.datetime_as_string(data[time_dim].min(), unit='m')}\n" # type: ignore + f" Data end: {np.datetime_as_string(data[time_dim].max(), unit='m')}" # type: ignore ) diff --git a/tests/test_utils.py b/tests/test_utils.py index b3453832..dc3d1dc4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -59,7 +59,7 @@ def test_to_absolute_path_with_relative_input_and_relative_parent(): # care for windows, see issue 22 Path(input_path).mkdir(exist_ok=True) - parsed = utils.to_absolute_path(input_path, parent=Path(".")) + parsed = utils.to_absolute_path(input_path, parent=Path.cwd()) expected = Path.cwd() / "input_dir" assert parsed == expected