From 4db6f40d5b2dc1d7969d0bc5852dd225426b8110 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Dec 2024 10:31:38 +0100 Subject: [PATCH 1/5] message conversion function --- haystack/utils/hf.py | 40 ++++++++++++++++++++++++++++- test/utils/test_hf.py | 59 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 537b05e232..dbff3f22dc 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from haystack import logging -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice @@ -270,6 +270,44 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte ) +def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]: + """ + Convert a message to the format expected by Hugging Face. + """ + text_contents = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not text_contents and not tool_calls and not tool_call_results: + raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") + elif len(text_contents) + len(tool_call_results) > 1: + raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") + + # HF always expects a content field, even if it is empty + hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""} + + if tool_call_results: + result = tool_call_results[0] + hf_msg["content"] = result.result + if tc_id := result.origin.id: + hf_msg["tool_call_id"] = tc_id + # HF does not provide a way to communicate errors in tool invocations, so we ignore the error field + return hf_msg + + if text_contents: + hf_msg["content"] = text_contents[0] + if tool_calls: + hf_tool_calls = [] + for tc in tool_calls: + hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}} + if tc.id is not None: + hf_tool_call["id"] = tc.id + hf_tool_calls.append(hf_tool_call) + hf_msg["tool_calls"] = hf_tool_calls + + return hf_msg + + with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer diff --git a/test/utils/test_hf.py b/test/utils/test_hf.py index 4350fb9fbb..d75e0b7501 100644 --- a/test/utils/test_hf.py +++ b/test/utils/test_hf.py @@ -2,8 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -from haystack.utils.hf import resolve_hf_device_map + +import pytest + +from haystack.utils.hf import resolve_hf_device_map, convert_message_to_hf_format from haystack.utils.device import ComponentDevice +from haystack.dataclasses import ChatMessage, ToolCall, ChatRole, TextContent def test_resolve_hf_device_map_only_device(): @@ -23,3 +27,56 @@ def test_resolve_hf_device_map_device_and_device_map(caplog): ) assert "The parameters `device` and `device_map` from `model_kwargs` are both provided." in caplog.text assert model_kwargs["device_map"] == "cuda:0" + + +def test_convert_message_to_hf_format(): + message = ChatMessage.from_system("You are good assistant") + assert convert_message_to_hf_format(message) == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert convert_message_to_hf_format(message) == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) + assert convert_message_to_hf_format(message) == {"role": "assistant", "content": "I have an answer"} + + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + assert convert_message_to_hf_format(message) == { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}} + ], + } + + message = ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather", arguments={"city": "Paris"})]) + assert convert_message_to_hf_format(message) == { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}}], + } + + tool_result = {"weather": "sunny", "temperature": "25"} + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ) + assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result, "tool_call_id": "123"} + + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(tool_name="weather", arguments={"city": "Paris"}) + ) + assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result} + + +def test_convert_message_to_hf_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + convert_message_to_hf_format(message) + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")], + ) + with pytest.raises(ValueError): + convert_message_to_hf_format(message) From 44b010338ba228cfeb627ba92bb63df30eb6a03f Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Dec 2024 11:11:34 +0100 Subject: [PATCH 2/5] hfapi w tools --- .../generators/chat/hugging_face_api.py | 150 ++++-- haystack/dataclasses/tool.py | 15 +- haystack/utils/hf.py | 2 +- .../generators/test_hugging_face_api.py | 499 +++++++++++++----- test/dataclasses/test_tool.py | 16 + 5 files changed, 502 insertions(+), 180 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 8711a9175a..cc6462018e 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -5,30 +5,25 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable -from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model +from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format from haystack.utils.url_validation import is_valid_http_url with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import: - from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient + from huggingface_hub import ( + ChatCompletionInputTool, + ChatCompletionOutput, + ChatCompletionStreamOutput, + InferenceClient, + ) logger = logging.getLogger(__name__) -def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: - """ - Convert a message to the format expected by Hugging Face APIs. - - :returns: A dictionary with the following keys: - - `role` - - `content` - """ - return {"role": message.role.value, "content": message.text or ""} - - @component class HuggingFaceAPIChatGenerator: """ @@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, ): """ Initialize the HuggingFaceAPIChatGenerator instance. @@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_GENERATION_INFERENCE`. - :param token: The Hugging Face token to use as HTTP bearer authorization. + :param token: + The Hugging Face token to use as HTTP bearer authorization. Check your HF token in your [account settings](https://huggingface.co/settings/tokens). :param generation_kwargs: A dictionary with keyword arguments to customize text generation. Some examples: `max_tokens`, `temperature`, `top_p`. For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion). - :param stop_words: An optional list of strings representing the stop words. - :param streaming_callback: An optional callable for handling streaming responses. + :param stop_words: + An optional list of strings representing the stop words. + :param streaming_callback: + An optional callable for handling streaming responses. + :param tools: + A list of tools for which the model can prepare calls. + The chosen model should support tool/function calling, according to the model card. + Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience + unexpected behavior. """ huggingface_hub_import.check() @@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments msg = f"Unknown api_type {api_type}" raise ValueError(msg) + if tools: + if streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) + # handle generation kwargs setup generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} generation_kwargs["stop"] = generation_kwargs.get("stop", []) @@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments self.generation_kwargs = generation_kwargs self.streaming_callback = streaming_callback self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) + self.tools = tools def to_dict(self) -> Dict[str, Any]: """ @@ -180,6 +190,7 @@ def to_dict(self) -> Dict[str, Any]: A dictionary containing the serialized component. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, api_type=str(self.api_type), @@ -187,6 +198,7 @@ def to_dict(self) -> Dict[str, Any]: token=self.token.to_dict() if self.token else None, generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, + tools=serialized_tools, ) @classmethod @@ -195,6 +207,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": Deserialize this component from a dictionary. """ deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: @@ -202,12 +215,22 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): """ Invoke the text generation inference based on the provided messages and generation parameters. - :param messages: A list of ChatMessage objects representing the input messages. - :param generation_kwargs: Additional keyword arguments for text generation. + :param messages: + A list of ChatMessage objects representing the input messages. + :param generation_kwargs: + Additional keyword arguments for text generation. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage objects. """ @@ -215,12 +238,22 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # update generation kwargs by merging with the default ones generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages] + formatted_messages = [convert_message_to_hf_format(message) for message in messages] + + tools = tools or self.tools + if tools: + if self.streaming_callback: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) if self.streaming_callback: return self._run_streaming(formatted_messages, generation_kwargs) - return self._run_non_streaming(formatted_messages, generation_kwargs) + hf_tools = None + if tools: + hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] + + return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]): api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion( @@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict generated_text = "" - for chunk in api_output: # pylint: disable=not-an-iterable - text = chunk.choices[0].delta.content + for chunk in api_output: + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = chunk.choices[0] + + text = choice.delta.content if text: generated_text += text - finish_reason = chunk.choices[0].finish_reason + + finish_reason = choice.finish_reason meta = {} if finish_reason: @@ -242,8 +281,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict stream_chunk = StreamingChunk(text, meta) self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) - message = ChatMessage.from_assistant(generated_text) - message.meta.update( + meta.update( { "model": self._client.model, "finish_reason": finish_reason, @@ -251,24 +289,48 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict "usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming } ) + + message = ChatMessage.from_assistant(text=generated_text, meta=meta) + return {"replies": [message]} def _run_non_streaming( - self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any] + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + tools: Optional[List["ChatCompletionInputTool"]] = None, ) -> Dict[str, List[ChatMessage]]: - chat_messages: List[ChatMessage] = [] - - api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs) - for choice in api_chat_output.choices: - message = ChatMessage.from_assistant(choice.message.content) - message.meta.update( - { - "model": self._client.model, - "finish_reason": choice.finish_reason, - "index": choice.index, - "usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0}, - } - ) - chat_messages.append(message) - - return {"replies": chat_messages} + api_chat_output: ChatCompletionOutput = self._client.chat_completion( + messages=messages, tools=tools, **generation_kwargs + ) + + if len(api_chat_output.choices) == 0: + return {"replies": []} + + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = api_chat_output.choices[0] + + text = choice.message.content + tool_calls = [] + + if hfapi_tool_calls := choice.message.tool_calls: + for hfapi_tc in hfapi_tool_calls: + tool_call = ToolCall( + tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id + ) + tool_calls.append(tool_call) + + meta = { + "model": self._client.model, + "finish_reason": choice.finish_reason, + "index": choice.index, + "usage": { + "prompt_tokens": api_chat_output.usage.prompt_tokens, + "completion_tokens": api_chat_output.usage.completion_tokens, + }, + } + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + return {"replies": [message]} diff --git a/haystack/dataclasses/tool.py b/haystack/dataclasses/tool.py index 3df3fd18f2..4aaf1e2bd1 100644 --- a/haystack/dataclasses/tool.py +++ b/haystack/dataclasses/tool.py @@ -4,7 +4,7 @@ import inspect from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional from pydantic import create_model @@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]): del property_schema[key] +def _check_duplicate_tool_names(tools: List[Tool]) -> None: + """ + Check for duplicate tool names. + + :param tools: The list of tools to check. + :raises ValueError: If duplicate tool names are found. + """ + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): """ Deserialize Tools in a dictionary inplace. diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index dbff3f22dc..6bc8169685 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -280,7 +280,7 @@ def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]: if not text_contents and not tool_calls and not tool_call_results: raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") - elif len(text_contents) + len(tool_call_results) > 1: + if len(text_contents) + len(tool_call_results) > 1: raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") # HF always expects a content field, even if it is empty diff --git a/test/components/generators/test_hugging_face_api.py b/test/components/generators/test_hugging_face_api.py index 0f4be2f9cb..0d0857e22a 100644 --- a/test/components/generators/test_hugging_face_api.py +++ b/test/components/generators/test_hugging_face_api.py @@ -5,38 +5,78 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from haystack import Pipeline +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType from huggingface_hub import ( - TextGenerationOutputToken, - TextGenerationStreamOutput, - TextGenerationStreamOutputStreamDetails, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, ) from huggingface_hub.utils import RepositoryNotFoundError -from haystack.components.generators import HuggingFaceAPIGenerator -from haystack.dataclasses import StreamingChunk -from haystack.utils.auth import Secret -from haystack.utils.hf import HFGenerationAPIType +from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator +from haystack.dataclasses import ChatMessage, Tool, ToolCall + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), + ChatMessage.from_user("Tell me about Berlin"), + ] + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] @pytest.fixture def mock_check_valid_model(): with patch( - "haystack.components.generators.hugging_face_api.check_valid_model", MagicMock(return_value=None) + "haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None) ) as mock: yield mock @pytest.fixture -def mock_text_generation(): - with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation: - mock_response = Mock() - mock_response.generated_text = "I'm fine, thanks." - details = Mock() - details.finish_reason = MagicMock(field1="value") - details.tokens = [1, 2, 3] - mock_response.details = details - mock_text_generation.return_value = mock_response - yield mock_text_generation +def mock_chat_completion(): + # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example + + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), + created=1710498360, + ) + + mock_chat_completion.return_value = completion + yield mock_chat_completion # used to test serialization of streaming_callback @@ -44,10 +84,10 @@ def streaming_callback_handler(x): return x -class TestHuggingFaceAPIGenerator: +class TestHuggingFaceAPIChatGenerator: def test_init_invalid_api_type(self): with pytest.raises(ValueError): - HuggingFaceAPIGenerator(api_type="invalid_api_type", api_params={}) + HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) def test_init_serverless(self, mock_check_valid_model): model = "HuggingFaceH4/zephyr-7b-alpha" @@ -55,7 +95,7 @@ def test_init_serverless(self, mock_check_valid_model): stop_words = ["stop"] streaming_callback = None - generator = HuggingFaceAPIGenerator( + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model}, token=None, @@ -66,23 +106,42 @@ def test_init_serverless(self, mock_check_valid_model): assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator.api_params == {"model": model} - assert generator.generation_kwargs == { - **generation_kwargs, - **{"stop_sequences": ["stop"]}, - **{"max_new_tokens": 512}, - } + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None + + def test_init_serverless_with_tools(self, mock_check_valid_model, tools): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"temperature": 0.6} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": model}, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + tools=tools, + ) + + assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert generator.api_params == {"model": model} + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.streaming_callback == streaming_callback + assert generator.tools == tools def test_init_serverless_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") with pytest.raises(RepositoryNotFoundError): - HuggingFaceAPIGenerator( + HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} ) def test_init_serverless_no_model(self): with pytest.raises(ValueError): - HuggingFaceAPIGenerator( + HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} ) @@ -92,7 +151,7 @@ def test_init_tgi(self): stop_words = ["stop"] streaming_callback = None - generator = HuggingFaceAPIGenerator( + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": url}, token=None, @@ -103,31 +162,49 @@ def test_init_tgi(self): assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE assert generator.api_params == {"url": url} - assert generator.generation_kwargs == { - **generation_kwargs, - **{"stop_sequences": ["stop"]}, - **{"max_new_tokens": 512}, - } + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None def test_init_tgi_invalid_url(self): with pytest.raises(ValueError): - HuggingFaceAPIGenerator( + HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"} ) def test_init_tgi_no_url(self): with pytest.raises(ValueError): - HuggingFaceAPIGenerator( + HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} ) + def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=duplicate_tools, + ) + + def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=tools, + streaming_callback=streaming_callback_handler, + ) + def test_to_dict(self, mock_check_valid_model): - generator = HuggingFaceAPIGenerator( + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], + tools=[tool], ) result = generator.to_dict() @@ -136,101 +213,118 @@ def test_to_dict(self, mock_check_valid_model): assert init_params["api_type"] == "serverless_inference_api" assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} - assert init_params["generation_kwargs"] == { - "temperature": 0.6, - "stop_sequences": ["stop", "words"], - "max_new_tokens": 512, - } + assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert init_params["streaming_callback"] is None + assert init_params["tools"] == [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + } + ] def test_from_dict(self, mock_check_valid_model): - generator = HuggingFaceAPIGenerator( + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, + tools=[tool], ) result = generator.to_dict() # now deserialize, call from_dict - generator_2 = HuggingFaceAPIGenerator.from_dict(result) + generator_2 = HuggingFaceAPIChatGenerator.from_dict(result) assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) - assert generator_2.generation_kwargs == { - "temperature": 0.6, - "stop_sequences": ["stop", "words"], - "max_new_tokens": 512, - } - assert generator_2.streaming_callback is streaming_callback_handler + assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert generator_2.streaming_callback is None + assert generator_2.tools == [tool] - def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_text_generation - ): - generator = HuggingFaceAPIGenerator( + def test_serde_in_pipeline(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - streaming_callback=None, + tools=[tool], ) - prompt = "Hello, how are you?" - response = generator.run(prompt) - - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == { - "details": True, - "temperature": 0.6, - "stop_sequences": ["stop", "words"], - "stream": False, - "max_new_tokens": 512, + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", + "init_parameters": { + "api_type": "serverless_inference_api", + "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, + "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, + "streaming_callback": None, + "tools": [ + { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + } + ], + }, + } + }, + "connections": [], } - assert isinstance(response, dict) - assert "replies" in response - assert "meta" in response - assert isinstance(response["replies"], list) - assert isinstance(response["meta"], list) - assert len(response["replies"]) == 1 - assert len(response["meta"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] + pipeline_yaml = pipeline.dumps() - def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_model, mock_text_generation): - generator = HuggingFaceAPIGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"} + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + streaming_callback=None, ) - generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} - response = generator.run("How are you?", generation_kwargs=generation_kwargs) + response = generator.run(messages=chat_messages) - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args + # check kwargs passed to chat_completion + _, kwargs = mock_chat_completion.call_args + hf_messages = [ + {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, + {"role": "user", "content": "Tell me about Berlin"}, + ] assert kwargs == { - "details": True, - "max_new_tokens": 100, - "stop_sequences": [], - "stream": False, - "temperature": 0.8, + "temperature": 0.6, + "stop": ["stop", "words"], + "max_tokens": 512, + "tools": None, + "messages": hf_messages, } - # Assert that the response contains the generated replies and the right response + assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] - assert response["replies"][0] == "I'm fine, thanks." - - # Assert that the response contains the metadata - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_text_generation): + def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): streaming_call_count = 0 # Define the streaming callback function @@ -239,38 +333,50 @@ def streaming_callback_fn(chunk: StreamingChunk): streaming_call_count += 1 assert isinstance(chunk, StreamingChunk) - generator = HuggingFaceAPIGenerator( + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, streaming_callback=streaming_callback_fn, ) # Create a fake streamed response - # Don't remove self + # self needed here, don't remove def mock_iter(self): - yield TextGenerationStreamOutput( - index=0, - generated_text=None, - token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False), + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), + index=0, + finish_reason=None, + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, ) - yield TextGenerationStreamOutput( - index=1, - generated_text=None, - token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False), - details=TextGenerationStreamOutputStreamDetails( - finish_reason="length", generated_tokens=5, seed=None, input_length=10 - ), + + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, ) mock_response = Mock(**{"__iter__": mock_iter}) - mock_text_generation.return_value = mock_response + mock_chat_completion.return_value = mock_response # Generate text response with streaming callback - response = generator.run("prompt") + response = generator.run(chat_messages) # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512} + _, kwargs = mock_chat_completion.call_args + assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} # Assert that the streaming callback was called twice assert streaming_call_count == 2 @@ -279,36 +385,161 @@ def mock_iter(self): assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): + component = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + streaming_callback=streaming_callback_handler, + ) + + with pytest.raises(ValueError): + message = ChatMessage.from_user("irrelevant") + component.run([message], tools=tools) + + def test_run_with_tools(self, mock_check_valid_model, tools): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, + tools=tools, + ) - # Assert that the response contains the metadata - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) > 0 - assert [isinstance(meta, dict) for meta in response["meta"]] + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="stop", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionOutputToolCall( + function=ChatCompletionOutputFunctionDefinition( + arguments={"city": "Paris"}, name="weather", description=None + ), + id="0", + type="function", + ) + ], + ), + logprobs=None, + ) + ], + created=1729074760, + id="", + model="meta-llama/Llama-3.1-70B-Instruct", + system_fingerprint="2.3.2-dev0-sha-28bb7ae", + usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), + ) + mock_chat_completion.return_value = completion + + messages = [ChatMessage.from_user("What is the weather in Paris?")] + response = generator.run(messages=messages) + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert response["replies"][0].tool_calls[0].tool_name == "weather" + assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} + assert response["replies"][0].tool_calls[0].id == "0" + assert response["replies"][0].meta == { + "finish_reason": "stop", + "index": 0, + "model": "meta-llama/Llama-3.1-70B-Instruct", + "usage": {"completion_tokens": 30, "prompt_tokens": 426}, + } - @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_run_serverless(self): - generator = HuggingFaceAPIGenerator( + def test_live_run_serverless(self): + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, - generation_kwargs={"max_new_tokens": 20}, + generation_kwargs={"max_tokens": 20}, ) - response = generator.run("How are you?") - # Assert that the response contains the generated replies + messages = [ChatMessage.from_user("What is the capital of France?")] + response = generator.run(messages=messages) + + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "usage" in response["replies"][0].meta + assert "prompt_tokens" in response["replies"][0].meta["usage"] + assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + def test_live_run_serverless_streaming(self): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"max_tokens": 20}, + streaming_callback=streaming_callback_handler, + ) + + messages = [ChatMessage.from_user("What is the capital of France?")] + response = generator.run(messages=messages) + assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "usage" in response["replies"][0].meta + assert "prompt_tokens" in response["replies"][0].meta["usage"] + assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + """ + We test the round trip: generate tool call, pass tool message, generate response. + + The model used here (zephyr-7b-beta) is always available and not gated. + Even if it does not officially support tools, TGI+HF API make it work. + """ + + chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")] + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"temperature": 0.5}, + ) + + results = generator.run(chat_messages, tools=tools) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert "city" in tool_call.arguments + assert "Paris" in tool_call.arguments["city"] + assert message.meta["finish_reason"] == "stop" + + new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + + # the model tends to make tool calls if provided with tools, so we don't pass them here + results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) - # Assert that the response contains the metadata - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) > 0 - assert [isinstance(meta, dict) for meta in response["meta"]] + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() diff --git a/test/dataclasses/test_tool.py b/test/dataclasses/test_tool.py index db9719a7f3..9e112853f3 100644 --- a/test/dataclasses/test_tool.py +++ b/test/dataclasses/test_tool.py @@ -12,6 +12,7 @@ ToolInvocationError, _remove_title_from_schema, deserialize_tools_inplace, + _check_duplicate_tool_names, ) try: @@ -303,3 +304,18 @@ def test_remove_title_from_schema_handle_no_title_in_top_level(): "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, "type": "object", } + + +def test_check_duplicate_tool_names(): + tools = [ + Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), + Tool(name="weather", description="A different description", parameters=parameters, function=get_weather_report), + ] + with pytest.raises(ValueError): + _check_duplicate_tool_names(tools) + + tools = [ + Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), + Tool(name="weather2", description="Get weather report", parameters=parameters, function=get_weather_report), + ] + _check_duplicate_tool_names(tools) From 3bc537463370386195371677b90b6b17f3ce811d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Dec 2024 11:40:32 +0100 Subject: [PATCH 3/5] right test file + hf_hub version --- .../hugging_face_api_document_embedder.py | 2 +- .../hugging_face_api_text_embedder.py | 2 +- .../generators/chat/hugging_face_api.py | 2 +- .../components/generators/hugging_face_api.py | 2 +- haystack/utils/hf.py | 2 +- pyproject.toml | 2 +- .../generators/chat/test_hugging_face_api.py | 297 +++++++++-- .../generators/test_hugging_face_api.py | 499 +++++------------- 8 files changed, 406 insertions(+), 402 deletions(-) diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 43f719e27d..459e386976 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -14,7 +14,7 @@ from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index f60a9e5fd7..2cd68d34da 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -11,7 +11,7 @@ from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index cc6462018e..8e7a6dac18 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -12,7 +12,7 @@ from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import ( ChatCompletionInputTool, ChatCompletionOutput, diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index a164c8c56c..a44ad94575 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -12,7 +12,7 @@ from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import ( InferenceClient, TextGenerationOutput, diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 6bc8169685..6a83594ada 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -16,7 +16,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_import: import torch -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import HfApi, InferenceClient, model_info from huggingface_hub.utils import RepositoryNotFoundError diff --git a/pyproject.toml b/pyproject.toml index c1fddc8704..6a76a2e9c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ extra-dependencies = [ "numpy>=2", # Haystack is compatible both with numpy 1.x and 2.x, but we test with 2.x "transformers[torch,sentencepiece]==4.44.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... - "huggingface_hub>=0.23.0", # Hugging Face API Generators and Embedders + "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier "openai-whisper>=20231106", # LocalWhisperTranscriber diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index e60ec863ab..0d0857e22a 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -5,23 +5,46 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from haystack import Pipeline +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType from huggingface_hub import ( ChatCompletionOutput, - ChatCompletionStreamOutput, ChatCompletionOutputComplete, - ChatCompletionStreamOutputChoice, + ChatCompletionOutputFunctionDefinition, ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputDelta, ) from huggingface_hub.utils import RepositoryNotFoundError -from haystack.components.generators.chat.hugging_face_api import ( - HuggingFaceAPIChatGenerator, - _convert_message_to_hfapi_format, -) -from haystack.dataclasses import ChatMessage, StreamingChunk -from haystack.utils.auth import Secret -from haystack.utils.hf import HFGenerationAPIType +from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator +from haystack.dataclasses import ChatMessage, Tool, ToolCall + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), + ChatMessage.from_user("Tell me about Berlin"), + ] + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] @pytest.fixture @@ -48,7 +71,7 @@ def mock_chat_completion(): id="some_id", model="some_model", system_fingerprint="some_fingerprint", - usage={"completion_tokens": 10, "prompt_tokens": 5, "total_tokens": 15}, + usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), created=1710498360, ) @@ -61,15 +84,7 @@ def streaming_callback_handler(x): return x -def test_convert_message_to_hfapi_format(): - message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_hfapi_format(message) == {"role": "system", "content": "You are good assistant"} - - message = ChatMessage.from_user("I have a question") - assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"} - - -class TestHuggingFaceAPIGenerator: +class TestHuggingFaceAPIChatGenerator: def test_init_invalid_api_type(self): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) @@ -93,6 +108,29 @@ def test_init_serverless(self, mock_check_valid_model): assert generator.api_params == {"model": model} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None + + def test_init_serverless_with_tools(self, mock_check_valid_model, tools): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"temperature": 0.6} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": model}, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + tools=tools, + ) + + assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert generator.api_params == {"model": model} + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.streaming_callback == streaming_callback + assert generator.tools == tools def test_init_serverless_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") @@ -126,6 +164,7 @@ def test_init_tgi(self): assert generator.api_params == {"url": url} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None def test_init_tgi_invalid_url(self): with pytest.raises(ValueError): @@ -139,12 +178,33 @@ def test_init_tgi_no_url(self): api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} ) + def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=duplicate_tools, + ) + + def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=tools, + streaming_callback=streaming_callback_handler, + ) + def test_to_dict(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], + tools=[tool], ) result = generator.to_dict() @@ -154,15 +214,26 @@ def test_to_dict(self, mock_check_valid_model): assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert init_params["streaming_callback"] is None + assert init_params["tools"] == [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + } + ] def test_from_dict(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, + tools=[tool], ) result = generator.to_dict() @@ -172,11 +243,57 @@ def test_from_dict(self, mock_check_valid_model): assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} - assert generator_2.streaming_callback is streaming_callback_handler + assert generator_2.streaming_callback is None + assert generator_2.tools == [tool] + + def test_serde_in_pipeline(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) - def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_chat_completion, chat_messages - ): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + token=Secret.from_env_var("ENV_VAR", strict=False), + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + tools=[tool], + ) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", + "init_parameters": { + "api_type": "serverless_inference_api", + "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, + "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, + "streaming_callback": None, + "tools": [ + { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + } + ], + }, + } + }, + "connections": [], + } + + pipeline_yaml = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, @@ -187,9 +304,19 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( response = generator.run(messages=chat_messages) - # check kwargs passed to text_generation + # check kwargs passed to chat_completion _, kwargs = mock_chat_completion.call_args - assert kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + hf_messages = [ + {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, + {"role": "user", "content": "Tell me about Berlin"}, + ] + assert kwargs == { + "temperature": 0.6, + "stop": ["stop", "words"], + "max_tokens": 512, + "tools": None, + "messages": hf_messages, + } assert isinstance(response, dict) assert "replies" in response @@ -197,7 +324,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): + def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): streaming_call_count = 0 # Define the streaming callback function @@ -260,13 +387,78 @@ def mock_iter(self): assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - @pytest.mark.flaky(reruns=5, reruns_delay=5) + def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): + component = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + streaming_callback=streaming_callback_handler, + ) + + with pytest.raises(ValueError): + message = ChatMessage.from_user("irrelevant") + component.run([message], tools=tools) + + def test_run_with_tools(self, mock_check_valid_model, tools): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, + tools=tools, + ) + + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="stop", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionOutputToolCall( + function=ChatCompletionOutputFunctionDefinition( + arguments={"city": "Paris"}, name="weather", description=None + ), + id="0", + type="function", + ) + ], + ), + logprobs=None, + ) + ], + created=1729074760, + id="", + model="meta-llama/Llama-3.1-70B-Instruct", + system_fingerprint="2.3.2-dev0-sha-28bb7ae", + usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), + ) + mock_chat_completion.return_value = completion + + messages = [ChatMessage.from_user("What is the weather in Paris?")] + response = generator.run(messages=messages) + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert response["replies"][0].tool_calls[0].tool_name == "weather" + assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} + assert response["replies"][0].tool_calls[0].id == "0" + assert response["replies"][0].meta == { + "finish_reason": "stop", + "index": 0, + "model": "meta-llama/Llama-3.1-70B-Instruct", + "usage": {"completion_tokens": 30, "prompt_tokens": 426}, + } + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_run_serverless(self): + def test_live_run_serverless(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, @@ -284,13 +476,12 @@ def test_run_serverless(self): assert "prompt_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"] - @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_run_serverless_streaming(self): + def test_live_run_serverless_streaming(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, @@ -308,3 +499,47 @@ def test_run_serverless_streaming(self): assert "usage" in response["replies"][0].meta assert "prompt_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + """ + We test the round trip: generate tool call, pass tool message, generate response. + + The model used here (zephyr-7b-beta) is always available and not gated. + Even if it does not officially support tools, TGI+HF API make it work. + """ + + chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")] + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"temperature": 0.5}, + ) + + results = generator.run(chat_messages, tools=tools) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert "city" in tool_call.arguments + assert "Paris" in tool_call.arguments["city"] + assert message.meta["finish_reason"] == "stop" + + new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + + # the model tends to make tool calls if provided with tools, so we don't pass them here + results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() diff --git a/test/components/generators/test_hugging_face_api.py b/test/components/generators/test_hugging_face_api.py index 0d0857e22a..0f4be2f9cb 100644 --- a/test/components/generators/test_hugging_face_api.py +++ b/test/components/generators/test_hugging_face_api.py @@ -5,78 +5,38 @@ from unittest.mock import MagicMock, Mock, patch import pytest -from haystack import Pipeline -from haystack.dataclasses import StreamingChunk -from haystack.utils.auth import Secret -from haystack.utils.hf import HFGenerationAPIType from huggingface_hub import ( - ChatCompletionOutput, - ChatCompletionOutputComplete, - ChatCompletionOutputFunctionDefinition, - ChatCompletionOutputMessage, - ChatCompletionOutputToolCall, - ChatCompletionOutputUsage, - ChatCompletionStreamOutput, - ChatCompletionStreamOutputChoice, - ChatCompletionStreamOutputDelta, + TextGenerationOutputToken, + TextGenerationStreamOutput, + TextGenerationStreamOutputStreamDetails, ) from huggingface_hub.utils import RepositoryNotFoundError -from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator -from haystack.dataclasses import ChatMessage, Tool, ToolCall - - -@pytest.fixture -def chat_messages(): - return [ - ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), - ChatMessage.from_user("Tell me about Berlin"), - ] - - -@pytest.fixture -def tools(): - tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} - tool = Tool( - name="weather", - description="useful to determine the weather in a given location", - parameters=tool_parameters, - function=lambda x: x, - ) - - return [tool] +from haystack.components.generators import HuggingFaceAPIGenerator +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType @pytest.fixture def mock_check_valid_model(): with patch( - "haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None) + "haystack.components.generators.hugging_face_api.check_valid_model", MagicMock(return_value=None) ) as mock: yield mock @pytest.fixture -def mock_chat_completion(): - # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example - - with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: - completion = ChatCompletionOutput( - choices=[ - ChatCompletionOutputComplete( - finish_reason="eos_token", - index=0, - message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), - ) - ], - id="some_id", - model="some_model", - system_fingerprint="some_fingerprint", - usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), - created=1710498360, - ) - - mock_chat_completion.return_value = completion - yield mock_chat_completion +def mock_text_generation(): + with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation: + mock_response = Mock() + mock_response.generated_text = "I'm fine, thanks." + details = Mock() + details.finish_reason = MagicMock(field1="value") + details.tokens = [1, 2, 3] + mock_response.details = details + mock_text_generation.return_value = mock_response + yield mock_text_generation # used to test serialization of streaming_callback @@ -84,10 +44,10 @@ def streaming_callback_handler(x): return x -class TestHuggingFaceAPIChatGenerator: +class TestHuggingFaceAPIGenerator: def test_init_invalid_api_type(self): with pytest.raises(ValueError): - HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) + HuggingFaceAPIGenerator(api_type="invalid_api_type", api_params={}) def test_init_serverless(self, mock_check_valid_model): model = "HuggingFaceH4/zephyr-7b-alpha" @@ -95,7 +55,7 @@ def test_init_serverless(self, mock_check_valid_model): stop_words = ["stop"] streaming_callback = None - generator = HuggingFaceAPIChatGenerator( + generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model}, token=None, @@ -106,42 +66,23 @@ def test_init_serverless(self, mock_check_valid_model): assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator.api_params == {"model": model} - assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} - assert generator.streaming_callback == streaming_callback - assert generator.tools is None - - def test_init_serverless_with_tools(self, mock_check_valid_model, tools): - model = "HuggingFaceH4/zephyr-7b-alpha" - generation_kwargs = {"temperature": 0.6} - stop_words = ["stop"] - streaming_callback = None - - generator = HuggingFaceAPIChatGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": model}, - token=None, - generation_kwargs=generation_kwargs, - stop_words=stop_words, - streaming_callback=streaming_callback, - tools=tools, - ) - - assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API - assert generator.api_params == {"model": model} - assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.generation_kwargs == { + **generation_kwargs, + **{"stop_sequences": ["stop"]}, + **{"max_new_tokens": 512}, + } assert generator.streaming_callback == streaming_callback - assert generator.tools == tools def test_init_serverless_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") with pytest.raises(RepositoryNotFoundError): - HuggingFaceAPIChatGenerator( + HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} ) def test_init_serverless_no_model(self): with pytest.raises(ValueError): - HuggingFaceAPIChatGenerator( + HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} ) @@ -151,7 +92,7 @@ def test_init_tgi(self): stop_words = ["stop"] streaming_callback = None - generator = HuggingFaceAPIChatGenerator( + generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": url}, token=None, @@ -162,49 +103,31 @@ def test_init_tgi(self): assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE assert generator.api_params == {"url": url} - assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.generation_kwargs == { + **generation_kwargs, + **{"stop_sequences": ["stop"]}, + **{"max_new_tokens": 512}, + } assert generator.streaming_callback == streaming_callback - assert generator.tools is None def test_init_tgi_invalid_url(self): with pytest.raises(ValueError): - HuggingFaceAPIChatGenerator( + HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"} ) def test_init_tgi_no_url(self): with pytest.raises(ValueError): - HuggingFaceAPIChatGenerator( + HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} ) - def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): - duplicate_tools = [tools[0], tools[0]] - with pytest.raises(ValueError): - HuggingFaceAPIChatGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "irrelevant"}, - tools=duplicate_tools, - ) - - def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): - with pytest.raises(ValueError): - HuggingFaceAPIChatGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "irrelevant"}, - tools=tools, - streaming_callback=streaming_callback_handler, - ) - def test_to_dict(self, mock_check_valid_model): - tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) - - generator = HuggingFaceAPIChatGenerator( + generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - tools=[tool], ) result = generator.to_dict() @@ -213,118 +136,101 @@ def test_to_dict(self, mock_check_valid_model): assert init_params["api_type"] == "serverless_inference_api" assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} - assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} - assert init_params["streaming_callback"] is None - assert init_params["tools"] == [ - { - "description": "description", - "function": "builtins.print", - "name": "name", - "parameters": {"x": {"type": "string"}}, - } - ] + assert init_params["generation_kwargs"] == { + "temperature": 0.6, + "stop_sequences": ["stop", "words"], + "max_new_tokens": 512, + } def test_from_dict(self, mock_check_valid_model): - tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) - - generator = HuggingFaceAPIChatGenerator( + generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - tools=[tool], + streaming_callback=streaming_callback_handler, ) result = generator.to_dict() # now deserialize, call from_dict - generator_2 = HuggingFaceAPIChatGenerator.from_dict(result) + generator_2 = HuggingFaceAPIGenerator.from_dict(result) assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) - assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} - assert generator_2.streaming_callback is None - assert generator_2.tools == [tool] - - def test_serde_in_pipeline(self, mock_check_valid_model): - tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + assert generator_2.generation_kwargs == { + "temperature": 0.6, + "stop_sequences": ["stop", "words"], + "max_new_tokens": 512, + } + assert generator_2.streaming_callback is streaming_callback_handler - generator = HuggingFaceAPIChatGenerator( + def test_generate_text_response_with_valid_prompt_and_generation_parameters( + self, mock_check_valid_model, mock_text_generation + ): + generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - tools=[tool], + streaming_callback=None, ) - pipeline = Pipeline() - pipeline.add_component("generator", generator) - - pipeline_dict = pipeline.to_dict() - assert pipeline_dict == { - "metadata": {}, - "max_runs_per_component": 100, - "components": { - "generator": { - "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", - "init_parameters": { - "api_type": "serverless_inference_api", - "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, - "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, - "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, - "streaming_callback": None, - "tools": [ - { - "name": "name", - "description": "description", - "parameters": {"x": {"type": "string"}}, - "function": "builtins.print", - } - ], - }, - } - }, - "connections": [], - } + prompt = "Hello, how are you?" + response = generator.run(prompt) - pipeline_yaml = pipeline.dumps() + # check kwargs passed to text_generation + _, kwargs = mock_text_generation.call_args + assert kwargs == { + "details": True, + "temperature": 0.6, + "stop_sequences": ["stop", "words"], + "stream": False, + "max_new_tokens": 512, + } - new_pipeline = Pipeline.loads(pipeline_yaml) - assert new_pipeline == pipeline + assert isinstance(response, dict) + assert "replies" in response + assert "meta" in response + assert isinstance(response["replies"], list) + assert isinstance(response["meta"], list) + assert len(response["replies"]) == 1 + assert len(response["meta"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] - def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): - generator = HuggingFaceAPIChatGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, - generation_kwargs={"temperature": 0.6}, - stop_words=["stop", "words"], - streaming_callback=None, + def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_model, mock_text_generation): + generator = HuggingFaceAPIGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"} ) - response = generator.run(messages=chat_messages) + generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} + response = generator.run("How are you?", generation_kwargs=generation_kwargs) - # check kwargs passed to chat_completion - _, kwargs = mock_chat_completion.call_args - hf_messages = [ - {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, - {"role": "user", "content": "Tell me about Berlin"}, - ] + # check kwargs passed to text_generation + _, kwargs = mock_text_generation.call_args assert kwargs == { - "temperature": 0.6, - "stop": ["stop", "words"], - "max_tokens": 512, - "tools": None, - "messages": hf_messages, + "details": True, + "max_new_tokens": 100, + "stop_sequences": [], + "stream": False, + "temperature": 0.8, } - assert isinstance(response, dict) + # Assert that the response contains the generated replies and the right response assert "replies" in response assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert len(response["replies"]) > 0 + assert [isinstance(reply, str) for reply in response["replies"]] + assert response["replies"][0] == "I'm fine, thanks." + + # Assert that the response contains the metadata + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) > 0 + assert [isinstance(reply, str) for reply in response["replies"]] - def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): + def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_text_generation): streaming_call_count = 0 # Define the streaming callback function @@ -333,50 +239,38 @@ def streaming_callback_fn(chunk: StreamingChunk): streaming_call_count += 1 assert isinstance(chunk, StreamingChunk) - generator = HuggingFaceAPIChatGenerator( + generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, streaming_callback=streaming_callback_fn, ) # Create a fake streamed response - # self needed here, don't remove + # Don't remove self def mock_iter(self): - yield ChatCompletionStreamOutput( - choices=[ - ChatCompletionStreamOutputChoice( - delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), - index=0, - finish_reason=None, - ) - ], - id="some_id", - model="some_model", - system_fingerprint="some_fingerprint", - created=1710498504, + yield TextGenerationStreamOutput( + index=0, + generated_text=None, + token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False), ) - - yield ChatCompletionStreamOutput( - choices=[ - ChatCompletionStreamOutputChoice( - delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" - ) - ], - id="some_id", - model="some_model", - system_fingerprint="some_fingerprint", - created=1710498504, + yield TextGenerationStreamOutput( + index=1, + generated_text=None, + token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False), + details=TextGenerationStreamOutputStreamDetails( + finish_reason="length", generated_tokens=5, seed=None, input_length=10 + ), ) mock_response = Mock(**{"__iter__": mock_iter}) - mock_chat_completion.return_value = mock_response + mock_text_generation.return_value = mock_response # Generate text response with streaming callback - response = generator.run(chat_messages) + response = generator.run("prompt") # check kwargs passed to text_generation - _, kwargs = mock_chat_completion.call_args - assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} + _, kwargs = mock_text_generation.call_args + assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512} # Assert that the streaming callback was called twice assert streaming_call_count == 2 @@ -385,161 +279,36 @@ def mock_iter(self): assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - - def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): - component = HuggingFaceAPIChatGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, - streaming_callback=streaming_callback_handler, - ) - - with pytest.raises(ValueError): - message = ChatMessage.from_user("irrelevant") - component.run([message], tools=tools) - - def test_run_with_tools(self, mock_check_valid_model, tools): - generator = HuggingFaceAPIChatGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, - tools=tools, - ) - - with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: - completion = ChatCompletionOutput( - choices=[ - ChatCompletionOutputComplete( - finish_reason="stop", - index=0, - message=ChatCompletionOutputMessage( - role="assistant", - content=None, - tool_calls=[ - ChatCompletionOutputToolCall( - function=ChatCompletionOutputFunctionDefinition( - arguments={"city": "Paris"}, name="weather", description=None - ), - id="0", - type="function", - ) - ], - ), - logprobs=None, - ) - ], - created=1729074760, - id="", - model="meta-llama/Llama-3.1-70B-Instruct", - system_fingerprint="2.3.2-dev0-sha-28bb7ae", - usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), - ) - mock_chat_completion.return_value = completion - - messages = [ChatMessage.from_user("What is the weather in Paris?")] - response = generator.run(messages=messages) - - assert isinstance(response, dict) - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert response["replies"][0].tool_calls[0].tool_name == "weather" - assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} - assert response["replies"][0].tool_calls[0].id == "0" - assert response["replies"][0].meta == { - "finish_reason": "stop", - "index": 0, - "model": "meta-llama/Llama-3.1-70B-Instruct", - "usage": {"completion_tokens": 30, "prompt_tokens": 426}, - } - - @pytest.mark.integration - @pytest.mark.skipif( - not os.environ.get("HF_API_TOKEN", None), - reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", - ) - def test_live_run_serverless(self): - generator = HuggingFaceAPIChatGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, - generation_kwargs={"max_tokens": 20}, - ) + assert [isinstance(reply, str) for reply in response["replies"]] - messages = [ChatMessage.from_user("What is the capital of France?")] - response = generator.run(messages=messages) - - assert "replies" in response - assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "usage" in response["replies"][0].meta - assert "prompt_tokens" in response["replies"][0].meta["usage"] - assert "completion_tokens" in response["replies"][0].meta["usage"] + # Assert that the response contains the metadata + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) > 0 + assert [isinstance(meta, dict) for meta in response["meta"]] + @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_live_run_serverless_streaming(self): - generator = HuggingFaceAPIChatGenerator( + def test_run_serverless(self): + generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, - generation_kwargs={"max_tokens": 20}, - streaming_callback=streaming_callback_handler, + generation_kwargs={"max_new_tokens": 20}, ) - messages = [ChatMessage.from_user("What is the capital of France?")] - response = generator.run(messages=messages) - + response = generator.run("How are you?") + # Assert that the response contains the generated replies assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "usage" in response["replies"][0].meta - assert "prompt_tokens" in response["replies"][0].meta["usage"] - assert "completion_tokens" in response["replies"][0].meta["usage"] - - @pytest.mark.integration - @pytest.mark.skipif( - not os.environ.get("HF_API_TOKEN", None), - reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", - ) - @pytest.mark.integration - def test_live_run_with_tools(self, tools): - """ - We test the round trip: generate tool call, pass tool message, generate response. - - The model used here (zephyr-7b-beta) is always available and not gated. - Even if it does not officially support tools, TGI+HF API make it work. - """ - - chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")] - generator = HuggingFaceAPIChatGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, - generation_kwargs={"temperature": 0.5}, - ) - - results = generator.run(chat_messages, tools=tools) - assert len(results["replies"]) == 1 - message = results["replies"][0] - - assert message.tool_calls - tool_call = message.tool_call - assert isinstance(tool_call, ToolCall) - assert tool_call.tool_name == "weather" - assert "city" in tool_call.arguments - assert "Paris" in tool_call.arguments["city"] - assert message.meta["finish_reason"] == "stop" - - new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] - - # the model tends to make tool calls if provided with tools, so we don't pass them here - results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) + assert [isinstance(reply, str) for reply in response["replies"]] - assert len(results["replies"]) == 1 - final_message = results["replies"][0] - assert not final_message.tool_calls - assert len(final_message.text) > 0 - assert "paris" in final_message.text.lower() + # Assert that the response contains the metadata + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) > 0 + assert [isinstance(meta, dict) for meta in response["meta"]] From 39c338364a136370d422b5882504867f857ead85 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Dec 2024 11:42:41 +0100 Subject: [PATCH 4/5] release note --- releasenotes/notes/hfapi-tools-a7224150bce52564.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 releasenotes/notes/hfapi-tools-a7224150bce52564.yaml diff --git a/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml b/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml new file mode 100644 index 0000000000..085ed35931 --- /dev/null +++ b/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add support for Tools in the Hugging Face API Chat Generator. From 6832e3155154cb42ada01d2d2bf5e280073cabc1 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Dec 2024 14:40:59 +0100 Subject: [PATCH 5/5] feedback --- .../components/generators/chat/hugging_face_api.py | 14 +++++++------- haystack/dataclasses/tool.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 8e7a6dac18..dab61e4d93 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -322,15 +322,15 @@ def _run_non_streaming( ) tool_calls.append(tool_call) - meta = { - "model": self._client.model, - "finish_reason": choice.finish_reason, - "index": choice.index, - "usage": { + meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index} + + usage = {"prompt_tokens": 0, "completion_tokens": 0} + if api_chat_output.usage: + usage = { "prompt_tokens": api_chat_output.usage.prompt_tokens, "completion_tokens": api_chat_output.usage.completion_tokens, - }, - } + } + meta["usage"] = usage message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) return {"replies": [message]} diff --git a/haystack/dataclasses/tool.py b/haystack/dataclasses/tool.py index 4aaf1e2bd1..c6606d51e8 100644 --- a/haystack/dataclasses/tool.py +++ b/haystack/dataclasses/tool.py @@ -218,7 +218,7 @@ def _remove_title_from_schema(schema: Dict[str, Any]): def _check_duplicate_tool_names(tools: List[Tool]) -> None: """ - Check for duplicate tool names. + Check for duplicate tool names and raises a ValueError if they are found. :param tools: The list of tools to check. :raises ValueError: If duplicate tool names are found.