Skip to content

Commit

Permalink
Merge pull request #872 from PrefectHQ/tools
Browse files Browse the repository at this point in the history
Tool → FunctionTool
  • Loading branch information
jlowin authored Mar 15, 2024
2 parents a77dfa1 + 3443daf commit 166c88d
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 59 deletions.
6 changes: 3 additions & 3 deletions cookbook/slackbot/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ async def handle_message(payload: SlackPayload) -> Completed:
ai_response_text,
"green",
)
messages = await assistant_thread.get_messages_async()

event = emit_assistant_completed_event(
child_assistant=ai,
parent_app=get_parent_app() if ENABLE_PARENT_APP else None,
payload={
"messages": await assistant_thread.get_messages_async(
json_compatible=True
),
"messages": [m.model_dump() for m in messages],
"metadata": assistant_thread.metadata,
"user": {
"id": event.user,
Expand Down
6 changes: 3 additions & 3 deletions src/marvin/_mappings/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode

from marvin.types import Function, Tool, ToolSet
from marvin.types import Function, FunctionTool, ToolSet


class FunctionSchema(GenerateJsonSchema):
Expand All @@ -15,10 +15,10 @@ def generate(self, schema: Any, mode: JsonSchemaMode = "validation"):

def cast_model_to_tool(
model: type[BaseModel],
) -> Tool[BaseModel]:
) -> FunctionTool[BaseModel]:
model_name = model.__name__
model_description = model.__doc__
return Tool[BaseModel](
return FunctionTool[BaseModel](
type="function",
function=Function[BaseModel](
name=model_name,
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/_mappings/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic.fields import FieldInfo

from marvin.settings import settings
from marvin.types import Grammar, Tool, ToolSet
from marvin.types import FunctionTool, Grammar, ToolSet

from .base_model import cast_model_to_tool, cast_model_to_toolset

Expand Down Expand Up @@ -46,7 +46,7 @@ def cast_type_to_tool(
field_name: str,
field_description: str,
python_function: Optional[Callable[..., Any]] = None,
) -> Tool[BaseModel]:
) -> FunctionTool[BaseModel]:
return cast_model_to_tool(
model=cast_type_to_model(
_type,
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/beta/applications/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jsonpatch import JsonPatch
from pydantic import BaseModel, Field, PrivateAttr, SerializeAsAny

from marvin.types import Tool
from marvin.types import FunctionTool
from marvin.utilities.tools import tool_from_function


Expand Down Expand Up @@ -66,7 +66,7 @@ def update_state_jsonpatches(self, patches: list[JSONPatchModel]):
self.set_state(state)
return "Application state updated successfully!"

def as_tool(self, name: str = None) -> "Tool":
def as_tool(self, name: str = None) -> "FunctionTool":
if name is None:
name = "state"
schema = self.get_schema()
Expand Down
30 changes: 10 additions & 20 deletions src/marvin/beta/assistants/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from openai.types.beta.threads.run import Run as OpenAIRun
from openai.types.beta.threads.runs import RunStep as OpenAIRunStep
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic import BaseModel, Field, field_validator

import marvin.utilities.openai
import marvin.utilities.tools
Expand Down Expand Up @@ -39,6 +39,7 @@ class Run(BaseModel, ExposeSyncMethodsMixin):
data (Any): Any additional data associated with the run.
"""

id: Optional[str] = None
thread: Thread
assistant: Assistant
instructions: Optional[str] = Field(
Expand Down Expand Up @@ -77,15 +78,15 @@ async def refresh_async(self):
"""Refreshes the run."""
client = marvin.utilities.openai.get_openai_client()
self.run = await client.beta.threads.runs.retrieve(
run_id=self.run.id, thread_id=self.thread.id
run_id=self.run.id if self.run else self.id, thread_id=self.thread.id
)

@expose_sync_method("cancel")
async def cancel_async(self):
"""Cancels the run."""
client = marvin.utilities.openai.get_openai_client()
await client.beta.threads.runs.cancel(
run_id=self.run.id, thread_id=self.thread.id
run_id=self.run.id if self.run else self.id, thread_id=self.thread.id
)

async def _handle_step_requires_action(
Expand Down Expand Up @@ -156,6 +157,10 @@ async def run_async(self) -> "Run":
if self.tools is not None or self.additional_tools is not None:
create_kwargs["tools"] = self.get_tools()

if self.id is not None:
raise ValueError(
"This run object was provided an ID; can not create a new run."
)
async with self.assistant:
self.run = await client.beta.threads.runs.create(
thread_id=self.thread.id,
Expand Down Expand Up @@ -195,25 +200,10 @@ async def run_async(self) -> "Run":


class RunMonitor(BaseModel):
run_id: str
thread_id: str
_run: Run = PrivateAttr()
_thread: Thread = PrivateAttr()
run: Run
thread: Thread
steps: list[OpenAIRunStep] = []

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._thread = Thread(**kwargs["thread_id"])
self._run = Run(**kwargs["run_id"], thread=self.thread)

@property
def thread(self):
return self._thread

@property
def run(self):
return self._run

async def refresh_run_steps_async(self):
"""
Asynchronously refreshes and updates the run steps list.
Expand Down
27 changes: 8 additions & 19 deletions src/marvin/beta/assistants/threads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import time
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional

# for openai < 1.14.0
try:
Expand All @@ -18,7 +18,6 @@
run_sync,
)
from marvin.utilities.logging import get_logger
from marvin.utilities.pydantic import parse_as

logger = get_logger("Threads")

Expand Down Expand Up @@ -100,25 +99,18 @@ async def get_messages_async(
limit: int = None,
before_message: Optional[str] = None,
after_message: Optional[str] = None,
json_compatible: bool = False,
) -> list[Union[Message, dict]]:
) -> list[Message]:
"""
Asynchronously retrieves messages from the thread.
Args:
limit (int, optional): The maximum number of messages to return.
before_message (str, optional): The ID of the message to start the list from,
retrieving messages sent before this one.
after_message (str, optional): The ID of the message to start the list from,
retrieving messages sent after this one.
json_compatible (bool, optional): If True, returns messages as dictionaries.
If False, returns messages as Message
objects. Default is False.
before_message (str, optional): The ID of the message to start the
list from, retrieving messages sent before this one.
after_message (str, optional): The ID of the message to start the
list from, retrieving messages sent after this one.
Returns:
list[Union[Message, dict]]: A list of messages from the thread, either
as dictionaries or Message objects,
depending on the value of json_compatible.
list[Union[Message, dict]]: A list of messages from the thread
"""

if self.id is None:
Expand All @@ -134,10 +126,7 @@ async def get_messages_async(
limit=limit,
order="desc",
)

T = dict if json_compatible else Message

return parse_as(list[T], reversed(response.model_dump()["data"]))
return response.data

@expose_sync_method("delete")
async def delete_async(self):
Expand Down
13 changes: 8 additions & 5 deletions src/marvin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,24 @@ def create(
return instance


class Tool(MarvinType, Generic[T]):
class Tool(MarvinType):
type: str


class FunctionTool(Tool, Generic[T]):
function: Optional[Function[T]] = None


class ToolSet(MarvinType, Generic[T]):
tools: Optional[list[Tool[T]]] = None
tools: Optional[list[Union[FunctionTool[T], Tool]]] = None
tool_choice: Optional[Union[Literal["auto"], dict[str, Any]]] = None


class RetrievalTool(Tool[T]):
class RetrievalTool(Tool):
type: Literal["retrieval"] = "retrieval"


class CodeInterpreterTool(Tool[T]):
class CodeInterpreterTool(Tool):
type: Literal["code_interpreter"] = "code_interpreter"


Expand Down Expand Up @@ -244,7 +247,7 @@ class Run(MarvinType, Generic[T]):
status: str
model: str
instructions: Optional[str]
tools: Optional[list[Tool[T]]] = None
tools: Optional[list[FunctionTool[T]]] = None
metadata: dict[str, str]


Expand Down
10 changes: 5 additions & 5 deletions src/marvin/utilities/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode

from marvin.types import Function, Tool
from marvin.types import Function, FunctionTool
from marvin.utilities.asyncio import run_sync
from marvin.utilities.logging import get_logger

Expand Down Expand Up @@ -63,7 +63,7 @@ def generate(self, schema: Any, mode: JsonSchemaMode = "validation"):
return json_schema


def tool_from_type(type_: U, tool_name: str = None) -> Tool[U]:
def tool_from_type(type_: U, tool_name: str = None) -> FunctionTool[U]:
"""
Creates an OpenAI-compatible tool from a Python type.
"""
Expand Down Expand Up @@ -99,7 +99,7 @@ def tool_from_model(model: type[M], python_fn: Callable[[str], M] = None):
def tool_fn(**data) -> M:
return TypeAdapter(model).validate_python(data)

return Tool[M](
return FunctionTool[M](
type="function",
function=Function[M].create(
name=model.__name__,
Expand Down Expand Up @@ -130,7 +130,7 @@ def tool_from_function(
fn, config=pydantic.ConfigDict(arbitrary_types_allowed=True)
).json_schema()

return Tool[T](
return FunctionTool[T](
type="function",
function=Function[T].create(
name=name or fn.__name__,
Expand All @@ -142,7 +142,7 @@ def tool_from_function(


def call_function_tool(
tools: list[Tool],
tools: list[FunctionTool],
function_name: str,
function_arguments_json: str,
return_string: bool = False,
Expand Down

0 comments on commit 166c88d

Please sign in to comment.