Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage; serialize chat_template #8663

Merged
merged 8 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)
Expand Down
152 changes: 107 additions & 45 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,25 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import (
ChatCompletionInputTool,
ChatCompletionOutput,
ChatCompletionStreamOutput,
InferenceClient,
)


logger = logging.getLogger(__name__)


def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]:
"""
Convert a message to the format expected by Hugging Face APIs.

:returns: A dictionary with the following keys:
- `role`
- `content`
"""
return {"role": message.role.value, "content": message.text or ""}


@component
class HuggingFaceAPIChatGenerator:
"""
Expand Down Expand Up @@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Initialize the HuggingFaceAPIChatGenerator instance.
Expand All @@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
`TEXT_GENERATION_INFERENCE`.
:param token: The Hugging Face token to use as HTTP bearer authorization.
:param token:
The Hugging Face token to use as HTTP bearer authorization.
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
:param generation_kwargs:
A dictionary with keyword arguments to customize text generation.
Some examples: `max_tokens`, `temperature`, `top_p`.
For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
:param stop_words:
An optional list of strings representing the stop words.
:param streaming_callback:
An optional callable for handling streaming responses.
:param tools:
A list of tools for which the model can prepare calls.
The chosen model should support tool/function calling, according to the model card.
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
unexpected behavior.
"""

huggingface_hub_import.check()
Expand Down Expand Up @@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
msg = f"Unknown api_type {api_type}"
raise ValueError(msg)

if tools:
if streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
generation_kwargs["stop"] = generation_kwargs.get("stop", [])
Expand All @@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.generation_kwargs = generation_kwargs
self.streaming_callback = streaming_callback
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
self.tools = tools

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -180,13 +190,15 @@ def to_dict(self) -> Dict[str, Any]:
A dictionary containing the serialized component.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
return default_to_dict(
self,
api_type=str(self.api_type),
api_params=self.api_params,
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
tools=serialized_tools,
)

@classmethod
Expand All @@ -195,32 +207,53 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_tools_inplace(data["init_parameters"], key="tools")
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.

:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param messages:
A list of ChatMessage objects representing the input messages.
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""

# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages]
formatted_messages = [convert_message_to_hf_format(message) for message in messages]

tools = tools or self.tools
if tools:
if self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

if self.streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs)

return self._run_non_streaming(formatted_messages, generation_kwargs)
hf_tools = None
if tools:
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]

return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)

def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
Expand All @@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict

generated_text = ""

for chunk in api_output: # pylint: disable=not-an-iterable
text = chunk.choices[0].delta.content
for chunk in api_output:
# n is unused, so the API always returns only one choice
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = chunk.choices[0]

text = choice.delta.content
if text:
generated_text += text
finish_reason = chunk.choices[0].finish_reason

finish_reason = choice.finish_reason

meta = {}
if finish_reason:
Expand All @@ -242,33 +281,56 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
stream_chunk = StreamingChunk(text, meta)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)

message = ChatMessage.from_assistant(generated_text)
message.meta.update(
meta.update(
{
"model": self._client.model,
"finish_reason": finish_reason,
"index": 0,
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
}
)

message = ChatMessage.from_assistant(text=generated_text, meta=meta)

return {"replies": [message]}

def _run_non_streaming(
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
tools: Optional[List["ChatCompletionInputTool"]] = None,
) -> Dict[str, List[ChatMessage]]:
chat_messages: List[ChatMessage] = []

api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
for choice in api_chat_output.choices:
message = ChatMessage.from_assistant(choice.message.content)
message.meta.update(
{
"model": self._client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
"usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0},
}
)
chat_messages.append(message)

return {"replies": chat_messages}
api_chat_output: ChatCompletionOutput = self._client.chat_completion(
messages=messages, tools=tools, **generation_kwargs
)

if len(api_chat_output.choices) == 0:
return {"replies": []}

# n is unused, so the API always returns only one choice
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = api_chat_output.choices[0]

text = choice.message.content
tool_calls = []

if hfapi_tool_calls := choice.message.tool_calls:
for hfapi_tc in hfapi_tool_calls:
tool_call = ToolCall(
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
)
tool_calls.append(tool_call)

meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}

usage = {"prompt_tokens": 0, "completion_tokens": 0}
if api_chat_output.usage:
usage = {
"prompt_tokens": api_chat_output.usage.prompt_tokens,
"completion_tokens": api_chat_output.usage.completion_tokens,
}
meta["usage"] = usage

message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
return {"replies": [message]}
6 changes: 5 additions & 1 deletion haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
HFTokenStreamingHandler,
StopWordsCriteria,
convert_message_to_hf_format,
deserialize_hf_model_kwargs,
serialize_hf_model_kwargs,
)
Expand Down Expand Up @@ -201,6 +202,7 @@ def to_dict(self) -> Dict[str, Any]:
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
token=self.token.to_dict() if self.token else None,
chat_template=self.chat_template,
)

huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
Expand Down Expand Up @@ -270,9 +272,11 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)

hf_messages = [convert_message_to_hf_format(message) for message in messages]

# Prepare the prompt for the model
prepared_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
)

# Avoid some unnecessary warnings in the generation pipeline call
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
Expand Down
15 changes: 14 additions & 1 deletion haystack/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional

from pydantic import create_model

Expand Down Expand Up @@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
del property_schema[key]


def _check_duplicate_tool_names(tools: List[Tool]) -> None:
"""
Check for duplicate tool names and raises a ValueError if they are found.

:param tools: The list of tools to check.
:raises ValueError: If duplicate tool names are found.
"""
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")


def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
"""
Deserialize Tools in a dictionary inplace.
Expand Down
Loading
Loading