Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Runpod Provider #157

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llama_stack/distribution/templates/local-runpod-build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: local-runpod
distribution_spec:
description: Use Runpod.io for running LLM inference
providers:
inference: remote::runpod
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda
16 changes: 16 additions & 0 deletions llama_stack/providers/adapters/inference/runpod/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .config import RunpodImplConfig
from .runpod import RunpodInferenceAdapter

async def get_adapter_impl(config: RunpodImplConfig, _deps):
assert isinstance(
config, RunpodImplConfig
), f"Unexpected config type: {type(config)}"
impl = RunpodInferenceAdapter(config)
await impl.initialize()
return impl
16 changes: 16 additions & 0 deletions llama_stack/providers/adapters/inference/runpod/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field


@json_schema_type
class RunpodImplConfig(BaseModel):
url: Optional[str] = Field(
default=None,
description="The URL for the Runpod model serving endpoint",
)
api_key: Optional[str] = Field(
default=None,
description="The Runpod API token",
)
254 changes: 254 additions & 0 deletions llama_stack/providers/adapters/inference/runpod/runpod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import AsyncGenerator
import logging

from openai import OpenAI

from llama_models.llama3.api.chat_format import ChatFormat

from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model

from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)

from .config import RunpodImplConfig

RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct"
}

logger = logging.getLogger(__name__)

class RunpodInferenceAdapter(Inference):
def __init__(self, config: RunpodImplConfig) -> None:
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)

@property
def client(self) -> OpenAI:
return OpenAI(
base_url=self.config.url,
api_key=self.config.api_key
)

async def initialize(self) -> None:
return

async def shutdown(self) -> None:
pass

async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()

def _messages_to_runpod_messages(self, messages: list[Message]) -> list:
runpod_messages = []
for message in messages:
if message.role == "ipython":
role = "tool"
else:
role = message.role
runpod_messages.append({"role": role, "content": message.content})

return runpod_messages

def resolve_runpod_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in RUNPOD_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(RUNPOD_SUPPORTED_MODELS.keys())}"

return RUNPOD_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)

def get_runpod_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)

return options

async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)

messages = augment_messages_for_tools(request)
options = self.get_runpod_chat_options(request)
runpod_model = self.resolve_runpod_model(request.model)

if not request.stream:

r = self.client.chat.completions.create(
model=runpod_model,
messages=self._messages_to_runpod_messages(messages),
stream=False,
**options,
)

stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens

completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)

buffer = ""
ipython = False
stop_reason = None

for chunk in self.client.chat.completions.create(
model=runpod_model,
messages=self._messages_to_runpod_messages(messages),
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if (
stop_reason is None
and chunk.choices[0].finish_reason == "stop"
):
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break

text = chunk.choices[0].delta.content

if text is None:
continue

# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue

if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue

buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)

# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)

for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
11 changes: 11 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="runpod",
pip_packages=[
"openai",
],
module="llama_stack.providers.adapters.inference.runpod",
config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig",
),
),
]