Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤖 Implement Tool Message Validation in Client Logic #1456

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,14 @@ def generate_text(
system_prompt=system_prompt,
tools=tools,
)

# Validate message sequence before processing
self._validate_messages(messages)

anthropic_client = self.get_client()

completion = anthropic_client.messages.create(

system=system_prompt or NOT_GIVEN,
model=self.model_name,
tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN,
Expand Down Expand Up @@ -375,15 +379,22 @@ def _format_claude_response_to_message(completion: anthropic.types.Message) -> M
@staticmethod
def to_message_param(message: Message) -> MessageParam:
if message.role == "tool":
if message.tool_call_id:
# Only create tool_result if we have a valid tool_call_id
return MessageParam(
role="user",
content=[
ToolResultBlockParam(
type="tool_result",
content=message.content or "",
tool_use_id=message.tool_call_id,
)
],
)
# Fallback to regular user message if no tool_call_id
return MessageParam(
role="user",
content=[
ToolResultBlockParam(
type="tool_result",
content=message.content or "",
tool_use_id=message.tool_call_id or "",
)
],
role="user",
content=[TextBlockParam(type="text", text=message.content or "")]
)
elif message.role == "tool_use":
if not message.tool_calls:
Expand Down Expand Up @@ -550,6 +561,36 @@ def clean_tool_call_assistant_messages(messages: list[Message]) -> list[Message]
new_messages.append(message)
return new_messages

@staticmethod
def validate_tool_message_sequence(messages: list[Message] | None) -> list[Message]:
"""
Validates that tool result messages have corresponding tool use messages before them.
Returns a new message list with invalid tool messages converted to regular messages.
"""
if not messages:
return []

# Track valid tool use IDs we've seen
tool_use_ids = set()
validated_messages = []

for message in messages:
# Track valid tool use IDs
if message.role == "tool_use" and message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.id:
tool_use_ids.add(tool_call.id)
validated_messages.append(message)
elif message.role == "tool" and message.tool_call_id:
if message.tool_call_id not in tool_use_ids:
# Convert to regular user message if no matching tool use found
validated_messages.append(Message(role="user", content=message.content))
else:
validated_messages.append(message)
else:
validated_messages.append(message)
return validated_messages


@module.provider
def provide_llm_client() -> LlmClient:
Expand Down
65 changes: 65 additions & 0 deletions tests/automation/agent/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,68 @@ def test_anthropic_prep_message_and_tools():
assert "input_schema" in tool_dicts[0]

assert returned_system_prompt == system_prompt

def test_validate_tool_message_sequence():
messages = [
Message(role="user", content="Hello"),
Message(
role="tool_use",
content="Using search",
tool_calls=[ToolCall(id="search_1", function="search", args='{"query": "test"}')],
),
Message(role="tool", content="Search results", tool_call_id="search_1"),
Message(
role="tool_use",
content="Using another tool",
tool_calls=[ToolCall(id="tool_2", function="other", args='{"param": "value"}')],
),
# Invalid tool result without matching tool_use
Message(role="tool", content="Invalid tool result", tool_call_id="invalid_id"),
]

validated_messages = LlmClient.validate_tool_message_sequence(messages)

assert len(validated_messages) == 5
# First message unchanged
assert validated_messages[0].role == "user"
assert validated_messages[0].content == "Hello"

# Valid tool_use remains
assert validated_messages[1].role == "tool_use"
assert validated_messages[1].tool_calls[0].id == "search_1"

# Valid tool result remains
assert validated_messages[2].role == "tool"
assert validated_messages[2].tool_call_id == "search_1"

# Second valid tool_use remains
assert validated_messages[3].role == "tool_use"
assert validated_messages[3].tool_calls[0].id == "tool_2"

# Invalid tool result converted to user message
assert validated_messages[4].role == "user"
assert validated_messages[4].content == "Invalid tool result"
assert validated_messages[4].tool_call_id is None

def test_anthropic_provider_handles_invalid_tool_sequence(mock_anthropic_client):
llm_client = LlmClient()
model = AnthropicProvider.model("claude-3-sonnet-20240229")

messages = [
Message(role="user", content="Hello"),
# This would previously cause an error due to tool_result without tool_use
Message(role="tool", content="Invalid tool result", tool_call_id="invalid_id"),
]

mock_anthropic_client.messages.create.return_value = MockAnthropicResponse(
content="Response", role="assistant"
)

# This should now work without raising an error
response = llm_client.generate_text(
messages=messages,
model=model,
)

assert response.message.content == "Response"
assert response.message.role == "assistant"
Loading