Skip to content

Commit

Permalink
customize ols
Browse files Browse the repository at this point in the history
  • Loading branch information
jameswnl committed Oct 29, 2024
1 parent a30427f commit b5c0e2e
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 48 deletions.
19 changes: 5 additions & 14 deletions Containerfile
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# vim: set filetype=dockerfile
ARG LIGHTSPEED_RAG_CONTENT_IMAGE=quay.io/openshift-lightspeed/lightspeed-rag-content@sha256:a91aca8224b1405e7c91576374c7bbc766b2009b2ef852895c27069fffc5b06f
ARG LIGHTSPEED_RAG_CONTENT_IMAGE=quay.io/openshift-lightspeed/lightspeed-rag-content@sha256:24699b4ebe31dfb09ba706e44140db48772b37590a1839e2c9f5de2005c8c385
ARG RAG_CONTENTS_SUB_FOLDER=vector_db/ocp_product_docs

FROM ${LIGHTSPEED_RAG_CONTENT_IMAGE} as lightspeed-rag-content

FROM registry.redhat.io/ubi9/ubi-minimal:latest
FROM registry.access.redhat.com/ubi9/ubi-minimal

ARG VERSION
# todo: this is overriden by the image ubi9/python-311, we hard coded WORKDIR below to /app-root
# makesure the default value of rag content is set according to APP_ROOT and then update the operator.
ARG APP_ROOT=/app-root

RUN microdnf install -y --nodocs --setopt=keepcache=0 --setopt=tsflags=nodocs \
Expand All @@ -26,7 +24,7 @@ ENV PYTHONDONTWRITEBYTECODE=1 \

WORKDIR /app-root

COPY --from=lightspeed-rag-content /rag/vector_db/ocp_product_docs ./vector_db/ocp_product_docs
COPY --from=lightspeed-rag-content /rag/${RAG_CONTENTS_SUB_FOLDER} ${APP_ROOT}/${RAG_CONTENTS_SUB_FOLDER}
COPY --from=lightspeed-rag-content /rag/embeddings_model ./embeddings_model

# Add explicit files and directories
Expand All @@ -45,14 +43,7 @@ EXPOSE 8080
EXPOSE 8443
CMD ["python3.11", "runner.py"]

LABEL io.k8s.display-name="OpenShift LightSpeed Service" \
io.k8s.description="AI-powered OpenShift Assistant Service." \
io.openshift.tags="openshift-lightspeed,ols" \
description="Red Hat OpenShift Lightspeed Service" \
summary="Red Hat OpenShift Lightspeed Service" \
com.redhat.component=openshift-lightspeed-service \
name=openshift-lightspeed-service \
vendor="Red Hat, Inc."
LABEL vendor="Red Hat, Inc."


