Skip to content

Commit

Permalink
Forward rest parameters to model (#1921)
Browse files Browse the repository at this point in the history
* Forwards rest parameters to model

Fixes #1660

* Add unittests for forwarding params
  • Loading branch information
idlefella authored Nov 19, 2024
1 parent c5cc74a commit 7063618
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
5 changes: 4 additions & 1 deletion runtimes/mlflow/mlserver_mlflow/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,8 @@ def _sync_metadata(self) -> None:

async def predict(self, payload: InferenceRequest) -> InferenceResponse:
decoded_payload = self.decode_request(payload)
model_output = self._model.predict(decoded_payload)
params = None
if payload.parameters and payload.parameters.model_extra:
params = payload.parameters.model_extra
model_output = self._model.predict(decoded_payload, params=params)
return self.encode_response(model_output, default_codec=TensorDictCodec)
49 changes: 48 additions & 1 deletion runtimes/mlflow/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pandas as pd

from typing import Any

from mlserver.codecs import NumpyCodec, PandasCodec, StringCodec
from mlserver.types import (
InferenceRequest,
Expand Down Expand Up @@ -257,3 +256,51 @@ async def test_invocation_with_params(
predict_mock.call_args[0][0].get("foo"), expected["data"]["foo"]
)
assert predict_mock.call_args.kwargs["params"] == expected["params"]


@pytest.mark.parametrize(
"input, params",
[
(
InferenceRequest(
parameters=Parameters(
content_type=NumpyCodec.ContentType, extra_param="extra_value"
),
inputs=[
RequestInput(
name="predict",
shape=[1, 10],
data=[range(0, 10)],
datatype="INT64",
parameters=Parameters(extra_param2="extra_value2"),
)
],
),
{"extra_param": "extra_value"},
),
(
InferenceRequest(
parameters=Parameters(content_type=NumpyCodec.ContentType),
inputs=[
RequestInput(
name="predict",
shape=[1, 10],
data=[range(0, 10)],
datatype="INT64",
)
],
),
None,
),
],
)
async def test_predict_with_params(
runtime: MLflowRuntime,
input: InferenceRequest,
params: dict,
):
with mock.patch.object(
runtime._model, "predict", return_value={"test": np.array([1, 2, 3])}
) as predict_mock:
await runtime.predict(input)
assert predict_mock.call_args.kwargs == {"params": params}

0 comments on commit 7063618

Please sign in to comment.