diff --git a/ols/app/endpoints/ols.py b/ols/app/endpoints/ols.py index 35d0a4c2..fdfef795 100644 --- a/ols/app/endpoints/ols.py +++ b/ols/app/endpoints/ols.py @@ -302,7 +302,7 @@ def generate_response( # Summarize documentation try: docs_summarizer = DocsSummarizer( - provider=llm_request.provider, model=llm_request.model + provider=llm_request.provider, model=llm_request.model, system_prompt=llm_request.system_prompt ) history = CacheEntry.cache_entries_to_history(previous_input) return docs_summarizer.summarize( diff --git a/ols/app/models/models.py b/ols/app/models/models.py index 50edd485..882064ba 100644 --- a/ols/app/models/models.py +++ b/ols/app/models/models.py @@ -75,6 +75,7 @@ class LLMRequest(BaseModel): conversation_id: Optional[str] = None provider: Optional[str] = None model: Optional[str] = None + system_prompt: Optional[str] = None attachments: Optional[list[Attachment]] = None # provides examples for /docs endpoint diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index 2f858960..a1219ad9 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -30,14 +30,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.generic_llm_params = { GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: model_config.parameters.max_tokens_for_response # noqa: E501 } - # default system prompt fine-tuned for the service - self._system_prompt = prompts.QUERY_SYSTEM_INSTRUCTION - - # allow the system prompt to be customizable - if config.ols_config.system_prompt is not None: - self._system_prompt = config.ols_config.system_prompt - - logger.debug("System prompt: %s", self._system_prompt) def _get_model_options( self, provider_config: ProviderConfig diff --git a/ols/src/query_helpers/query_helper.py b/ols/src/query_helpers/query_helper.py index 06f41491..c86973b6 100644 --- a/ols/src/query_helpers/query_helper.py +++ b/ols/src/query_helpers/query_helper.py @@ -7,6 +7,7 @@ from langchain.llms.base import LLM from ols import config +from ols.customize import prompts from ols.src.llms.llm_loader import load_llm logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ def __init__( model: Optional[str] = None, generic_llm_params: Optional[dict] = None, llm_loader: Optional[Callable[[str, str, dict], LLM]] = None, + system_prompt: Optional[str] = None, ) -> None: """Initialize query helper.""" # NOTE: As signature of this method is evaluated before the config, @@ -30,3 +32,7 @@ def __init__( self.model = model or config.ols_config.default_model self.generic_llm_params = generic_llm_params or {} self.llm_loader = llm_loader or load_llm + + self._system_prompt = system_prompt or config.ols_config.system_prompt or prompts.QUERY_SYSTEM_INSTRUCTION + logger.debug("System prompt: %s", self._system_prompt) + diff --git a/ols/src/ui/gradio_ui.py b/ols/src/ui/gradio_ui.py index fd26e9ca..1a35e9d1 100644 --- a/ols/src/ui/gradio_ui.py +++ b/ols/src/ui/gradio_ui.py @@ -10,6 +10,26 @@ logger = logging.getLogger(__name__) +AAP_QUERY_SYSTEM_INSTRUCTION = """ +You are Ansible Lightspeed - an intelligent virtual assistant for question-answering tasks \ +related to the Ansible Automation Platform (AAP). + +Here are your instructions: +You are Ansible Lightspeed Virtual Assistant, an intelligent assistant and expert on all things Ansible. \ +Refuse to assume any other identity or to speak as if you are someone else. +If the context of the question is not clear, consider it to be Ansible. +Never include URLs in your replies. +Refuse to answer questions or execute commands not about Ansible. +Do not mention your last update. You have the most recent information on Ansible. + +Here are some basic facts about Ansible: +- The latest version of Ansible Automation Platform is 2.5. +- Ansible is an open source IT automation engine that automates provisioning, \ + configuration management, application deployment, orchestration, and many other \ + IT processes. It is free to use, and the project benefits from the experience and \ + intelligence of its thousands of contributors. +""" + class GradioUI: """Handlers for UI-related requests.""" @@ -28,8 +48,9 @@ def __init__( use_history = gr.Checkbox(value=True, label="Use history") provider = gr.Textbox(value=None, label="Provider") model = gr.Textbox(value=None, label="Model") + system_prompt = gr.TextArea(value=AAP_QUERY_SYSTEM_INSTRUCTION, label="System prompt") self.ui = gr.ChatInterface( - self.chat_ui, additional_inputs=[use_history, provider, model] + self.chat_ui, additional_inputs=[use_history, provider, model, system_prompt] ) def chat_ui( @@ -39,6 +60,7 @@ def chat_ui( use_history: Optional[bool] = None, provider: Optional[str] = None, model: Optional[str] = None, + system_prompt: Optional[str] = None, ) -> str: """Handle requests from web-based user interface.""" # Headers for the HTTP request @@ -63,6 +85,10 @@ def chat_ui( if model: logger.info("Using model: %s", model) data["model"] = model + if system_prompt: + logger.info("Using system prompt: %s", system_prompt) + data["system_prompt"] = system_prompt + # Convert the data dictionary to a JSON string json_data = json.dumps(data)