From abe32f5cde6a742672f34ece9e577004819db3ca Mon Sep 17 00:00:00 2001 From: ggroup3 Date: Wed, 2 Oct 2024 17:24:03 -0500 Subject: [PATCH] [functionality] Implement completion() methods #168 - Implemented completion() method for meta-reference, fireworks, together, ollama, and bedrock providers - Each implementation handles both streaming and non-streaming responses - Converted single content to messages when necessary - Ensured proper yield of CompletionResponse or CompletionResponseStreamChunk objects --- .gitignore | 1 + .../adapters/inference/bedrock/bedrock.py | 27 +++++- .../adapters/inference/fireworks/fireworks.py | 47 ++++++++++- .../adapters/inference/ollama/ollama.py | 2 +- .../adapters/inference/together/together.py | 52 +++++++++++- .../meta_reference/inference/inference.py | 84 +------------------ 6 files changed, 120 insertions(+), 93 deletions(-) diff --git a/.gitignore b/.gitignore index 2465d2d4..d0772c2b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ xcuserdata/ Package.resolved *.pte *.ipynb_checkpoints* +venv/ diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 9c1db4bd..d5f619df 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -95,8 +95,31 @@ async def completion( sampling_params: Optional[SamplingParams] = SamplingParams(), stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - raise NotImplementedError() + ) -> AsyncGenerator[Union[CompletionResponse, CompletionResponseStreamChunk], None]: + # Convert the content to a single message + messages = [UserMessage(content=content)] + + # Use the existing chat_completion method + async for response in self.chat_completion( + model=model, + messages=messages, + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + ): + if isinstance(response, ChatCompletionResponse): + yield CompletionResponse( + completion=response.completion_message.content[0], + logprobs=response.logprobs, + ) + elif isinstance(response, ChatCompletionResponseStreamChunk): + yield CompletionResponseStreamChunk( + event=CompletionResponseEvent( + event_type=CompletionResponseEventType(response.event.event_type.value), + delta=response.event.delta, + stop_reason=response.event.stop_reason, + ) + ) @staticmethod def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason: diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index f6949cbd..50890992 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, Union from fireworks.client import Fireworks @@ -56,8 +56,49 @@ async def completion( sampling_params: Optional[SamplingParams] = SamplingParams(), stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> AsyncGenerator: - raise NotImplementedError() + ) -> AsyncGenerator[Union[CompletionResponse, CompletionResponseStreamChunk], None]: + fireworks_model = self.map_to_provider_model(model) + options = self.get_fireworks_chat_options(ChatCompletionRequest( + model=model, + messages=[], + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + )) + + if not stream: + response = await self.client.completions.create( + model=fireworks_model, + prompt=content, + stream=False, + **options, + ) + yield CompletionResponse( + completion=response.choices[0].text, + logprobs=None, # Fireworks doesn't provide logprobs + ) + else: + async for chunk in self.client.completions.create( + model=fireworks_model, + prompt=content, + stream=True, + **options, + ): + if chunk.choices[0].text: + yield CompletionResponseStreamChunk( + event=CompletionResponseEvent( + event_type=CompletionResponseEventType.progress, + delta=chunk.choices[0].text, + ) + ) + + yield CompletionResponseStreamChunk( + event=CompletionResponseEvent( + event_type=CompletionResponseEventType.complete, + delta="", + stop_reason=StopReason.end_of_turn, + ) + ) def _messages_to_fireworks_messages(self, messages: list[Message]) -> list: fireworks_messages = [] diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index c4d48af8..1b87a46e 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -262,4 +262,4 @@ async def chat_completion( delta="", stop_reason=stop_reason, ) - ) + ) \ No newline at end of file diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 9f73a81d..944a553f 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, Union from llama_models.llama3.api.chat_format import ChatFormat @@ -61,8 +61,52 @@ async def completion( sampling_params: Optional[SamplingParams] = SamplingParams(), stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> AsyncGenerator: - raise NotImplementedError() + ) -> AsyncGenerator[Union[CompletionResponse, CompletionResponseStreamChunk], None]: + together_api_key = self.config.api_key or self.get_request_provider_data().together_api_key + client = Together(api_key=together_api_key) + + together_model = self.map_to_provider_model(model) + options = self.get_together_chat_options(ChatCompletionRequest( + model=model, + messages=[], + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + )) + + if not stream: + response = client.completions.create( + model=together_model, + prompt=content, + stream=False, + **options, + ) + yield CompletionResponse( + completion=response.choices[0].text, + logprobs=None, # Together doesn't provide logprobs + ) + else: + for chunk in client.completions.create( + model=together_model, + prompt=content, + stream=True, + **options, + ): + if chunk.choices[0].text: + yield CompletionResponseStreamChunk( + event=CompletionResponseEvent( + event_type=CompletionResponseEventType.progress, + delta=chunk.choices[0].text, + ) + ) + + yield CompletionResponseStreamChunk( + event=CompletionResponseEvent( + event_type=CompletionResponseEventType.complete, + delta="", + stop_reason=StopReason.end_of_turn, + ) + ) def _messages_to_together_messages(self, messages: list[Message]) -> list: together_messages = [] @@ -262,4 +306,4 @@ async def chat_completion( delta="", stop_reason=stop_reason, ) - ) + ) \ No newline at end of file diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index e89d8ec4..cc6f22af 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -127,86 +127,4 @@ async def chat_completion( ), ) ) - buffer = buffer[len("<|python_tag|>") :] - continue - - if not request.stream: - if request.logprobs: - logprobs.append(token_result.logprob) - - continue - - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - - if ipython: - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - else: - delta = text - - if stop_reason is None: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - # TODO(ashwin): parse tool calls separately here and report errors? - # if someone breaks the iteration before coming here we are toast - message = self.generator.formatter.decode_assistant_message( - tokens, stop_reason - ) - if request.stream: - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) - - # TODO(ashwin): what else do we need to send out here when everything finishes? - else: - yield ChatCompletionResponse( - completion_message=message, - logprobs=logprobs if request.logprobs else None, - ) + buffer = buffer[len(" \ No newline at end of file