Skip to content

Commit

Permalink
Added float32 support
Browse files Browse the repository at this point in the history
  • Loading branch information
goord committed May 22, 2018
1 parent 63f07cc commit 2256736
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 90 deletions.
23 changes: 16 additions & 7 deletions grpc4bmi/bmi_client_docker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import shutil
import os

import docker

Expand All @@ -13,27 +14,35 @@ class BmiClientDocker(BmiClient):
destruction, this class terminates the corresponding docker server.
"""

input_mount_point = "/data/input"
output_mount_point = "/data/output"

def __init__(self, image, image_port=50051, host=None, input_dir=None, output_dir=None):
client = docker.from_env()
port = BmiClient.get_unique_port()
super(BmiClientDocker, self).__init__(BmiClient.create_grpc_channel(port=port, host=host))
client = docker.from_env()
volumes = {}
self.input_dir = input_dir
self.input_dir = None
if input_dir is not None:
volumes[input_dir] = {"bind": "/data", "mode": "ro"}
self.input_dir = os.path.abspath(input_dir)
volumes[self.input_dir] = {"bind": BmiClientDocker.input_mount_point, "mode": "rw"}
self.output_dir = None
if output_dir is not None:
volumes[output_dir] = {"bind": "/data/output", "mode": "rw"}
self.output_dir = os.path.abspath(output_dir)
volumes[self.output_dir] = {"bind": BmiClientDocker.output_mount_point, "mode": "rw"}
self.container = client.containers.run(image, ports={str(image_port) + "/tcp": port},
volumes=volumes,
detach=True)
super(BmiClientDocker, self).__init__(BmiClient.create_grpc_channel(port=port, host=host))

def __del__(self):
self.container.stop()
if hasattr(self,"container"):
self.container.stop()

def initialize(self, filename):
if self.input_dir is not None:
shutil.copy(filename, self.input_dir)
super(BmiClientDocker, self).initialize("/data/" + filename)
fname = os.path.basename(filename)
super(BmiClientDocker, self).initialize(os.path.join(BmiClientDocker.input_mount_point, fname))
else:
super(BmiClientDocker, self).initialize(filename)

Expand Down
30 changes: 23 additions & 7 deletions grpc4bmi/bmi_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import grpc
import numpy

import bmi_pb2
import bmi_pb2_grpc
from grpc4bmi import bmi_pb2, bmi_pb2_grpc

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,11 +128,14 @@ def get_value_at_indices(self, var_name, indices):
return BmiClient.make_array(response)

def set_value(self, var_name, src):
request = None
if src.dtype == numpy.int32:
request = bmi_pb2.SetValueRequest(name=var_name, values_int=src.flatten(), shape=src.shape)
elif src.dtype == numpy.float32:
request = bmi_pb2.SetValueRequest(name=var_name, values_float=src.flatten(), shape=src.shape)
elif src.dtype == numpy.float64:
request = bmi_pb2.SetValueRequest(name=var_name, values_double=src.flatten(), shape=src.shape)
else:
raise NotImplementedError("Arrays with type %s cannot be transmitted through this GRPC channel" % src.dtype)
self.stub.setValue(request)

def set_value_at_indices(self, var_name, indices, src):
Expand All @@ -146,13 +148,17 @@ def set_value_at_indices(self, var_name, indices, src):
index_size = index_array.shape[1]
else:
raise NotImplementedError("Index arrays should be either 1 or 2-dimensional, row-major ordering")
request = None
if src.dtype == numpy.int32:
request = bmi_pb2.SetValueAtIndicesRequest(name=var_name, indices=index_array.flatten(), values_int=src,
index_size=index_size)
elif src.dtype == numpy.float32:
request = bmi_pb2.SetValueAtIndicesRequest(name=var_name, indices=index_array.flatten(), values_float=src,
index_size=index_size)
elif src.dtype == numpy.float64:
request = bmi_pb2.SetValueAtIndicesRequest(name=var_name, indices=index_array.flatten(), values_double=src,
index_size=index_size)
else:
raise NotImplementedError("Arrays with type %s cannot be transmitted through this GRPC channel" % src.dtype)
self.stub.setValueAtIndices(request)

def get_grid_size(self, grid_id):
Expand Down Expand Up @@ -190,9 +196,19 @@ def get_grid_origin(self, grid_id):

@staticmethod
def make_array(response):
ints_in_buffer = any(response.values_int)
floats_in_buffer = any(response.values_float)
doubles_in_buffer = any(response.values_double)
code = (1 if ints_in_buffer else 0) + (1 if floats_in_buffer else 0) + (1 if doubles_in_buffer else 0)
if code == 0:
log.warning("No values found in server response buffer detected")
return numpy.array([])
if code > 1:
raise NotImplementedError("Multiple value types in single server response buffer detected")
shape = response.shape
if any(response.values_int):
if ints_in_buffer:
return numpy.reshape(response.values_int, shape)
if any(response.values_double):
if floats_in_buffer:
return numpy.reshape(response.values_float, shape)
if doubles_in_buffer:
return numpy.reshape(response.values_double, shape)
return numpy.array([])
44 changes: 31 additions & 13 deletions grpc4bmi/bmi_grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@

import numpy

import bmi_pb2
import bmi_pb2_grpc
from grpc4bmi import bmi_pb2, bmi_pb2_grpc

log = logging.getLogger(__name__)


class BmiServer(bmi_pb2_grpc.BmiServiceServicer):

"""
BMI Server class, wrapping an existing python implementation and exposing it via GRPC across the memory space (to
listening client processes). The class takes a package, module and class name and instantiates the BMI
Expand Down Expand Up @@ -99,6 +97,8 @@ def getValue(self, request, context):
vals = self.bmi_model_.get_value(request.name)
if vals.dtype == numpy.int32:
return bmi_pb2.GetValueResponse(shape=vals.shape, values_int=vals.flatten())
if vals.dtype == numpy.float32:
return bmi_pb2.GetValueResponse(shape=vals.shape, values_float=vals.flatten())
if vals.dtype == numpy.float64:
return bmi_pb2.GetValueResponse(shape=vals.shape, values_double=vals.flatten())
raise NotImplementedError("Arrays with type %s cannot be transmitted through this GRPC channel" % vals.dtype)
Expand All @@ -110,37 +110,45 @@ def getValueAtIndices(self, request, context):
indices = request.indices
index_size = request.index_size
if index_size == 2:
num_indices = len(request.indices)/index_size
indices = numpy.reshape(indices,(num_indices, index_size))
num_indices = len(request.indices) / index_size
indices = numpy.reshape(indices, (num_indices, index_size))
vals = self.bmi_model_.get_value_at_indices(request.name, indices)
if vals.dtype == numpy.int32:
return bmi_pb2.GetValueAtIndicesResponse(values_int=vals.flatten(), shape=vals.shape)
if vals.dtype == numpy.float32:
return bmi_pb2.GetValueAtIndicesResponse(values_float=vals.flatten(), shape=vals.shape)
if vals.dtype == numpy.float64:
return bmi_pb2.GetValueAtIndicesResponse(values_double=vals.flatten(), shape=vals.shape)
raise NotImplementedError("Arrays with type %s cannot be transmitted through this GRPC channel" % vals.dtype)

