Skip to content

Commit

Permalink
Fix type errors, comply with linters & formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
BSchilperoort committed Nov 22, 2023
1 parent facb2b3 commit 5f7f418
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 30 deletions.
83 changes: 58 additions & 25 deletions PyStemmusScope/bmi.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
"""BMI wrapper for the STEMMUS_SCOPE model."""
import os
import subprocess
import sys
from pathlib import Path
import h5py
import numpy as np
from bmipy.bmi import Bmi
from PyStemmusScope.config_io import read_config
import subprocess
import h5py


MODEL_INPUT_VARNAMES: tuple[str] = ("soil_temperature",)
MODEL_INPUT_VARNAMES: tuple[str, ...] = ("soil_temperature",)

MODEL_OUTPUT_VARNAMES: tuple[str] = (
MODEL_OUTPUT_VARNAMES: tuple[str, ...] = (
"soil_temperature",
"respiration",
)

MODEL_VARNAMES: tuple[str] = tuple(set(MODEL_INPUT_VARNAMES + MODEL_OUTPUT_VARNAMES))
MODEL_VARNAMES: tuple[str, ...] = tuple(
set(MODEL_INPUT_VARNAMES + MODEL_OUTPUT_VARNAMES)
)

VARNAME_UNITS: dict[str, str] = {"respiration": "unknown", "soil_temperature": "degC"}

Expand All @@ -30,6 +32,12 @@
"soil_temperature": 1,
}

NO_STATE_MSG = (
"The model state is not available. Please run `.update()` before requesting "
"\nthis model info. If you did run .update() before, something seems to have "
"\ngone wrong and you have to restart the model."
)


def ipython_info():
"""Get ipython info: if the code is being run from notebook or terminal."""
Expand All @@ -54,7 +62,7 @@ def load_state(config: dict) -> h5py.File:
return h5py.File(matfile, mode="a")


def get_variable(state: h5py.File, varname: str) -> None:
def get_variable(state: h5py.File, varname: str) -> np.ndarray:
"""Get a variable from the model state.
Args:
Expand All @@ -65,7 +73,7 @@ def get_variable(state: h5py.File, varname: str) -> None:
case "respiration":
return state["fluxes"]["Resp"][0]
case "soil_temperature":
return state["TT"][0,:-1]
return state["TT"][0, :-1]
case _:
if varname in MODEL_VARNAMES:
msg = "Varname is missing in get_variable! Contact devs."
Expand All @@ -89,7 +97,7 @@ def set_variable(state: h5py.File, varname: str, value: np.ndarray) -> dict:
case "respiration":
state["fluxes"]["Resp"][0] = value
case "soil_temperature":
state["TT"][0,:-1] = value
state["TT"][0, :-1] = value
case _:
if varname in MODEL_VARNAMES:
msg = "Varname is missing in get_variable! Contact devs."
Expand All @@ -99,19 +107,22 @@ def set_variable(state: h5py.File, varname: str, value: np.ndarray) -> dict:
return state


def is_alive(process: subprocess.Popen) -> bool:
"""Check if the process is alive, and raise an exception if it is not."""
def is_alive(process: subprocess.Popen | None) -> subprocess.Popen:
"""Return process if the process is alive, raise an exception if it is not."""
if process is None:
msg = "Model process does not seem to be open."
raise ConnectionError(msg)
if process.poll() is not None:
msg = f"Model terminated with return code {process.poll()}"
raise ConnectionError(msg)
return True
return process


def wait_for_model(process: subprocess.Popen, phrase=b"Select run mode:") -> None:
"""Wait for model to be ready for interaction."""
output = b''
output = b""
while is_alive(process) and phrase not in output:
output+=process.stdout.read(1)
output += str(process.stdout.read(1)) # type: ignore


class StemmusScopeBmi(Bmi):
Expand Down Expand Up @@ -148,8 +159,8 @@ def initialize(self, config_file: str) -> None:
)

wait_for_model(self.matlab_process)
is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"initialize\n")
self.matlab_process = is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"initialize\n") # type: ignore
wait_for_model(self.matlab_process)

def update(self) -> None:
Expand All @@ -160,9 +171,9 @@ def update(self) -> None:
if self.matlab_process is None:
msg = "Run initialize before trying to update the model."
raise AttributeError(msg)
is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"update\n")

self.matlab_process = is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"update\n") # type: ignore
wait_for_model(self.matlab_process)

self.state = load_state(self.config)
Expand All @@ -177,8 +188,8 @@ def update_until(self, time: float) -> None:

def finalize(self) -> None:
"""Finalize the STEMMUS_SCOPE model."""
is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"finalize\n")
self.matlab_process = is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"finalize\n") # type: ignore
wait_for_model(self.matlab_process, phrase=b"Finished clean up.")

def get_component_name(self) -> str:
Expand All @@ -205,13 +216,15 @@ def get_output_item_count(self) -> int:
"""
return len(MODEL_OUTPUT_VARNAMES)

