diff --git a/opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx b/opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx
index 06dd83061cb..b1510aaf89c 100644
--- a/opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx
+++ b/opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx
@@ -182,7 +182,8 @@ export function generateChatPrompt(
.join('\n')
: `- ${t(values.instruments.pipettes)}`
const flexGripper =
- values.instruments.flexGripper === FLEX_GRIPPER
+ values.instruments.flexGripper === FLEX_GRIPPER &&
+ values.instruments.robot === OPENTRONS_FLEX
? `\n- ${t('with_flex_gripper')}`
: ''
const modules = values.modules
diff --git a/opentrons-ai-server/api/domain/anthropic_predict.py b/opentrons-ai-server/api/domain/anthropic_predict.py
index abd94b631ba..ff392eefe7a 100644
--- a/opentrons-ai-server/api/domain/anthropic_predict.py
+++ b/opentrons-ai-server/api/domain/anthropic_predict.py
@@ -1,6 +1,6 @@
import uuid
from pathlib import Path
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Literal
import requests
import structlog
@@ -23,7 +23,7 @@ def __init__(self, settings: Settings) -> None:
self.model_name: str = settings.anthropic_model_name
self.system_prompt: str = SYSTEM_PROMPT
self.path_docs: Path = ROOT_PATH / "api" / "storage" / "docs"
- self._messages: List[MessageParam] = [
+ self.cashed_docs: List[MessageParam] = [
{
"role": "user",
"content": [
@@ -77,19 +77,26 @@ def get_docs(self) -> str:
return "\n".join(xml_output)
@tracer.wrap()
- def generate_message(self, max_tokens: int = 4096) -> Message:
+ def _process_message(
+ self, user_id: str, messages: List[MessageParam], message_type: Literal["create", "update"], max_tokens: int = 4096
+ ) -> Message:
+ """
+ Internal method to handle message processing with different system prompts.
+ For now, system prompt is the same.
+ """
- response = self.client.messages.create(
+ response: Message = self.client.messages.create(
model=self.model_name,
system=self.system_prompt,
max_tokens=max_tokens,
- messages=self._messages,
+ messages=messages,
tools=self.tools, # type: ignore
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
+ metadata={"user_id": user_id},
)
logger.info(
- "Token usage",
+ f"Token usage: {message_type.capitalize()}",
extra={
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens,
@@ -100,15 +107,23 @@ def generate_message(self, max_tokens: int = 4096) -> Message:
return response
@tracer.wrap()
- def predict(self, prompt: str) -> str | None:
+ def process_message(
+ self, user_id: str, prompt: str, history: List[MessageParam] | None = None, message_type: Literal["create", "update"] = "create"
+ ) -> str | None:
+ """Unified method for creating and updating messages"""
try:
- self._messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
- response = self.generate_message()
+ messages: List[MessageParam] = self.cashed_docs.copy()
+ if history:
+ messages += history
+
+ messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
+ response = self._process_message(user_id=user_id, messages=messages, message_type=message_type)
+
if response.content[-1].type == "tool_use":
tool_use = response.content[-1]
- self._messages.append({"role": "assistant", "content": response.content})
+ messages.append({"role": "assistant", "content": response.content})
result = self.handle_tool_use(tool_use.name, tool_use.input) # type: ignore
- self._messages.append(
+ messages.append(
{
"role": "user",
"content": [
@@ -120,25 +135,26 @@ def predict(self, prompt: str) -> str | None:
],
}
)
- follow_up = self.generate_message()
- response_text = follow_up.content[0].text # type: ignore
- self._messages.append({"role": "assistant", "content": response_text})
- return response_text
+ follow_up = self._process_message(user_id=user_id, messages=messages, message_type=message_type)
+ return follow_up.content[0].text # type: ignore
elif response.content[0].type == "text":
- response_text = response.content[0].text
- self._messages.append({"role": "assistant", "content": response_text})
- return response_text
+ return response.content[0].text
logger.error("Unexpected response type")
return None
- except IndexError as e:
- logger.error("Invalid response format", extra={"error": str(e)})
- return None
except Exception as e:
- logger.error("Error in predict method", extra={"error": str(e)})
+ logger.error(f"Error in {message_type} method", extra={"error": str(e)})
return None
+ @tracer.wrap()
+ def create(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None:
+ return self.process_message(user_id, prompt, history, "create")
+
+ @tracer.wrap()
+ def update(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None:
+ return self.process_message(user_id, prompt, history, "update")
+
@tracer.wrap()
def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
if func_name == "simulate_protocol":
@@ -148,17 +164,6 @@ def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
logger.error("Unknown tool", extra={"tool": func_name})
raise ValueError(f"Unknown tool: {func_name}")
- @tracer.wrap()
- def reset(self) -> None:
- self._messages = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore
- ],
- }
- ]
-
@tracer.wrap()
def simulate_protocol(self, protocol: str) -> str:
url = "https://Opentrons-simulator.hf.space/protocol"
@@ -197,8 +202,9 @@ def main() -> None:
settings = Settings()
llm = AnthropicPredict(settings)
- prompt = Prompt.ask("Type a prompt to send to the Anthropic API:")
- completion = llm.predict(prompt)
+ Prompt.ask("Type a prompt to send to the Anthropic API:")
+
+ completion = llm.create(user_id="1", prompt="hi", history=None)
print(completion)
diff --git a/opentrons-ai-server/api/domain/config_anthropic.py b/opentrons-ai-server/api/domain/config_anthropic.py
index 9d511012592..bab6a26c812 100644
--- a/opentrons-ai-server/api/domain/config_anthropic.py
+++ b/opentrons-ai-server/api/domain/config_anthropic.py
@@ -9,9 +9,6 @@
4. Flag potential safety or compatibility issues
5. Suggest protocol optimizations when appropriate
-Call protocol simulation tool to validate the code - only when it is called explicitly by the user.
-For all other queries, provide direct responses.
-
Important guidelines:
- Always verify labware compatibility before generating protocols
- Include appropriate error handling in generated code
@@ -28,26 +25,25 @@
"""
PROMPT = """
-Here are the inputs you will work with:
-
-
-{USER_PROMPT}
-
-
Follow these instructions to handle the user's prompt:
-1. Analyze the user's prompt to determine if it's:
+1. :
a) A request to generate a protocol
- b) A question about the Opentrons Python API v2
+ b) A question about the Opentrons Python API v2 or about details of protocol
c) A common task (e.g., value changes, OT-2 to Flex conversion, slot correction)
d) An unrelated or unclear request
+ e) A tool calling. If a user calls simulate protocol explicity, then call.
+ f) A greeting. Respond kindly.
+
+ Note: when you respond you dont need mention the category or the type.
-2. If the prompt is unrelated or unclear, ask the user for clarification. For example:
- I apologize, but your prompt seems unclear. Could you please provide more details?
+2. If the prompt is unrelated or unclear, ask the user for clarification.
+ I'm sorry, but your prompt seems unclear. Could you please provide more details?
+ You dont need to mention
-3. If the prompt is a question about the API, answer it using only the information
+3. If the prompt is a question about the API or details, answer it using only the information
provided in the section. Provide references and place them under the tag.
Format your response like this:
API answer:
@@ -86,8 +82,8 @@
}}
requirements = {{
- 'robotType': '[Robot type based on user prompt, OT-2 or Flex, default is OT-2]',
- 'apiLevel': '[apiLevel, default is 2.19 ]'
+ 'robotType': '[Robot type: OT-2(default) for Opentrons OT-2, Flex for Opentrons Flex]',
+ 'apiLevel': '[apiLevel, default: 2.19]'
}}
def run(protocol: protocol_api.ProtocolContext):
@@ -214,4 +210,10 @@ def run(protocol: protocol_api.ProtocolContext):
as a reference to generate a basic protocol.
Remember to use only the information provided in the . Do not introduce any external information or assumptions.
-"""
+
+Here are the inputs you will work with:
+
+
+{USER_PROMPT}
+
+"""
\ No newline at end of file
diff --git a/opentrons-ai-server/api/handler/fast.py b/opentrons-ai-server/api/handler/fast.py
index b93eb6580ce..6a94bad8733 100644
--- a/opentrons-ai-server/api/handler/fast.py
+++ b/opentrons-ai-server/api/handler/fast.py
@@ -199,10 +199,19 @@ async def create_chat_completion(
return ChatResponse(reply="Default fake response. ", fake=body.fake)
response: Optional[str] = None
+
+ if "Write a protocol using" in body.history[0]["content"]: # type: ignore
+ protocol_option = "create"
+ else:
+ protocol_option = "update"
+
if "openai" in settings.model.lower():
response = openai.predict(prompt=body.message, chat_completion_message_params=body.history)
else:
- response = claude.predict(prompt=body.message)
+ if protocol_option == "create":
+ response = claude.create(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore
+ else:
+ response = claude.update(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore
if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))
@@ -218,35 +227,36 @@ async def create_chat_completion(
@tracer.wrap()
@app.post(
- "/api/chat/updateProtocol",
+ "/api/chat/createProtocol",
response_model=Union[ChatResponse, ErrorResponse],
- summary="Updates protocol",
- description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.",
+ summary="Creates protocol",
+ description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.",
)
-async def update_protocol(
- body: UpdateProtocol, user: Annotated[User, Security(auth.verify)]
+async def create_protocol(
+ body: CreateProtocol, user: Annotated[User, Security(auth.verify)]
) -> Union[ChatResponse, ErrorResponse]: # noqa: B008
"""
Generate an updated protocol using LLM.
- - **request**: The HTTP request containing the existing protocol and other relevant parameters.
+ - **request**: The HTTP request containing the chat message.
- **returns**: A chat response or an error message.
"""
- logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user})
+ logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user})
try:
- if not body.protocol_text or body.protocol_text == "":
+
+ if not body.prompt or body.prompt == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump()
)
if body.fake:
- return ChatResponse(reply="Fake response", fake=bool(body.fake))
+ return ChatResponse(reply="Fake response", fake=body.fake)
response: Optional[str] = None
if "openai" in settings.model.lower():
- response = openai.predict(prompt=body.prompt, chat_completion_message_params=None)
+ response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None)
else:
- response = claude.predict(prompt=body.prompt)
+ response = claude.create(user_id=str(user.sub), prompt=body.prompt, history=None)
if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))
@@ -254,7 +264,7 @@ async def update_protocol(
return ChatResponse(reply=response, fake=bool(body.fake))
except Exception as e:
- logger.exception("Error processing protocol update")
+ logger.exception("Error processing protocol creation")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump()
) from e
@@ -262,36 +272,35 @@ async def update_protocol(
@tracer.wrap()
@app.post(
- "/api/chat/createProtocol",
+ "/api/chat/updateProtocol",
response_model=Union[ChatResponse, ErrorResponse],
- summary="Creates protocol",
- description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.",
+ summary="Updates protocol",
+ description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.",
)
-async def create_protocol(
- body: CreateProtocol, user: Annotated[User, Security(auth.verify)]
+async def update_protocol(
+ body: UpdateProtocol, user: Annotated[User, Security(auth.verify)]
) -> Union[ChatResponse, ErrorResponse]: # noqa: B008
"""
Generate an updated protocol using LLM.
- - **request**: The HTTP request containing the chat message.
+ - **request**: The HTTP request containing the existing protocol and other relevant parameters.
- **returns**: A chat response or an error message.
"""
- logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user})
+ logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user})
try:
-
- if not body.prompt or body.prompt == "":
+ if not body.protocol_text or body.protocol_text == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump()
)
if body.fake:
- return ChatResponse(reply="Fake response", fake=body.fake)
+ return ChatResponse(reply="Fake response", fake=bool(body.fake))
response: Optional[str] = None
if "openai" in settings.model.lower():
- response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None)
+ response = openai.predict(prompt=body.prompt, chat_completion_message_params=None)
else:
- response = claude.predict(prompt=str(body.model_dump()))
+ response = claude.update(user_id=str(user.sub), prompt=body.prompt, history=None)
if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))
@@ -299,7 +308,7 @@ async def create_protocol(
return ChatResponse(reply=response, fake=bool(body.fake))
except Exception as e:
- logger.exception("Error processing protocol creation")
+ logger.exception("Error processing protocol update")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump()
) from e
diff --git a/opentrons-ai-server/api/models/chat_request.py b/opentrons-ai-server/api/models/chat_request.py
index fb8c0942c9d..94d265e113d 100644
--- a/opentrons-ai-server/api/models/chat_request.py
+++ b/opentrons-ai-server/api/models/chat_request.py
@@ -24,9 +24,13 @@ class Chat(BaseModel):
Field(None, description="Chat history in the form of a list of messages. Type is from OpenAI's ChatCompletionMessageParam"),
]
+ChatOptions = Literal["update", "create"]
+ChatOptionsType = Annotated[Optional[ChatOptions], Field("create", description="which chat pathway did the user enter: create or update")]
+
class ChatRequest(BaseModel):
message: str = Field(..., description="The latest message to be processed.")
history: HistoryType
fake: bool = Field(True, description="When set to true, the response will be a fake. OpenAI API is not used.")
fake_key: FakeKeyType
+ chat_options: ChatOptionsType
diff --git a/opentrons-ai-server/tests/helpers/client.py b/opentrons-ai-server/tests/helpers/client.py
index bf5a7febb3c..3b3dcfa7511 100644
--- a/opentrons-ai-server/tests/helpers/client.py
+++ b/opentrons-ai-server/tests/helpers/client.py
@@ -65,7 +65,7 @@ def get_health(self) -> Response:
@timeit
def get_chat_completion(self, message: str, fake: bool = True, fake_key: Optional[FakeKeys] = None, bad_auth: bool = False) -> Response:
"""Call the /chat/completion endpoint and return the response."""
- request = ChatRequest(message=message, fake=fake, fake_key=fake_key, history=None)
+ request = ChatRequest(message=message, fake=fake, fake_key=fake_key, history=None, chat_options=None)
headers = self.standard_headers if not bad_auth else self.invalid_auth_headers
return self.httpx.post("/chat/completion", headers=headers, json=request.model_dump())