Skip to content

Commit

Permalink
Use STEMMUS_SCOPE's interactive mode, h5py for IO
Browse files Browse the repository at this point in the history
  • Loading branch information
BSchilperoort committed Nov 22, 2023
1 parent 1e3011d commit facb2b3
Showing 1 changed file with 59 additions and 39 deletions.
98 changes: 59 additions & 39 deletions PyStemmusScope/bmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import os
import sys
from pathlib import Path
from subprocess import run
from typing import Any
import hdf5storage
import numpy as np
from bmipy.bmi import Bmi
from PyStemmusScope.config_io import read_config
import subprocess
import h5py


MODEL_INPUT_VARNAMES: tuple[str] = ("soil_temperature",)
Expand Down Expand Up @@ -42,20 +41,7 @@ def ipython_info():
return ip


def run_cmd(args):
"""Run a command (i.e. run the matlab model executable).
Args:
args: Arguments to run.
"""
if ipython_info() == "notebook":
result = run(args, shell=True)
else:
result = run(args)
result.check_returncode()


def load_state(config: dict) -> dict[str, Any]:
def load_state(config: dict) -> h5py.File:
"""Load the STEMMUS_SCOPE model state.
Args:
Expand All @@ -65,10 +51,10 @@ def load_state(config: dict) -> dict[str, Any]:
Model state, as a dict.
"""
matfile = Path(config["OutputPath"]) / "STEMMUS_SCOPE_state.mat"
return hdf5storage.loadmat(matfile, appendmat=False)
return h5py.File(matfile, mode="a")


def get_variable(state: dict, varname: str) -> None:
def get_variable(state: h5py.File, varname: str) -> None:
"""Get a variable from the model state.
Args:
Expand All @@ -77,9 +63,9 @@ def get_variable(state: dict, varname: str) -> None:
"""
match varname:
case "respiration":
return state["fluxes"]["Resp"].flatten()
return state["fluxes"]["Resp"][0]
case "soil_temperature":
return state["TT"].flatten()[:-1]
return state["TT"][0,:-1]
case _:
if varname in MODEL_VARNAMES:
msg = "Varname is missing in get_variable! Contact devs."
Expand All @@ -88,7 +74,7 @@ def get_variable(state: dict, varname: str) -> None:
raise ValueError(msg)


def set_variable(state: dict, varname: str, value: np.ndarray) -> dict:
def set_variable(state: h5py.File, varname: str, value: np.ndarray) -> dict:
"""Set a variable in the model state.
Args:
Expand All @@ -101,9 +87,9 @@ def set_variable(state: dict, varname: str, value: np.ndarray) -> dict:
"""
match varname:
case "respiration":
state["fluxes"]["Resp"][0][0][0] = value
state["fluxes"]["Resp"][0] = value
case "soil_temperature":
state["TT"][:-1, 0] = value
state["TT"][0,:-1] = value
case _:
if varname in MODEL_VARNAMES:
msg = "Varname is missing in get_variable! Contact devs."
Expand All @@ -113,13 +99,30 @@ def set_variable(state: dict, varname: str, value: np.ndarray) -> dict:
return state


def is_alive(process: subprocess.Popen) -> bool:
"""Check if the process is alive, and raise an exception if it is not."""
if process.poll() is not None:
msg = f"Model terminated with return code {process.poll()}"
raise ConnectionError(msg)
return True


def wait_for_model(process: subprocess.Popen, phrase=b"Select run mode:") -> None:
"""Wait for model to be ready for interaction."""
output = b''
while is_alive(process) and phrase not in output:
output+=process.stdout.read(1)


class StemmusScopeBmi(Bmi):
"""STEMMUS_SCOPE Basic Model Interface."""

config_file: str = ""
config: dict = {}
state: dict = {}
state_file: Path
state: h5py.File | None = None
state_file: Path | None = None

matlab_process: subprocess.Popen | None = None

