Skip to content

Commit

Permalink
added unit test for extra inference kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanbo Liu committed Dec 5, 2023
1 parent 83fe202 commit 70f0de8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
5 changes: 3 additions & 2 deletions runtimes/huggingface/mlserver_huggingface/codecs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:

value = get_decoded_or_raw(item)
values[item.name] = value
if request.parameters.extra is not None:
values.update(request.parameters.extra)
if request.parameters is not None:
if request.parameters.extra is not None:
values.update(request.parameters.extra)
return values


Expand Down
63 changes: 62 additions & 1 deletion runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest.mock import MagicMock, patch

import json
import pytest
import torch
from typing import Dict, Optional
Expand All @@ -13,6 +13,9 @@
from mlserver_huggingface.runtime import HuggingFaceRuntime
from mlserver_huggingface.settings import HuggingFaceSettings
from mlserver_huggingface.common import load_pipeline_from_settings
from mlserver.types import InferenceRequest, RequestInput
from mlserver.settings import ModelSettings, ModelParameters
from mlserver.types.dataplane import Parameters


@pytest.mark.parametrize(
Expand Down Expand Up @@ -210,3 +213,61 @@ def test_pipeline_checks_for_eos_and_pad_token(
m = load_pipeline_from_settings(hf_settings, model_settings)

assert m._batch_size == expected_batch_size


@pytest.mark.parametrize(
"inference_kwargs1, inference_kwargs2, expected",
[
(
{"max_length": 20},
{"max_length": 10},
True,
)
],
)
async def test_pipeline_uses_inference_kwargs(
inference_kwargs1: Optional[dict],
inference_kwargs2: Optional[dict],
expected: bool,
):
model_settings = ModelSettings(
name="foo",
implementation=HuggingFaceRuntime,
parameters=ModelParameters(
extra={
"pretrained_model": "Maykeye/TinyLLama-v0",
"task": "text-generation",
}
),
)
runtime = HuggingFaceRuntime(model_settings)
runtime.ready = await runtime.load()
payload1 = InferenceRequest(
inputs=[
RequestInput(
name="args",
shape=[1],
datatype="BYTES",
data=["This is a test"],
)
],
parameters=Parameters(extra=inference_kwargs1),
)
payload2 = InferenceRequest(
inputs=[
RequestInput(
name="args",
shape=[1],
datatype="BYTES",
data=["This is a test"],
)
],
parameters=Parameters(extra=inference_kwargs2),
)

result1 = await runtime.predict(payload1)
result1 = json.loads(result1.outputs[0].data[0])["generated_text"]

result2 = await runtime.predict(payload2)
result2 = json.loads(result2.outputs[0].data[0])["generated_text"]
assert (len(result1) > len(result2)) == expected

0 comments on commit 70f0de8

Please sign in to comment.