diff --git a/opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx b/opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx index 06dd83061cb..b1510aaf89c 100644 --- a/opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx +++ b/opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx @@ -182,7 +182,8 @@ export function generateChatPrompt( .join('\n') : `- ${t(values.instruments.pipettes)}` const flexGripper = - values.instruments.flexGripper === FLEX_GRIPPER + values.instruments.flexGripper === FLEX_GRIPPER && + values.instruments.robot === OPENTRONS_FLEX ? `\n- ${t('with_flex_gripper')}` : '' const modules = values.modules diff --git a/opentrons-ai-server/api/domain/anthropic_predict.py b/opentrons-ai-server/api/domain/anthropic_predict.py index abd94b631ba..ff392eefe7a 100644 --- a/opentrons-ai-server/api/domain/anthropic_predict.py +++ b/opentrons-ai-server/api/domain/anthropic_predict.py @@ -1,6 +1,6 @@ import uuid from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Literal import requests import structlog @@ -23,7 +23,7 @@ def __init__(self, settings: Settings) -> None: self.model_name: str = settings.anthropic_model_name self.system_prompt: str = SYSTEM_PROMPT self.path_docs: Path = ROOT_PATH / "api" / "storage" / "docs" - self._messages: List[MessageParam] = [ + self.cashed_docs: List[MessageParam] = [ { "role": "user", "content": [ @@ -77,19 +77,26 @@ def get_docs(self) -> str: return "\n".join(xml_output) @tracer.wrap() - def generate_message(self, max_tokens: int = 4096) -> Message: + def _process_message( + self, user_id: str, messages: List[MessageParam], message_type: Literal["create", "update"], max_tokens: int = 4096 + ) -> Message: + """ + Internal method to handle message processing with different system prompts. + For now, system prompt is the same. + """ - response = self.client.messages.create( + response: Message = self.client.messages.create( model=self.model_name, system=self.system_prompt, max_tokens=max_tokens, - messages=self._messages, + messages=messages, tools=self.tools, # type: ignore extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, + metadata={"user_id": user_id}, ) logger.info( - "Token usage", + f"Token usage: {message_type.capitalize()}", extra={ "input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, @@ -100,15 +107,23 @@ def generate_message(self, max_tokens: int = 4096) -> Message: return response @tracer.wrap() - def predict(self, prompt: str) -> str | None: + def process_message( + self, user_id: str, prompt: str, history: List[MessageParam] | None = None, message_type: Literal["create", "update"] = "create" + ) -> str | None: + """Unified method for creating and updating messages""" try: - self._messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)}) - response = self.generate_message() + messages: List[MessageParam] = self.cashed_docs.copy() + if history: + messages += history + + messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)}) + response = self._process_message(user_id=user_id, messages=messages, message_type=message_type) + if response.content[-1].type == "tool_use": tool_use = response.content[-1] - self._messages.append({"role": "assistant", "content": response.content}) + messages.append({"role": "assistant", "content": response.content}) result = self.handle_tool_use(tool_use.name, tool_use.input) # type: ignore - self._messages.append( + messages.append( { "role": "user", "content": [ @@ -120,25 +135,26 @@ def predict(self, prompt: str) -> str | None: ], } ) - follow_up = self.generate_message() - response_text = follow_up.content[0].text # type: ignore - self._messages.append({"role": "assistant", "content": response_text}) - return response_text + follow_up = self._process_message(user_id=user_id, messages=messages, message_type=message_type) + return follow_up.content[0].text # type: ignore elif response.content[0].type == "text": - response_text = response.content[0].text - self._messages.append({"role": "assistant", "content": response_text}) - return response_text + return response.content[0].text logger.error("Unexpected response type") return None - except IndexError as e: - logger.error("Invalid response format", extra={"error": str(e)}) - return None except Exception as e: - logger.error("Error in predict method", extra={"error": str(e)}) + logger.error(f"Error in {message_type} method", extra={"error": str(e)}) return None + @tracer.wrap() + def create(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None: + return self.process_message(user_id, prompt, history, "create") + + @tracer.wrap() + def update(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None: + return self.process_message(user_id, prompt, history, "update") + @tracer.wrap() def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str: if func_name == "simulate_protocol": @@ -148,17 +164,6 @@ def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str: logger.error("Unknown tool", extra={"tool": func_name}) raise ValueError(f"Unknown tool: {func_name}") - @tracer.wrap() - def reset(self) -> None: - self._messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore - ], - } - ] - @tracer.wrap() def simulate_protocol(self, protocol: str) -> str: url = "https://Opentrons-simulator.hf.space/protocol" @@ -197,8 +202,9 @@ def main() -> None: settings = Settings() llm = AnthropicPredict(settings) - prompt = Prompt.ask("Type a prompt to send to the Anthropic API:") - completion = llm.predict(prompt) + Prompt.ask("Type a prompt to send to the Anthropic API:") + + completion = llm.create(user_id="1", prompt="hi", history=None) print(completion) diff --git a/opentrons-ai-server/api/domain/config_anthropic.py b/opentrons-ai-server/api/domain/config_anthropic.py index 9d511012592..bab6a26c812 100644 --- a/opentrons-ai-server/api/domain/config_anthropic.py +++ b/opentrons-ai-server/api/domain/config_anthropic.py @@ -9,9 +9,6 @@ 4. Flag potential safety or compatibility issues 5. Suggest protocol optimizations when appropriate -Call protocol simulation tool to validate the code - only when it is called explicitly by the user. -For all other queries, provide direct responses. - Important guidelines: - Always verify labware compatibility before generating protocols - Include appropriate error handling in generated code @@ -28,26 +25,25 @@ """ PROMPT = """ -Here are the inputs you will work with: - - -{USER_PROMPT} - - Follow these instructions to handle the user's prompt: -1. Analyze the user's prompt to determine if it's: +1. : a) A request to generate a protocol - b) A question about the Opentrons Python API v2 + b) A question about the Opentrons Python API v2 or about details of protocol c) A common task (e.g., value changes, OT-2 to Flex conversion, slot correction) d) An unrelated or unclear request + e) A tool calling. If a user calls simulate protocol explicity, then call. + f) A greeting. Respond kindly. + + Note: when you respond you dont need mention the category or the type. -2. If the prompt is unrelated or unclear, ask the user for clarification. For example: - I apologize, but your prompt seems unclear. Could you please provide more details? +2. If the prompt is unrelated or unclear, ask the user for clarification. + I'm sorry, but your prompt seems unclear. Could you please provide more details? + You dont need to mention -3. If the prompt is a question about the API, answer it using only the information +3. If the prompt is a question about the API or details, answer it using only the information provided in the section. Provide references and place them under the tag. Format your response like this: API answer: @@ -86,8 +82,8 @@ }} requirements = {{ - 'robotType': '[Robot type based on user prompt, OT-2 or Flex, default is OT-2]', - 'apiLevel': '[apiLevel, default is 2.19 ]' + 'robotType': '[Robot type: OT-2(default) for Opentrons OT-2, Flex for Opentrons Flex]', + 'apiLevel': '[apiLevel, default: 2.19]' }} def run(protocol: protocol_api.ProtocolContext): @@ -214,4 +210,10 @@ def run(protocol: protocol_api.ProtocolContext): as a reference to generate a basic protocol. Remember to use only the information provided in the . Do not introduce any external information or assumptions. -""" + +Here are the inputs you will work with: + + +{USER_PROMPT} + +""" \ No newline at end of file diff --git a/opentrons-ai-server/api/handler/fast.py b/opentrons-ai-server/api/handler/fast.py index b93eb6580ce..6a94bad8733 100644 --- a/opentrons-ai-server/api/handler/fast.py +++ b/opentrons-ai-server/api/handler/fast.py @@ -199,10 +199,19 @@ async def create_chat_completion( return ChatResponse(reply="Default fake response. ", fake=body.fake) response: Optional[str] = None + + if "Write a protocol using" in body.history[0]["content"]: # type: ignore + protocol_option = "create" + else: + protocol_option = "update" + if "openai" in settings.model.lower(): response = openai.predict(prompt=body.message, chat_completion_message_params=body.history) else: - response = claude.predict(prompt=body.message) + if protocol_option == "create": + response = claude.create(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore + else: + response = claude.update(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore if response is None or response == "": return ChatResponse(reply="No response was generated", fake=bool(body.fake)) @@ -218,35 +227,36 @@ async def create_chat_completion( @tracer.wrap() @app.post( - "/api/chat/updateProtocol", + "/api/chat/createProtocol", response_model=Union[ChatResponse, ErrorResponse], - summary="Updates protocol", - description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.", + summary="Creates protocol", + description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.", ) -async def update_protocol( - body: UpdateProtocol, user: Annotated[User, Security(auth.verify)] +async def create_protocol( + body: CreateProtocol, user: Annotated[User, Security(auth.verify)] ) -> Union[ChatResponse, ErrorResponse]: # noqa: B008 """ Generate an updated protocol using LLM. - - **request**: The HTTP request containing the existing protocol and other relevant parameters. + - **request**: The HTTP request containing the chat message. - **returns**: A chat response or an error message. """ - logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user}) + logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user}) try: - if not body.protocol_text or body.protocol_text == "": + + if not body.prompt or body.prompt == "": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump() ) if body.fake: - return ChatResponse(reply="Fake response", fake=bool(body.fake)) + return ChatResponse(reply="Fake response", fake=body.fake) response: Optional[str] = None if "openai" in settings.model.lower(): - response = openai.predict(prompt=body.prompt, chat_completion_message_params=None) + response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None) else: - response = claude.predict(prompt=body.prompt) + response = claude.create(user_id=str(user.sub), prompt=body.prompt, history=None) if response is None or response == "": return ChatResponse(reply="No response was generated", fake=bool(body.fake)) @@ -254,7 +264,7 @@ async def update_protocol( return ChatResponse(reply=response, fake=bool(body.fake)) except Exception as e: - logger.exception("Error processing protocol update") + logger.exception("Error processing protocol creation") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump() ) from e @@ -262,36 +272,35 @@ async def update_protocol( @tracer.wrap() @app.post( - "/api/chat/createProtocol", + "/api/chat/updateProtocol", response_model=Union[ChatResponse, ErrorResponse], - summary="Creates protocol", - description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.", + summary="Updates protocol", + description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.", ) -async def create_protocol( - body: CreateProtocol, user: Annotated[User, Security(auth.verify)] +async def update_protocol( + body: UpdateProtocol, user: Annotated[User, Security(auth.verify)] ) -> Union[ChatResponse, ErrorResponse]: # noqa: B008 """ Generate an updated protocol using LLM. - - **request**: The HTTP request containing the chat message. + - **request**: The HTTP request containing the existing protocol and other relevant parameters. - **returns**: A chat response or an error message. """ - logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user}) + logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user}) try: - - if not body.prompt or body.prompt == "": + if not body.protocol_text or body.protocol_text == "": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump() ) if body.fake: - return ChatResponse(reply="Fake response", fake=body.fake) + return ChatResponse(reply="Fake response", fake=bool(body.fake)) response: Optional[str] = None if "openai" in settings.model.lower(): - response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None) + response = openai.predict(prompt=body.prompt, chat_completion_message_params=None) else: - response = claude.predict(prompt=str(body.model_dump())) + response = claude.update(user_id=str(user.sub), prompt=body.prompt, history=None) if response is None or response == "": return ChatResponse(reply="No response was generated", fake=bool(body.fake)) @@ -299,7 +308,7 @@ async def create_protocol( return ChatResponse(reply=response, fake=bool(body.fake)) except Exception as e: - logger.exception("Error processing protocol creation") + logger.exception("Error processing protocol update") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump() ) from e diff --git a/opentrons-ai-server/api/models/chat_request.py b/opentrons-ai-server/api/models/chat_request.py index fb8c0942c9d..94d265e113d 100644 --- a/opentrons-ai-server/api/models/chat_request.py +++ b/opentrons-ai-server/api/models/chat_request.py @@ -24,9 +24,13 @@ class Chat(BaseModel): Field(None, description="Chat history in the form of a list of messages. Type is from OpenAI's ChatCompletionMessageParam"), ] +ChatOptions = Literal["update", "create"] +ChatOptionsType = Annotated[Optional[ChatOptions], Field("create", description="which chat pathway did the user enter: create or update")] + class ChatRequest(BaseModel): message: str = Field(..., description="The latest message to be processed.") history: HistoryType fake: bool = Field(True, description="When set to true, the response will be a fake. OpenAI API is not used.") fake_key: FakeKeyType + chat_options: ChatOptionsType diff --git a/opentrons-ai-server/tests/helpers/client.py b/opentrons-ai-server/tests/helpers/client.py index bf5a7febb3c..3b3dcfa7511 100644 --- a/opentrons-ai-server/tests/helpers/client.py +++ b/opentrons-ai-server/tests/helpers/client.py @@ -65,7 +65,7 @@ def get_health(self) -> Response: @timeit def get_chat_completion(self, message: str, fake: bool = True, fake_key: Optional[FakeKeys] = None, bad_auth: bool = False) -> Response: """Call the /chat/completion endpoint and return the response.""" - request = ChatRequest(message=message, fake=fake, fake_key=fake_key, history=None) + request = ChatRequest(message=message, fake=fake, fake_key=fake_key, history=None, chat_options=None) headers = self.standard_headers if not bad_auth else self.invalid_auth_headers return self.httpx.post("/chat/completion", headers=headers, json=request.model_dump())