def initialize(self, config_file: str) -> None:
"""Perform startup tasks for the model.
Expand All @@ -132,19 +135,35 @@ def initialize(self, config_file: str) -> None:
self.exe_file = self.config["ExeFilePath"]
self.state_file = Path(self.config["OutputPath"]) / "STEMMUS_SCOPE_state.mat"

args = f"{self.exe_file} {self.config_file} initialize"
args = [self.exe_file, self.config_file, "interactive"]

# set matlab log dirc
os.environ["MATLAB_LOG_DIR"] = str(self.config["InputPath"])

run_cmd(args)
self.matlab_process = subprocess.Popen(
args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
bufsize=0,
)

wait_for_model(self.matlab_process)
is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"initialize\n")
wait_for_model(self.matlab_process)

def update(self) -> None:
"""Advance the model state by one time step."""
if self.state != {}:
hdf5storage.savemat(self.state_file, self.state, appendmat=False)
args = f"{self.exe_file} {self.config_file} update"
run_cmd(args)
if self.state is not None:
self.state = self.state.close() # Close file to allow matlab to write

if self.matlab_process is None:
msg = "Run initialize before trying to update the model."
raise AttributeError(msg)

is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"update\n")
wait_for_model(self.matlab_process)

self.state = load_state(self.config)

Expand All @@ -158,8 +177,9 @@ def update_until(self, time: float) -> None:

def finalize(self) -> None:
"""Finalize the STEMMUS_SCOPE model."""
args = f"{self.exe_file} {self.config_file} finalize"
run_cmd(args)
is_alive(self.matlab_process)
self.matlab_process.stdin.write(b"finalize\n")
wait_for_model(self.matlab_process, phrase=b"Finished clean up.")

def get_component_name(self) -> str:
"""Name of the component.
Expand Down Expand Up @@ -232,7 +252,7 @@ def get_time_units(self) -> str:

def get_time_step(self) -> float:
"""Return the current time step of the model."""
return self.state["KT"][0][0].flatten().astype("float")
return float(self.state["KT"][0])

def get_value(self, name: str, dest: np.ndarray) -> np.ndarray:
"""Get a copy of values of the given variable.
Expand Down Expand Up @@ -321,7 +341,7 @@ def get_grid_size(self, grid: int) -> int:
if grid == 0:
return 1
if grid == 1:
return int(self.state["ModelSettings"]["mN"].flatten()[0]) - 1
return int(self.state["ModelSettings"]["mN"][0]) - 1

def get_grid_type(self, grid: int) -> str:
"""Get the grid type as a string.
Expand All @@ -345,7 +365,7 @@ def get_grid_x(self, grid: int, x: np.ndarray) -> np.ndarray:
Returns:
The input numpy array that holds the grid's column x-coordinates.
"""
x[:] = self.state["SiteProperties"]["latitude"][0][0].flatten().astype("float")
x[:] = self.state["SiteProperties"]["latitude"][0]
return x

def get_grid_y(self, grid: int, y: np.ndarray) -> np.ndarray:
Expand All @@ -358,7 +378,7 @@ def get_grid_y(self, grid: int, y: np.ndarray) -> np.ndarray:
Returns:
The input numpy array that holds the grid's column y-coordinates.
"""
y[:] = self.state["SiteProperties"]["latitude"].flatten().astype("float")
y[:] = self.state["SiteProperties"]["latitude"][0]
return y

def get_grid_z(self, grid: int, z: np.ndarray) -> np.ndarray:
Expand All @@ -375,8 +395,8 @@ def get_grid_z(self, grid: int, z: np.ndarray) -> np.ndarray:
z[:] = (
-np.hstack(
(
self.state["ModelSettings"]["DeltZ_R"][:,0].cumsum()[::-1],
np.array([0.0]),
self.state["ModelSettings"]["DeltZ_R"].flatten().cumsum(),
)
)
/ 100
Expand Down

0 comments on commit facb2b3

Please sign in to comment.