From 2bc58d298749a28372bb98a7b3c902786380ea69 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 19 Dec 2024 15:04:37 +0100 Subject: [PATCH] feat: support for tools in `HuggingFaceAPIChatGenerator` (#8661) * message conversion function * hfapi w tools * right test file + hf_hub version * release note * feedback --- .../hugging_face_api_document_embedder.py | 2 +- .../hugging_face_api_text_embedder.py | 2 +- .../generators/chat/hugging_face_api.py | 152 ++++++--- .../components/generators/hugging_face_api.py | 2 +- haystack/dataclasses/tool.py | 15 +- haystack/utils/hf.py | 42 ++- pyproject.toml | 2 +- .../notes/hfapi-tools-a7224150bce52564.yaml | 4 + .../generators/chat/test_hugging_face_api.py | 297 ++++++++++++++++-- test/dataclasses/test_tool.py | 16 + test/utils/test_hf.py | 59 +++- 11 files changed, 509 insertions(+), 84 deletions(-) create mode 100644 releasenotes/notes/hfapi-tools-a7224150bce52564.yaml diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py index 43f719e27d..459e386976 100644 --- a/haystack/components/embedders/hugging_face_api_document_embedder.py +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -14,7 +14,7 @@ from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) diff --git a/haystack/components/embedders/hugging_face_api_text_embedder.py b/haystack/components/embedders/hugging_face_api_text_embedder.py index f60a9e5fd7..2cd68d34da 100644 --- a/haystack/components/embedders/hugging_face_api_text_embedder.py +++ b/haystack/components/embedders/hugging_face_api_text_embedder.py @@ -11,7 +11,7 @@ from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 8711a9175a..dab61e4d93 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -5,30 +5,25 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable -from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model +from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import: - from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient +with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import: + from huggingface_hub import ( + ChatCompletionInputTool, + ChatCompletionOutput, + ChatCompletionStreamOutput, + InferenceClient, + ) logger = logging.getLogger(__name__) -def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: - """ - Convert a message to the format expected by Hugging Face APIs. - - :returns: A dictionary with the following keys: - - `role` - - `content` - """ - return {"role": message.role.value, "content": message.text or ""} - - @component class HuggingFaceAPIChatGenerator: """ @@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, ): """ Initialize the HuggingFaceAPIChatGenerator instance. @@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_GENERATION_INFERENCE`. - :param token: The Hugging Face token to use as HTTP bearer authorization. + :param token: + The Hugging Face token to use as HTTP bearer authorization. Check your HF token in your [account settings](https://huggingface.co/settings/tokens). :param generation_kwargs: A dictionary with keyword arguments to customize text generation. Some examples: `max_tokens`, `temperature`, `top_p`. For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion). - :param stop_words: An optional list of strings representing the stop words. - :param streaming_callback: An optional callable for handling streaming responses. + :param stop_words: + An optional list of strings representing the stop words. + :param streaming_callback: + An optional callable for handling streaming responses. + :param tools: + A list of tools for which the model can prepare calls. + The chosen model should support tool/function calling, according to the model card. + Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience + unexpected behavior. """ huggingface_hub_import.check() @@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments msg = f"Unknown api_type {api_type}" raise ValueError(msg) + if tools: + if streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) + # handle generation kwargs setup generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} generation_kwargs["stop"] = generation_kwargs.get("stop", []) @@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments self.generation_kwargs = generation_kwargs self.streaming_callback = streaming_callback self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) + self.tools = tools def to_dict(self) -> Dict[str, Any]: """ @@ -180,6 +190,7 @@ def to_dict(self) -> Dict[str, Any]: A dictionary containing the serialized component. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, api_type=str(self.api_type), @@ -187,6 +198,7 @@ def to_dict(self) -> Dict[str, Any]: token=self.token.to_dict() if self.token else None, generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, + tools=serialized_tools, ) @classmethod @@ -195,6 +207,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": Deserialize this component from a dictionary. """ deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: @@ -202,12 +215,22 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): """ Invoke the text generation inference based on the provided messages and generation parameters. - :param messages: A list of ChatMessage objects representing the input messages. - :param generation_kwargs: Additional keyword arguments for text generation. + :param messages: + A list of ChatMessage objects representing the input messages. + :param generation_kwargs: + Additional keyword arguments for text generation. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage objects. """ @@ -215,12 +238,22 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # update generation kwargs by merging with the default ones generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages] + formatted_messages = [convert_message_to_hf_format(message) for message in messages] + + tools = tools or self.tools + if tools: + if self.streaming_callback: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) if self.streaming_callback: return self._run_streaming(formatted_messages, generation_kwargs) - return self._run_non_streaming(formatted_messages, generation_kwargs) + hf_tools = None + if tools: + hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] + + return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]): api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion( @@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict generated_text = "" - for chunk in api_output: # pylint: disable=not-an-iterable - text = chunk.choices[0].delta.content + for chunk in api_output: + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = chunk.choices[0] + + text = choice.delta.content if text: generated_text += text - finish_reason = chunk.choices[0].finish_reason + + finish_reason = choice.finish_reason meta = {} if finish_reason: @@ -242,8 +281,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict stream_chunk = StreamingChunk(text, meta) self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) - message = ChatMessage.from_assistant(generated_text) - message.meta.update( + meta.update( { "model": self._client.model, "finish_reason": finish_reason, @@ -251,24 +289,48 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict "usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming } ) + + message = ChatMessage.from_assistant(text=generated_text, meta=meta) + return {"replies": [message]} def _run_non_streaming( - self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any] + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + tools: Optional[List["ChatCompletionInputTool"]] = None, ) -> Dict[str, List[ChatMessage]]: - chat_messages: List[ChatMessage] = [] - - api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs) - for choice in api_chat_output.choices: - message = ChatMessage.from_assistant(choice.message.content) - message.meta.update( - { - "model": self._client.model, - "finish_reason": choice.finish_reason, - "index": choice.index, - "usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0}, - } - ) - chat_messages.append(message) - - return {"replies": chat_messages} + api_chat_output: ChatCompletionOutput = self._client.chat_completion( + messages=messages, tools=tools, **generation_kwargs + ) + + if len(api_chat_output.choices) == 0: + return {"replies": []} + + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = api_chat_output.choices[0] + + text = choice.message.content + tool_calls = [] + + if hfapi_tool_calls := choice.message.tool_calls: + for hfapi_tc in hfapi_tool_calls: + tool_call = ToolCall( + tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id + ) + tool_calls.append(tool_call) + + meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index} + + usage = {"prompt_tokens": 0, "completion_tokens": 0} + if api_chat_output.usage: + usage = { + "prompt_tokens": api_chat_output.usage.prompt_tokens, + "completion_tokens": api_chat_output.usage.completion_tokens, + } + meta["usage"] = usage + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + return {"replies": [message]} diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index a164c8c56c..a44ad94575 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -12,7 +12,7 @@ from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model from haystack.utils.url_validation import is_valid_http_url -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import ( InferenceClient, TextGenerationOutput, diff --git a/haystack/dataclasses/tool.py b/haystack/dataclasses/tool.py index 3df3fd18f2..c6606d51e8 100644 --- a/haystack/dataclasses/tool.py +++ b/haystack/dataclasses/tool.py @@ -4,7 +4,7 @@ import inspect from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional from pydantic import create_model @@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]): del property_schema[key] +def _check_duplicate_tool_names(tools: List[Tool]) -> None: + """ + Check for duplicate tool names and raises a ValueError if they are found. + + :param tools: The list of tools to check. + :raises ValueError: If duplicate tool names are found. + """ + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): """ Deserialize Tools in a dictionary inplace. diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 537b05e232..6a83594ada 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from haystack import logging -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice @@ -16,7 +16,7 @@ with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_import: import torch -with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import: +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: from huggingface_hub import HfApi, InferenceClient, model_info from huggingface_hub.utils import RepositoryNotFoundError @@ -270,6 +270,44 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte ) +def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]: + """ + Convert a message to the format expected by Hugging Face. + """ + text_contents = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not text_contents and not tool_calls and not tool_call_results: + raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") + if len(text_contents) + len(tool_call_results) > 1: + raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") + + # HF always expects a content field, even if it is empty + hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""} + + if tool_call_results: + result = tool_call_results[0] + hf_msg["content"] = result.result + if tc_id := result.origin.id: + hf_msg["tool_call_id"] = tc_id + # HF does not provide a way to communicate errors in tool invocations, so we ignore the error field + return hf_msg + + if text_contents: + hf_msg["content"] = text_contents[0] + if tool_calls: + hf_tool_calls = [] + for tc in tool_calls: + hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}} + if tc.id is not None: + hf_tool_call["id"] = tc.id + hf_tool_calls.append(hf_tool_call) + hf_msg["tool_calls"] = hf_tool_calls + + return hf_msg + + with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import: from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer diff --git a/pyproject.toml b/pyproject.toml index c1fddc8704..6a76a2e9c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ extra-dependencies = [ "numpy>=2", # Haystack is compatible both with numpy 1.x and 2.x, but we test with 2.x "transformers[torch,sentencepiece]==4.44.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators... - "huggingface_hub>=0.23.0", # Hugging Face API Generators and Embedders + "huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders "sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder "langdetect", # TextLanguageRouter and DocumentLanguageClassifier "openai-whisper>=20231106", # LocalWhisperTranscriber diff --git a/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml b/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml new file mode 100644 index 0000000000..085ed35931 --- /dev/null +++ b/releasenotes/notes/hfapi-tools-a7224150bce52564.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add support for Tools in the Hugging Face API Chat Generator. diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index e60ec863ab..0d0857e22a 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -5,23 +5,46 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from haystack import Pipeline +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType from huggingface_hub import ( ChatCompletionOutput, - ChatCompletionStreamOutput, ChatCompletionOutputComplete, - ChatCompletionStreamOutputChoice, + ChatCompletionOutputFunctionDefinition, ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputDelta, ) from huggingface_hub.utils import RepositoryNotFoundError -from haystack.components.generators.chat.hugging_face_api import ( - HuggingFaceAPIChatGenerator, - _convert_message_to_hfapi_format, -) -from haystack.dataclasses import ChatMessage, StreamingChunk -from haystack.utils.auth import Secret -from haystack.utils.hf import HFGenerationAPIType +from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator +from haystack.dataclasses import ChatMessage, Tool, ToolCall + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), + ChatMessage.from_user("Tell me about Berlin"), + ] + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] @pytest.fixture @@ -48,7 +71,7 @@ def mock_chat_completion(): id="some_id", model="some_model", system_fingerprint="some_fingerprint", - usage={"completion_tokens": 10, "prompt_tokens": 5, "total_tokens": 15}, + usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), created=1710498360, ) @@ -61,15 +84,7 @@ def streaming_callback_handler(x): return x -def test_convert_message_to_hfapi_format(): - message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_hfapi_format(message) == {"role": "system", "content": "You are good assistant"} - - message = ChatMessage.from_user("I have a question") - assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"} - - -class TestHuggingFaceAPIGenerator: +class TestHuggingFaceAPIChatGenerator: def test_init_invalid_api_type(self): with pytest.raises(ValueError): HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) @@ -93,6 +108,29 @@ def test_init_serverless(self, mock_check_valid_model): assert generator.api_params == {"model": model} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None + + def test_init_serverless_with_tools(self, mock_check_valid_model, tools): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"temperature": 0.6} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": model}, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + tools=tools, + ) + + assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert generator.api_params == {"model": model} + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.streaming_callback == streaming_callback + assert generator.tools == tools def test_init_serverless_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") @@ -126,6 +164,7 @@ def test_init_tgi(self): assert generator.api_params == {"url": url} assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None def test_init_tgi_invalid_url(self): with pytest.raises(ValueError): @@ -139,12 +178,33 @@ def test_init_tgi_no_url(self): api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} ) + def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=duplicate_tools, + ) + + def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=tools, + streaming_callback=streaming_callback_handler, + ) + def test_to_dict(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], + tools=[tool], ) result = generator.to_dict() @@ -154,15 +214,26 @@ def test_to_dict(self, mock_check_valid_model): assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert init_params["streaming_callback"] is None + assert init_params["tools"] == [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + } + ] def test_from_dict(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, + tools=[tool], ) result = generator.to_dict() @@ -172,11 +243,57 @@ def test_from_dict(self, mock_check_valid_model): assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} - assert generator_2.streaming_callback is streaming_callback_handler + assert generator_2.streaming_callback is None + assert generator_2.tools == [tool] + + def test_serde_in_pipeline(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) - def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_chat_completion, chat_messages - ): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + token=Secret.from_env_var("ENV_VAR", strict=False), + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + tools=[tool], + ) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", + "init_parameters": { + "api_type": "serverless_inference_api", + "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, + "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, + "streaming_callback": None, + "tools": [ + { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + } + ], + }, + } + }, + "connections": [], + } + + pipeline_yaml = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, @@ -187,9 +304,19 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( response = generator.run(messages=chat_messages) - # check kwargs passed to text_generation + # check kwargs passed to chat_completion _, kwargs = mock_chat_completion.call_args - assert kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + hf_messages = [ + {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, + {"role": "user", "content": "Tell me about Berlin"}, + ] + assert kwargs == { + "temperature": 0.6, + "stop": ["stop", "words"], + "max_tokens": 512, + "tools": None, + "messages": hf_messages, + } assert isinstance(response, dict) assert "replies" in response @@ -197,7 +324,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): + def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): streaming_call_count = 0 # Define the streaming callback function @@ -260,13 +387,78 @@ def mock_iter(self): assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - @pytest.mark.flaky(reruns=5, reruns_delay=5) + def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): + component = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + streaming_callback=streaming_callback_handler, + ) + + with pytest.raises(ValueError): + message = ChatMessage.from_user("irrelevant") + component.run([message], tools=tools) + + def test_run_with_tools(self, mock_check_valid_model, tools): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, + tools=tools, + ) + + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="stop", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionOutputToolCall( + function=ChatCompletionOutputFunctionDefinition( + arguments={"city": "Paris"}, name="weather", description=None + ), + id="0", + type="function", + ) + ], + ), + logprobs=None, + ) + ], + created=1729074760, + id="", + model="meta-llama/Llama-3.1-70B-Instruct", + system_fingerprint="2.3.2-dev0-sha-28bb7ae", + usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), + ) + mock_chat_completion.return_value = completion + + messages = [ChatMessage.from_user("What is the weather in Paris?")] + response = generator.run(messages=messages) + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert response["replies"][0].tool_calls[0].tool_name == "weather" + assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} + assert response["replies"][0].tool_calls[0].id == "0" + assert response["replies"][0].meta == { + "finish_reason": "stop", + "index": 0, + "model": "meta-llama/Llama-3.1-70B-Instruct", + "usage": {"completion_tokens": 30, "prompt_tokens": 426}, + } + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_run_serverless(self): + def test_live_run_serverless(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, @@ -284,13 +476,12 @@ def test_run_serverless(self): assert "prompt_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"] - @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_run_serverless_streaming(self): + def test_live_run_serverless_streaming(self): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, @@ -308,3 +499,47 @@ def test_run_serverless_streaming(self): assert "usage" in response["replies"][0].meta assert "prompt_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + """ + We test the round trip: generate tool call, pass tool message, generate response. + + The model used here (zephyr-7b-beta) is always available and not gated. + Even if it does not officially support tools, TGI+HF API make it work. + """ + + chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")] + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"temperature": 0.5}, + ) + + results = generator.run(chat_messages, tools=tools) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert "city" in tool_call.arguments + assert "Paris" in tool_call.arguments["city"] + assert message.meta["finish_reason"] == "stop" + + new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + + # the model tends to make tool calls if provided with tools, so we don't pass them here + results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() diff --git a/test/dataclasses/test_tool.py b/test/dataclasses/test_tool.py index db9719a7f3..9e112853f3 100644 --- a/test/dataclasses/test_tool.py +++ b/test/dataclasses/test_tool.py @@ -12,6 +12,7 @@ ToolInvocationError, _remove_title_from_schema, deserialize_tools_inplace, + _check_duplicate_tool_names, ) try: @@ -303,3 +304,18 @@ def test_remove_title_from_schema_handle_no_title_in_top_level(): "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, "type": "object", } + + +def test_check_duplicate_tool_names(): + tools = [ + Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), + Tool(name="weather", description="A different description", parameters=parameters, function=get_weather_report), + ] + with pytest.raises(ValueError): + _check_duplicate_tool_names(tools) + + tools = [ + Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), + Tool(name="weather2", description="Get weather report", parameters=parameters, function=get_weather_report), + ] + _check_duplicate_tool_names(tools) diff --git a/test/utils/test_hf.py b/test/utils/test_hf.py index 4350fb9fbb..d75e0b7501 100644 --- a/test/utils/test_hf.py +++ b/test/utils/test_hf.py @@ -2,8 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -from haystack.utils.hf import resolve_hf_device_map + +import pytest + +from haystack.utils.hf import resolve_hf_device_map, convert_message_to_hf_format from haystack.utils.device import ComponentDevice +from haystack.dataclasses import ChatMessage, ToolCall, ChatRole, TextContent def test_resolve_hf_device_map_only_device(): @@ -23,3 +27,56 @@ def test_resolve_hf_device_map_device_and_device_map(caplog): ) assert "The parameters `device` and `device_map` from `model_kwargs` are both provided." in caplog.text assert model_kwargs["device_map"] == "cuda:0" + + +def test_convert_message_to_hf_format(): + message = ChatMessage.from_system("You are good assistant") + assert convert_message_to_hf_format(message) == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert convert_message_to_hf_format(message) == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) + assert convert_message_to_hf_format(message) == {"role": "assistant", "content": "I have an answer"} + + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + assert convert_message_to_hf_format(message) == { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}} + ], + } + + message = ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather", arguments={"city": "Paris"})]) + assert convert_message_to_hf_format(message) == { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}}], + } + + tool_result = {"weather": "sunny", "temperature": "25"} + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ) + assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result, "tool_call_id": "123"} + + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(tool_name="weather", arguments={"city": "Paris"}) + ) + assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result} + + +def test_convert_message_to_hf_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + convert_message_to_hf_format(message) + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")], + ) + with pytest.raises(ValueError): + convert_message_to_hf_format(message)