diff --git a/chatsky/llm/conditions.py b/chatsky/llm/conditions.py index c393d7a6c..d7a107674 100644 --- a/chatsky/llm/conditions.py +++ b/chatsky/llm/conditions.py @@ -6,9 +6,20 @@ """ from chatsky.script.core.message import Message +from chatsky.script import Context +from chatsky.pipeline import Pipeline +import re -def regex_search(pattern: str) -> bool: - pass -def semantic_distance(target: str | Message, threshold: float) -> bool: - pass \ No newline at end of file +def regex_search(pattern: str): + def _(ctx: Context, _: Pipeline) -> bool: + return bool(re.search(pattern, ctx.last_request.text)) + + return _ + + +def semantic_distance(target: str | Message, threshold: float): + def _(ctx: Context, _: Pipeline) -> bool: + pass + + return _ diff --git a/chatsky/llm/filters.py b/chatsky/llm/filters.py new file mode 100644 index 000000000..a73b9b4a8 --- /dev/null +++ b/chatsky/llm/filters.py @@ -0,0 +1,15 @@ +""" +Filters. +--------- +This module contains a collection of basic functions for history filtering to avoid cluttering LLMs context window. +""" + +from chatsky.script.core.message import Message + +def is_important(msg: Message) -> bool: + if msg.misc["important"]: + return True + return False + +def from_the_model(msg: Message, model_name: str) -> bool: + return msg.annotation.__generated_by_model__ == model_name diff --git a/chatsky/llm/wrapper.py b/chatsky/llm/wrapper.py index 8e1c12a68..1facf43c1 100644 --- a/chatsky/llm/wrapper.py +++ b/chatsky/llm/wrapper.py @@ -96,7 +96,7 @@ def llm_response( model_name, prompt="", history=5, - filter: Callable=None + filter_func: Callable=None ): """ Basic function for receiving LLM responses. @@ -105,28 +105,16 @@ def llm_response( :param model_name: Name of the model from the `Pipeline.models` dictionary. :param prompt: Prompt for the model. :param history: Number of messages to keep in history. - :param filter: filter function to filter messages that will go the models context. + :param filter_func: filter function to filter messages that will go the models context. """ model = pipeline.get(model_name) history_messages = [] if history == 0: return model.respond([prompt + "\n" + ctx.last_request.text]) else: - for req, resp in filter(lambda x: filter(x[0], x[1]), zip(ctx.requests[-history:], ctx.responses[-history:])): - if req.attachments != []: - content = [{"type": "text", "text": prompt + "\n" + ctx.last_request.text}] - for image in ctx.last_request.attachments: - if image is not Image: - continue - content.append( - {"type": "image_url", "image_url": {"url": __attachment_to_content(image)}} - ) - req_message = HumanMessage(content=content) - else: - req_message = HumanMessage(req.text) - - history_messages.append(req_message) - history_messages.append(SystemMessage(resp.text)) + for req, resp in filter(lambda x: filter_func(x), zip(ctx.requests[-history:], ctx.responses[-history:])): + history_messages.append(message_to_langchain(req)) + history_messages.append(message_to_langchain(resp, human=False)) return model.respond(history_messages) @@ -160,7 +148,7 @@ def __attachment_to_content(attachment: Image) -> str: return image_b64 -def message_to_langchain(message: Message): +def message_to_langchain(message: Message, human=True): if message.attachments != []: content = [{"type": "text", "text": message.text}] for image in message.attachments: @@ -169,4 +157,7 @@ def message_to_langchain(message: Message): content.append( {"type": "image_url", "image_url": {"url": __attachment_to_content(image)}} ) - return HumanMessage(content=content) + if human: + return HumanMessage(content=content) + else: + return SystemMessage(content=content) \ No newline at end of file