# TODO: warn if both ints and doubles are in the buffer
def setValue(self, request, context):
if any(request.values_int):
ints, floats, doubles = BmiServer.check_request_values(request)
if ints:
array = numpy.reshape(numpy.array(request.values_int, dtype=numpy.int32), request.shape)
self.bmi_model_.set_value(request.name, array)
elif any(request.values_double):
if floats:
array = numpy.reshape(numpy.array(request.values_float, dtype=numpy.float32), request.shape)
self.bmi_model_.set_value(request.name, array)
if doubles:
array = numpy.reshape(numpy.array(request.values_double, dtype=numpy.float64), request.shape)
self.bmi_model_.set_value(request.name, array)
return bmi_pb2.Empty()

def setValuePtr(self, request, context):
raise NotImplementedError("Array references cannot be transmitted through this GRPC channel")

# TODO: warn if both ints and doubles are in the buffer
def setValueAtIndices(self, request, context):
index_size = request.index_size
num_indices = len(request.indices)/index_size
num_indices = len(request.indices) / index_size
index_array = numpy.reshape(request.indices, newshape=(num_indices, index_size))
if any(request.values_int):
ints, floats, doubles = BmiServer.check_request_values(request)
if ints:
array = numpy.array(request.values_int, dtype=numpy.int32)
self.bmi_model_.set_value_at_indices(request.name, indices=index_array, src=array)
elif any(request.values_double):
if floats:
array = numpy.array(request.values_int, dtype=numpy.float32)
self.bmi_model_.set_value_at_indices(request.name, indices=index_array, src=array)
if doubles:
array = numpy.array(request.values_double, dtype=numpy.float64)
self.bmi_model_.set_value_at_indices(request.name, indices=index_array, src=array)
return bmi_pb2.Empty()
Expand Down Expand Up @@ -187,4 +195,14 @@ def getGridConnectivity(self, request, context):
def getGridOffset(self, request, context):
return bmi_pb2.GetGridOffsetResponse(offsets=self.bmi_model_.get_grid_offset(request.grid_id))


@staticmethod
def check_request_values(request):
ints_in_buffer = any(request.values_int)
floats_in_buffer = any(request.values_float)
doubles_in_buffer = any(request.values_double)
code = (1 if ints_in_buffer else 0) + (1 if floats_in_buffer else 0) + (1 if doubles_in_buffer else 0)
if code == 0:
log.warning("No values found in message buffer detected")
if code > 1:
raise NotImplementedError("Multiple value types in single message buffer detected")
return ints_in_buffer, floats_in_buffer, doubles_in_buffer
Loading

0 comments on commit 2256736

Please sign in to comment.