-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #97 from eWaterCycle/memoized_bmi_client
Memoized bmi client
- Loading branch information
Showing
2 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from basic_modeling_interface.bmi import Bmi | ||
|
||
|
||
class MemoizedBmi(Bmi): | ||
"""Wrapper around Bmi object that caches the return values of almost all methods. | ||
Most BMI methods return same value each time it is called, so the results can be cached. | ||
gRPC communication is expensive and can be sped up by caching. | ||
The following methods are not cached: | ||
* initialize | ||
* update_* | ||
* finalize | ||
* get_current_time | ||
* get_value_* | ||
* set_value_* | ||
The cache is cleared when initialize() is called. | ||
Example: | ||
A gRPC BMI server is running on localhost:1234, to cache it use the following. | ||
>>> import grpc | ||
>>> from grpc4bmi.bmi_grpc_client import BmiClient | ||
>>> from grpc4bmi.bmi_memoized import MemoizedBmi | ||
>>> slow_model = BmiClient(grpc.insecure_channel("localhost:1234")) | ||
>>> model = MemoizedBmi(slow_model) | ||
>>> print(model.get_component_name()) | ||
Hello world | ||
>>> # Calling second time will return cached value | ||
>>> # and not talk to server on "localhost:1234" | ||
>>> print(model.get_component_name()) | ||
Hello world | ||
""" | ||
def __init__(self, origin: Bmi): | ||
self.origin = origin | ||
self.cache = dict() | ||
|
||
def _cache(self, fn, arg=None): | ||
if fn not in self.cache: | ||
self.cache[fn] = dict() | ||
if arg not in self.cache[fn]: | ||
if arg is None: | ||
self.cache[fn][arg] = getattr(self.origin, fn)() | ||
else: | ||
self.cache[fn][arg] = getattr(self.origin, fn)(arg) | ||
return self.cache[fn][arg] | ||
|
||
def initialize(self, filename): | ||
self.cache = dict() | ||
return self.origin.initialize(filename) | ||
|
||
def update(self): | ||
self.origin.update() | ||
|
||
def update_until(self, time): | ||
self.origin.update_until(time) | ||
|
||
def update_frac(self, time_frac): | ||
self.origin.update_frac(time_frac) | ||
|
||
def finalize(self): | ||
self.origin.finalize() | ||
|
||
def get_component_name(self): | ||
return self._cache('get_component_name') | ||
|
||
def get_input_var_names(self): | ||
return self._cache('get_input_var_names') | ||
|
||
def get_output_var_names(self): | ||
return self._cache('get_output_var_names') | ||
|
||
def get_start_time(self): | ||
return self._cache('get_start_time') | ||
|
||
def get_current_time(self): | ||
return self.origin.get_current_time() | ||
|
||
def get_end_time(self): | ||
return self._cache('get_end_time') | ||
|
||
def get_time_step(self): | ||
return self._cache('get_time_step') | ||
|
||
def get_time_units(self): | ||
return self._cache('get_time_units') | ||
|
||
def get_var_type(self, var_name): | ||
return self._cache('get_var_type', var_name) | ||
|
||
def get_var_units(self, var_name): | ||
return self._cache('get_var_units', var_name) | ||
|
||
def get_var_itemsize(self, var_name): | ||
return self._cache('get_var_itemsize', var_name) | ||
|
||
def get_var_nbytes(self, var_name): | ||
return self._cache('get_var_nbytes', var_name) | ||
|
||
def get_var_grid(self, var_name): | ||
return self._cache('get_var_grid', var_name) | ||
|
||
def get_value(self, var_name): | ||
return self.origin.get_value(var_name) | ||
|
||
def get_value_ref(self, var_name): | ||
return self.origin.get_value_ref(var_name) | ||
|
||
def get_value_at_indices(self, var_name, indices): | ||
return self.origin.get_value_at_indices(var_name, indices) | ||
|
||
def set_value(self, var_name, src): | ||
return self.origin.set_value(var_name, src) | ||
|
||
def set_value_at_indices(self, var_name, indices, src): | ||
return self.origin.set_value_at_indices(var_name, indices, src) | ||
|
||
def get_grid_shape(self, grid_id): | ||
return self._cache('get_grid_shape', grid_id) | ||
|
||
def get_grid_x(self, grid_id): | ||
return self._cache('get_grid_x', grid_id) | ||
|
||
def get_grid_y(self, grid_id): | ||
return self._cache('get_grid_y', grid_id) | ||
|
||
def get_grid_z(self, grid_id): | ||
return self._cache('get_grid_z', grid_id) | ||
|
||
def get_grid_spacing(self, grid_id): | ||
return self._cache('get_grid_spacing', grid_id) | ||
|
||
def get_grid_origin(self, grid_id): | ||
return self._cache('get_grid_origin', grid_id) | ||
|
||
def get_grid_connectivity(self, grid_id): | ||
return self._cache('get_grid_connectivity', grid_id) | ||
|
||
def get_grid_offset(self, grid_id): | ||
return self._cache('get_grid_offset', grid_id) | ||
|
||
def get_grid_rank(self, grid_id): | ||
return self._cache('get_grid_rank', grid_id) | ||
|
||
def get_grid_size(self, grid_id): | ||
return self._cache('get_grid_size', grid_id) | ||
|
||
def get_grid_type(self, grid_id): | ||
return self._cache('get_grid_type', grid_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from grpc4bmi.bmi_memoized import MemoizedBmi | ||
from test.flatbmiheat import FlatBmiHeat | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'mut_name,mut_args', | ||
[ | ||
('get_component_name', tuple()), | ||
('get_input_var_names', tuple()), | ||
('get_output_var_names', tuple()), | ||
('get_start_time', tuple()), | ||
('get_end_time', tuple()), | ||
('get_time_step', tuple()), | ||
('get_time_units', tuple()), | ||
('get_var_type', ['plate_surface__temperature']), | ||
('get_var_units', ['plate_surface__temperature']), | ||
('get_var_itemsize', ['plate_surface__temperature']), | ||
('get_var_nbytes', ['plate_surface__temperature']), | ||
('get_var_grid', ['plate_surface__temperature']), | ||
('get_grid_shape', [0]), | ||
('get_grid_x', [0]), | ||
('get_grid_y', [0]), | ||
('get_grid_z', [0]), | ||
('get_grid_spacing', [0]), | ||
('get_grid_origin', [0]), | ||
('get_grid_connectivity', [0]), | ||
('get_grid_offset', [0]), | ||
('get_grid_rank', [0]), | ||
('get_grid_size', [0]), | ||
('get_grid_type', [0]), | ||
] | ||
) | ||
def test_memoized_methods(mut_name, mut_args): | ||
model = FlatBmiHeat() | ||
with patch.object(model, mut_name, wraps=getattr(model, mut_name)) as mock_method: | ||
client = MemoizedBmi(model) | ||
client.initialize(None) | ||
mot = getattr(client, mut_name) | ||
|
||
first_result = mot(*mut_args) | ||
second_result = mot(*mut_args) | ||
|
||
assert first_result == second_result | ||
assert mock_method.call_count == 1 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'mut_name,mut_args', | ||
[ | ||
('update', tuple()), | ||
('update_until', [2]), | ||
('update_frac', [0.5]), | ||
('finalize', tuple()), | ||
('get_current_time', tuple()), | ||
('get_value', ['plate_surface__temperature']), | ||
('get_value_ref', ['plate_surface__temperature']), | ||
('get_value_at_indices', ['plate_surface__temperature', [1, 2, 3]]), | ||
('set_value', ['plate_surface__temperature', np.ones((10, 20))]), | ||
('set_value_at_indices', ['plate_surface__temperature', [1, 2, 3], [4, 5, 6]]), | ||
] | ||
) | ||
def test_nonmemoized_methods(mut_name, mut_args): | ||
model = FlatBmiHeat() | ||
with patch.object(model, mut_name, wraps=getattr(model, mut_name)) as mock_method: | ||
client = MemoizedBmi(model) | ||
client.initialize(None) | ||
mot = getattr(client, mut_name) | ||
|
||
mot(*mut_args) | ||
mot(*mut_args) | ||
|
||
assert mock_method.call_count == 2 | ||
|
||
|
||
def test_initialize_clears_cache(): | ||
model = FlatBmiHeat() | ||
client = MemoizedBmi(model) | ||
client.initialize(None) | ||
# Fill cache | ||
client.get_start_time() | ||
# Clear cache | ||
client.initialize(None) | ||
|
||
with patch.object(model, 'get_start_time', wraps=model.get_start_time) as mock_method: | ||
client.get_start_time() | ||
assert mock_method.call_count == 1 |