def get_input_var_names(self) -> list[str]:
# The types of the following two methods are wrong in python-bmi
# see: https://github.com/csdms/bmi-python/issues/38
def get_input_var_names(self) -> tuple[str, ...]: # type: ignore
"""List of the model's input variables (as CSDMS Standard Names)."""
return list(MODEL_INPUT_VARNAMES.keys())
return MODEL_INPUT_VARNAMES

def get_output_var_names(self) -> list[str]:
def get_output_var_names(self) -> tuple[str, ...]: # type: ignore
"""List of the model's output variables (as CSDMS Standard Names)."""
return list(MODEL_OUTPUT_VARNAMES.keys())
return MODEL_OUTPUT_VARNAMES

def get_var_grid(self, name: str) -> int:
"""Get grid identifier for the given variable."""
Expand Down Expand Up @@ -252,6 +265,8 @@ def get_time_units(self) -> str:

def get_time_step(self) -> float:
"""Return the current time step of the model."""
if self.state is None:
raise ValueError(NO_STATE_MSG)
return float(self.state["KT"][0])

def get_value(self, name: str, dest: np.ndarray) -> np.ndarray:
Expand All @@ -264,6 +279,8 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray:
Returns:
The same numpy array that was passed as an input buffer.
"""
if self.state is None:
raise ValueError(NO_STATE_MSG)
dest[:] = get_variable(self.state, name)
return dest

Expand Down Expand Up @@ -296,6 +313,8 @@ def set_value(self, name: str, src: np.ndarray) -> None:
name: Input or output variable name, a CSDMS Standard Name.
src: The new value for the specified variable.
"""
if self.state is None:
raise ValueError(NO_STATE_MSG)
self.state = set_variable(self.state, name, src)

def set_value_at_indices(
Expand Down Expand Up @@ -328,6 +347,8 @@ def get_grid_rank(self, grid: int) -> int:
return 2
if grid == 1:
return 3
msg = f"Invalid grid identifier '{grid}'"
raise ValueError(msg)

def get_grid_size(self, grid: int) -> int:
"""Get the total number of elements in the computational grid.
Expand All @@ -338,11 +359,17 @@ def get_grid_size(self, grid: int) -> int:
Returns:
Size of the grid.
"""
if self.state is None:
raise ValueError(NO_STATE_MSG)

if grid == 0:
return 1
if grid == 1:
return int(self.state["ModelSettings"]["mN"][0]) - 1

msg = f"Invalid grid identifier '{grid}'"
raise ValueError(msg)

def get_grid_type(self, grid: int) -> str:
"""Get the grid type as a string.
Expand All @@ -365,6 +392,8 @@ def get_grid_x(self, grid: int, x: np.ndarray) -> np.ndarray:
Returns:
The input numpy array that holds the grid's column x-coordinates.
"""
if self.state is None:
raise ValueError(NO_STATE_MSG)
x[:] = self.state["SiteProperties"]["latitude"][0]
return x

Expand All @@ -378,6 +407,8 @@ def get_grid_y(self, grid: int, y: np.ndarray) -> np.ndarray:
Returns:
The input numpy array that holds the grid's column y-coordinates.
"""
if self.state is None:
raise ValueError(NO_STATE_MSG)
y[:] = self.state["SiteProperties"]["latitude"][0]
return y

Expand All @@ -391,11 +422,13 @@ def get_grid_z(self, grid: int, z: np.ndarray) -> np.ndarray:
Returns:
The input numpy array that holds the grid's column z-coordinates.
"""
if self.state is None:
raise ValueError(NO_STATE_MSG)
if grid == 1:
z[:] = (
-np.hstack(
(
self.state["ModelSettings"]["DeltZ_R"][:,0].cumsum()[::-1],
self.state["ModelSettings"]["DeltZ_R"][:, 0].cumsum()[::-1],
np.array([0.0]),
)
)
Expand Down
8 changes: 4 additions & 4 deletions PyStemmusScope/global_data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def assert_time_within_bounds(
raise MissingDataError(
"\nThe available data cannot cover the specified start and end time.\n"
f" Specified model time range:\n"
f" {np.datetime_as_string(start_time, unit='m')}"
f" - {np.datetime_as_string(end_time, unit='m')}\n"
f" Data start: {np.datetime_as_string(data[time_dim].min(), unit='m')}\n"
f" Data end: {np.datetime_as_string(data[time_dim].max(), unit='m')}"
f" {np.datetime_as_string(start_time, unit='m')}" # type: ignore
f" - {np.datetime_as_string(end_time, unit='m')}\n" # type: ignore
f" Data start: {np.datetime_as_string(data[time_dim].min(), unit='m')}\n" # type: ignore
f" Data end: {np.datetime_as_string(data[time_dim].max(), unit='m')}" # type: ignore
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_to_absolute_path_with_relative_input_and_relative_parent():
# care for windows, see issue 22
Path(input_path).mkdir(exist_ok=True)

parsed = utils.to_absolute_path(input_path, parent=Path("."))
parsed = utils.to_absolute_path(input_path, parent=Path.cwd())
expected = Path.cwd() / "input_dir"
assert parsed == expected

Expand Down

0 comments on commit 5f7f418

Please sign in to comment.