Skip to content

Commit

Permalink
Add extra params to MLModel (#783)
Browse files Browse the repository at this point in the history
  • Loading branch information
njooma authored Nov 6, 2024
1 parent f7b35b8 commit b630e05
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
18 changes: 13 additions & 5 deletions src/viam/services/mlmodel/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Dict, Optional
from typing import Dict, Mapping, Optional

from grpclib.client import Channel
from numpy.typing import NDArray

from viam.proto.service.mlmodel import InferRequest, InferResponse, MetadataRequest, MetadataResponse, MLModelServiceStub
from viam.resource.rpc_client_base import ReconfigurableResourceRPCClientBase
from viam.services.mlmodel.utils import flat_tensors_to_ndarrays, ndarrays_to_flat_tensors
from viam.utils import ValueTypes, dict_to_struct

from .mlmodel import Metadata, MLModel

Expand All @@ -16,14 +17,21 @@ def __init__(self, name: str, channel: Channel):
self.client = MLModelServiceStub(channel)
super().__init__(name)

async def infer(self, input_tensors: Dict[str, NDArray], *, timeout: Optional[float] = None, **kwargs) -> Dict[str, NDArray]:
async def infer(
self,
input_tensors: Dict[str, NDArray],
*,
extra: Optional[Mapping[str, ValueTypes]] = None,
timeout: Optional[float] = None,
**kwargs,
) -> Dict[str, NDArray]:
md = kwargs.get("metadata", self.Metadata()).proto
request = InferRequest(name=self.name, input_tensors=ndarrays_to_flat_tensors(input_tensors))
request = InferRequest(name=self.name, input_tensors=ndarrays_to_flat_tensors(input_tensors), extra=dict_to_struct(extra))
response: InferResponse = await self.client.Infer(request, timeout=timeout, metadata=md)
return flat_tensors_to_ndarrays(response.output_tensors)

async def metadata(self, *, timeout: Optional[float] = None, **kwargs) -> Metadata:
async def metadata(self, *, extra: Optional[Mapping[str, ValueTypes]] = None, timeout: Optional[float] = None, **kwargs) -> Metadata:
md = kwargs.get("metadata", self.Metadata()).proto
request = MetadataRequest(name=self.name)
request = MetadataRequest(name=self.name, extra=dict_to_struct(extra))
response: MetadataResponse = await self.client.Metadata(request, timeout=timeout, metadata=md)
return response.metadata
13 changes: 10 additions & 3 deletions src/viam/services/mlmodel/mlmodel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import abc
from typing import Dict, Final, Optional
from typing import Dict, Final, Optional, Mapping

from numpy.typing import NDArray

from viam.proto.service.mlmodel import Metadata
from viam.resource.types import RESOURCE_NAMESPACE_RDK, RESOURCE_TYPE_SERVICE, Subtype
from viam.utils import ValueTypes

from ..service_base import ServiceBase

Expand All @@ -25,7 +26,13 @@ class MLModel(ServiceBase):
)

@abc.abstractmethod
async def infer(self, input_tensors: Dict[str, NDArray], *, timeout: Optional[float]) -> Dict[str, NDArray]:
async def infer(
self,
input_tensors: Dict[str, NDArray],
*,
extra: Optional[Mapping[str, ValueTypes]] = None,
timeout: Optional[float] = None,
) -> Dict[str, NDArray]:
"""Take an already ordered input tensor as an array, make an inference on the model, and return an output tensor map.
::
Expand All @@ -50,7 +57,7 @@ async def infer(self, input_tensors: Dict[str, NDArray], *, timeout: Optional[fl
...

@abc.abstractmethod
async def metadata(self, *, timeout: Optional[float]) -> Metadata:
async def metadata(self, *, extra: Optional[Mapping[str, ValueTypes]] = None, timeout: Optional[float] = None) -> Metadata:
"""Get the metadata (such as name, type, expected tensor/array shape, inputs, and outputs) associated with the ML model.
::
Expand Down
7 changes: 5 additions & 2 deletions src/viam/services/mlmodel/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from viam.proto.service.mlmodel import InferRequest, InferResponse, MetadataRequest, MetadataResponse, MLModelServiceBase
from viam.resource.rpc_service_base import ResourceRPCServiceBase
from viam.services.mlmodel.utils import flat_tensors_to_ndarrays, ndarrays_to_flat_tensors
from viam.utils import struct_to_dict

from .mlmodel import MLModel

Expand All @@ -19,8 +20,9 @@ async def Infer(self, stream: Stream[InferRequest, InferResponse]) -> None:
assert request is not None
name = request.name
mlmodel = self.get_resource(name)
extra = struct_to_dict(request.extra)
timeout = stream.deadline.time_remaining() if stream.deadline else None
output_tensors = await mlmodel.infer(input_tensors=flat_tensors_to_ndarrays(request.input_tensors), timeout=timeout)
output_tensors = await mlmodel.infer(input_tensors=flat_tensors_to_ndarrays(request.input_tensors), extra=extra, timeout=timeout)
response = InferResponse(output_tensors=ndarrays_to_flat_tensors(output_tensors))
await stream.send_message(response)

Expand All @@ -29,7 +31,8 @@ async def Metadata(self, stream: Stream[MetadataRequest, MetadataResponse]) -> N
assert request is not None
name = request.name
mlmodel = self.get_resource(name)
extra = struct_to_dict(request.extra)
timeout = stream.deadline.time_remaining() if stream.deadline else None
metadata = await mlmodel.metadata(timeout=timeout)
metadata = await mlmodel.metadata(extra=extra, timeout=timeout)
response = MetadataResponse(metadata=metadata)
await stream.send_message(response)
10 changes: 8 additions & 2 deletions tests/mocks/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,13 +515,19 @@ def __init__(self, name: str):

super().__init__(name)

async def infer(self, input_tensors: Dict[str, NDArray], *, timeout: Optional[float] = None) -> Dict[str, NDArray]:
async def infer(
self,
input_tensors: Dict[str, NDArray],
*,
extra: Optional[Mapping[str, ValueTypes]] = None,
timeout: Optional[float] = None,
) -> Dict[str, NDArray]:
self.timeout = timeout
request_data = ndarrays_to_flat_tensors(input_tensors)
response_data = flat_tensors_to_ndarrays(request_data)
return response_data

async def metadata(self, *, timeout: Optional[float] = None) -> Metadata:
async def metadata(self, *, extra: Optional[Mapping[str, ValueTypes]] = None, timeout: Optional[float] = None) -> Metadata:
self.timeout = timeout
return self.META

Expand Down

0 comments on commit b630e05

Please sign in to comment.