Skip to content

Commit

Permalink
Make static
Browse files Browse the repository at this point in the history
  • Loading branch information
roaga committed Nov 21, 2024
1 parent 081cd2c commit 4f050ea
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
14 changes: 8 additions & 6 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,9 @@ def generate_text(
default_temperature = defaults.temperature if defaults else None
# More defaults to come

messages = self.clean_message_content(messages if messages else [])
messages = LlmClient.clean_message_content(messages if messages else [])
if not tools:
messages = self.clean_tool_call_assistant_messages(messages)
messages = LlmClient.clean_tool_call_assistant_messages(messages)

if model.provider_name == LlmProviderType.OPENAI:
model = cast(OpenAiProvider, model)
Expand Down Expand Up @@ -520,8 +520,8 @@ def generate_structured(
if run_name:
langfuse_context.update_current_observation(name=run_name + " - Generate Structured")

messages = self.clean_message_content(messages if messages else [])
messages = self.clean_tool_call_assistant_messages(messages)
messages = LlmClient.clean_message_content(messages if messages else [])
messages = LlmClient.clean_tool_call_assistant_messages(messages)

if model.provider_name == LlmProviderType.OPENAI:
model = cast(OpenAiProvider, model)
Expand All @@ -539,7 +539,8 @@ def generate_structured(
else:
raise ValueError(f"Invalid provider: {model.provider_name}")

def clean_tool_call_assistant_messages(self, messages: list[Message]) -> list[Message]:
@staticmethod
def clean_tool_call_assistant_messages(messages: list[Message]) -> list[Message]:
new_messages = []
for message in messages:
if message.role == "assistant" and message.tool_calls:
Expand All @@ -556,7 +557,8 @@ def clean_tool_call_assistant_messages(self, messages: list[Message]) -> list[Me
new_messages.append(message)
return new_messages

def clean_message_content(self, messages: list[Message]) -> list[Message]:
@staticmethod
def clean_message_content(messages: list[Message]) -> list[Message]:
new_messages = []
for message in messages:
if not message.content:
Expand Down
4 changes: 2 additions & 2 deletions tests/automation/agent/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_clean_tool_call_assistant_messages():
Message(role="assistant", content="Final response"),
]

cleaned_messages = LlmClient().clean_tool_call_assistant_messages(messages)
cleaned_messages = LlmClient.clean_tool_call_assistant_messages(messages)

assert len(cleaned_messages) == 5
assert cleaned_messages[0].role == "user"
Expand All @@ -302,7 +302,7 @@ def test_clean_message_content():
Message(role="user", content=""),
]

cleaned_messages = LlmClient().clean_message_content(messages)
cleaned_messages = LlmClient.clean_message_content(messages)

assert len(cleaned_messages) == 1
assert cleaned_messages[0].role == "user"
Expand Down

0 comments on commit 4f050ea

Please sign in to comment.