Skip to content

Commit

Permalink
fix black
Browse files Browse the repository at this point in the history
  • Loading branch information
jameswnl committed Oct 17, 2024
1 parent 7d0a210 commit 77fb6a2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
10 changes: 6 additions & 4 deletions ols/app/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,9 +902,11 @@ def __init__(
if data is None:
return

self.conversation_cache = ConversationCacheConfig(
data.get("conversation_cache")
) if data.get("conversation_cache") else None
self.conversation_cache = (
ConversationCacheConfig(data.get("conversation_cache"))
if data.get("conversation_cache")
else None
)
self.logging_config = LoggingConfig(**data.get("logging_config", {}))
if data.get("reference_content") is not None:
self.reference_content = ReferenceContent(data.get("reference_content"))
Expand Down Expand Up @@ -933,7 +935,7 @@ def __init__(
self.certificate_directory = data.get(
"certificate_directory", constants.DEFAULT_CERTIFICATE_DIRECTORY
)
self.customize = data.get('customize')
self.customize = data.get("customize")

def __eq__(self, other: object) -> bool:
"""Compare two objects for equality."""
Expand Down
6 changes: 3 additions & 3 deletions ols/customize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import importlib

project = os.getenv('PROJECT', 'ols')
prompts = importlib.import_module(f'ols.customize.{project}.prompts')
keywords = importlib.import_module(f'ols.customize.{project}.keywords')
project = os.getenv("PROJECT", "ols")
prompts = importlib.import_module(f"ols.customize.{project}.prompts")
keywords = importlib.import_module(f"ols.customize.{project}.keywords")
17 changes: 13 additions & 4 deletions ols/src/prompts/prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ols.constants import ModelFamily
from ols.customize import prompts


def restructure_rag_context_pre(text: str, model: str) -> str:
"""Restructure rag text - pre truncation."""
if ModelFamily.GRANITE in model:
Expand Down Expand Up @@ -62,7 +63,9 @@ def _generate_prompt_gpt(self) -> tuple[ChatPromptTemplate, dict]:

if len(self._rag_context) > 0:
llm_input_values["context"] = "".join(self._rag_context)
sys_intruction = sys_intruction + "\n" + prompts.USE_CONTEXT_INSTRUCTION.strip()
sys_intruction = (
sys_intruction + "\n" + prompts.USE_CONTEXT_INSTRUCTION.strip()
)

if len(self._history) > 0:
chat_history = []
Expand All @@ -73,7 +76,9 @@ def _generate_prompt_gpt(self) -> tuple[ChatPromptTemplate, dict]:
chat_history.append(AIMessage(content=h.removeprefix("ai: ")))
llm_input_values["chat_history"] = chat_history

sys_intruction = sys_intruction + "\n" + prompts.USE_HISTORY_INSTRUCTION.strip()
sys_intruction = (
sys_intruction + "\n" + prompts.USE_HISTORY_INSTRUCTION.strip()
)

if "context" in llm_input_values:
sys_intruction = sys_intruction + "\n{context}"
Expand All @@ -93,10 +98,14 @@ def _generate_prompt_granite(self) -> tuple[PromptTemplate, dict]:

if len(self._rag_context) > 0:
llm_input_values["context"] = "".join(self._rag_context)
prompt_message = prompt_message + "\n" + prompts.USE_CONTEXT_INSTRUCTION.strip()
prompt_message = (
prompt_message + "\n" + prompts.USE_CONTEXT_INSTRUCTION.strip()
)

if len(self._history) > 0:
prompt_message = prompt_message + "\n" + prompts.USE_HISTORY_INSTRUCTION.strip()
prompt_message = (
prompt_message + "\n" + prompts.USE_HISTORY_INSTRUCTION.strip()
)
llm_input_values["chat_history"] = "".join(self._history)

if "context" in llm_input_values:
Expand Down

0 comments on commit 77fb6a2

Please sign in to comment.