# no-root user is checked in Konflux
Expand Down
6 changes: 3 additions & 3 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
SummarizerResponse,
UnauthorizedResponse,
)
from ols.customize import keywords, prompts
from ols.src.llms.llm_loader import LLMConfigurationError, resolve_provider_config
from ols.src.query_helpers.attachment_appender import append_attachments_to_query
from ols.src.query_helpers.docs_summarizer import DocsSummarizer
from ols.src.query_helpers.question_validator import QuestionValidator
from ols.utils import errors_parsing, suid
from ols.utils.auth_dependency import AuthDependency
from ols.utils.keywords import KEYWORDS
from ols.utils.token_handler import PromptTooLongError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -130,7 +130,7 @@ def conversation_request(

if not valid:
summarizer_response = SummarizerResponse(
constants.INVALID_QUERY_RESP,
prompts.INVALID_QUERY_RESP,
[],
False,
)
Expand Down Expand Up @@ -496,7 +496,7 @@ def _validate_question_keyword(query: str) -> bool:
# Current implementation is without any tokenizer method, lemmatization/n-grams.
# Add valid keywords to keywords.py file.
query_temp = query.lower()
for kw in KEYWORDS:
for kw in keywords.KEYWORDS:
if kw in query_temp:
return True
# query_temp = {q_word.lower().strip(".?,") for q_word in query.split()}
Expand Down
8 changes: 6 additions & 2 deletions ols/app/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ class OLSConfig(BaseModel):

extra_ca: list[FilePath] = []
certificate_directory: Optional[str] = None
customize: Optional[str] = None

def __init__(
self, data: Optional[dict] = None, ignore_missing_certs: bool = False
Expand All @@ -901,8 +902,10 @@ def __init__(
if data is None:
return

self.conversation_cache = ConversationCacheConfig(
data.get("conversation_cache", None)
self.conversation_cache = (
ConversationCacheConfig(data.get("conversation_cache"))
if data.get("conversation_cache")
else None
)
self.logging_config = LoggingConfig(**data.get("logging_config", {}))
if data.get("reference_content") is not None:
Expand Down Expand Up @@ -932,6 +935,7 @@ def __init__(
self.certificate_directory = data.get(
"certificate_directory", constants.DEFAULT_CERTIFICATE_DIRECTORY
)
self.customize = data.get("customize")

def __eq__(self, other: object) -> bool:
"""Compare two objects for equality."""
Expand Down
7 changes: 0 additions & 7 deletions ols/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@ class QueryValidationMethod(StrEnum):
SUBJECT_ALLOWED = "ALLOWED"


# Default responses
INVALID_QUERY_RESP = (
"Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift, "
"please ask me a question related to OpenShift."
)


# providers
PROVIDER_BAM = "bam"
PROVIDER_OPENAI = "openai"
Expand Down
8 changes: 8 additions & 0 deletions ols/customize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Contains customization packages for individual projects (for prompts/keyvords)."""

import importlib
import os

project = os.getenv("PROJECT", "ols")
prompts = importlib.import_module(f"ols.customize.{project}.prompts")
keywords = importlib.import_module(f"ols.customize.{project}.keywords")
1 change: 1 addition & 0 deletions ols/customize/ols/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Customized prompts/keyvords for OpenShift Lightspeed Service (ols)."""
File renamed without changes.
6 changes: 6 additions & 0 deletions ols/src/prompts/prompts.py → ols/customize/ols/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
# but that is not done as granite was adding role tags like `Human:` in the response.
# With PromptTemplate, we have more control how we want to structure the prompt.

# Default responses
INVALID_QUERY_RESP = (
"Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift, "
"please ask me a question related to OpenShift."
)

QUERY_SYSTEM_INSTRUCTION = """
You are OpenShift Lightspeed - an intelligent assistant for question-answering tasks \
related to the OpenShift container orchestration platform.
Expand Down
25 changes: 14 additions & 11 deletions ols/src/prompts/prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
)

from ols.constants import ModelFamily

from .prompts import (
QUERY_SYSTEM_INSTRUCTION,
USE_CONTEXT_INSTRUCTION,
USE_HISTORY_INSTRUCTION,
)
from ols.customize import prompts


def restructure_rag_context_pre(text: str, model: str) -> str:
Expand Down Expand Up @@ -52,7 +47,7 @@ def __init__(
query: str,
rag_context: list[str] = [],
history: list[str] = [],
system_instruction: str = QUERY_SYSTEM_INSTRUCTION,
system_instruction: str = prompts.QUERY_SYSTEM_INSTRUCTION,
):
"""Initialize prompt generator."""
self._query = query
Expand All @@ -68,7 +63,9 @@ def _generate_prompt_gpt(self) -> tuple[ChatPromptTemplate, dict]:

if len(self._rag_context) > 0:
llm_input_values["context"] = "".join(self._rag_context)
sys_intruction = sys_intruction + "\n" + USE_CONTEXT_INSTRUCTION.strip()
sys_intruction = (
sys_intruction + "\n" + prompts.USE_CONTEXT_INSTRUCTION.strip()
)

if len(self._history) > 0:
chat_history = []
Expand All @@ -79,7 +76,9 @@ def _generate_prompt_gpt(self) -> tuple[ChatPromptTemplate, dict]:
chat_history.append(AIMessage(content=h.removeprefix("ai: ")))
llm_input_values["chat_history"] = chat_history

sys_intruction = sys_intruction + "\n" + USE_HISTORY_INSTRUCTION.strip()
sys_intruction = (
sys_intruction + "\n" + prompts.USE_HISTORY_INSTRUCTION.strip()
)

if "context" in llm_input_values:
sys_intruction = sys_intruction + "\n{context}"
Expand All @@ -99,10 +98,14 @@ def _generate_prompt_granite(self) -> tuple[PromptTemplate, dict]:

if len(self._rag_context) > 0:
llm_input_values["context"] = "".join(self._rag_context)
prompt_message = prompt_message + "\n" + USE_CONTEXT_INSTRUCTION.strip()
prompt_message = (
prompt_message + "\n" + prompts.USE_CONTEXT_INSTRUCTION.strip()
)

if len(self._history) > 0:
prompt_message = prompt_message + "\n" + USE_HISTORY_INSTRUCTION.strip()
prompt_message = (
prompt_message + "\n" + prompts.USE_HISTORY_INSTRUCTION.strip()
)
llm_input_values["chat_history"] = "".join(self._history)

if "context" in llm_input_values:
Expand Down
4 changes: 2 additions & 2 deletions ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from ols.app.models.config import ProviderConfig
from ols.app.models.models import SummarizerResponse
from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters
from ols.customize import prompts
from ols.src.prompts.prompt_generator import GeneratePrompt
from ols.src.prompts.prompts import QUERY_SYSTEM_INSTRUCTION
from ols.src.query_helpers.query_helper import QueryHelper
from ols.utils.token_handler import TokenHandler

Expand All @@ -31,7 +31,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: model_config.parameters.max_tokens_for_response # noqa: E501
}
# default system prompt fine-tuned for the service
self._system_prompt = QUERY_SYSTEM_INSTRUCTION
self._system_prompt = prompts.QUERY_SYSTEM_INSTRUCTION

# allow the system prompt to be customizable
if config.ols_config.system_prompt is not None:
Expand Down
4 changes: 2 additions & 2 deletions ols/src/query_helpers/question_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ols import config
from ols.app.metrics import TokenMetricUpdater
from ols.constants import SUBJECT_REJECTED, GenericLLMParameters
from ols.src.prompts.prompts import QUESTION_VALIDATOR_PROMPT_TEMPLATE
from ols.customize import prompts
from ols.src.query_helpers.query_helper import QueryHelper
from ols.utils.token_handler import TokenHandler

Expand Down Expand Up @@ -54,7 +54,7 @@ def validate_question(
logger.info(f"{conversation_id} call settings: {settings_string}")

prompt_instructions = PromptTemplate.from_template(
QUESTION_VALIDATOR_PROMPT_TEMPLATE
prompts.QUESTION_VALIDATOR_PROMPT_TEMPLATE
)

bare_llm = self.llm_loader(self.provider, self.model, self.generic_llm_params)
Expand Down
5 changes: 2 additions & 3 deletions runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import ols.app.models.config as config_model
from ols import constants
from ols.utils.auth_dependency import K8sClientSingleton
from ols.utils.logging import configure_logging


Expand Down Expand Up @@ -163,8 +162,8 @@ def start_uvicorn():

# Initialize the K8sClientSingleton with cluster id during module load.
# We want the application to fail early if the cluster ID is not available.
cluster_id = K8sClientSingleton.get_cluster_id()
logger.info(f"running on cluster with ID '{cluster_id}'")
# cluster_id = K8sClientSingleton.get_cluster_id()
# logger.info(f"running on cluster with ID '{cluster_id}'")

# init loading of query redactor
config.query_redactor
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ProviderConfig,
QueryFilter,
)
from ols.customize import prompts
from ols.utils import suid
from ols.utils.errors_parsing import DEFAULT_ERROR_MESSAGE, DEFAULT_STATUS_CODE
from tests.mock_classes.mock_langchain_interface import mock_langchain_interface
Expand Down Expand Up @@ -84,7 +85,7 @@ def test_post_question_on_invalid_question(_setup):

expected_json = {
"conversation_id": conversation_id,
"response": constants.INVALID_QUERY_RESP,
"response": prompts.INVALID_QUERY_RESP,
"referenced_documents": [],
"truncated": False,
}
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/app/endpoints/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ReferencedDocument,
SummarizerResponse,
)
from ols.customize import prompts
from ols.src.llms.llm_loader import LLMConfigurationError
from ols.utils import suid
from ols.utils.errors_parsing import DEFAULT_ERROR_MESSAGE
Expand Down Expand Up @@ -651,7 +652,7 @@ def test_conversation_request(
mock_validate_question.return_value = False
llm_request = LLMRequest(query="Generate a yaml")
response = ols.conversation_request(llm_request, auth)
assert response.response == constants.INVALID_QUERY_RESP
assert response.response == prompts.INVALID_QUERY_RESP
assert suid.check_suid(
response.conversation_id
), "Improper conversation ID returned"
Expand Down Expand Up @@ -738,7 +739,7 @@ def test_question_validation_in_conversation_start(auth):

response = ols.conversation_request(llm_request, auth)

assert response.response.startswith(constants.INVALID_QUERY_RESP)
assert response.response.startswith(prompts.INVALID_QUERY_RESP)


@pytest.mark.usefixtures("_load_config")
Expand Down Expand Up @@ -778,7 +779,7 @@ def test_conversation_request_invalid_subject(mock_validate, auth):

mock_validate.return_value = False
response = ols.conversation_request(llm_request, auth)
assert response.response == constants.INVALID_QUERY_RESP
assert response.response == prompts.INVALID_QUERY_RESP
assert len(response.referenced_documents) == 0
assert not response.truncated

Expand Down

0 comments on commit b5c0e2e

Please sign in to comment.