Skip to content

Commit

Permalink
Adding conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste committed Jul 31, 2024
1 parent 38a8f8f commit 592267f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
19 changes: 15 additions & 4 deletions chatsky/llm/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,20 @@
"""

from chatsky.script.core.message import Message
from chatsky.script import Context
from chatsky.pipeline import Pipeline
import re

def regex_search(pattern: str) -> bool:
pass

def semantic_distance(target: str | Message, threshold: float) -> bool:
pass
def regex_search(pattern: str):
def _(ctx: Context, _: Pipeline) -> bool:
return bool(re.search(pattern, ctx.last_request.text))

return _


def semantic_distance(target: str | Message, threshold: float):
def _(ctx: Context, _: Pipeline) -> bool:
pass

return _
15 changes: 15 additions & 0 deletions chatsky/llm/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Filters.
---------
This module contains a collection of basic functions for history filtering to avoid cluttering LLMs context window.
"""

from chatsky.script.core.message import Message

def is_important(msg: Message) -> bool:
if msg.misc["important"]:
return True
return False

def from_the_model(msg: Message, model_name: str) -> bool:
return msg.annotation.__generated_by_model__ == model_name
29 changes: 10 additions & 19 deletions chatsky/llm/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def llm_response(
model_name,
prompt="",
history=5,
filter: Callable=None
filter_func: Callable=None
):
"""
Basic function for receiving LLM responses.
Expand All @@ -105,28 +105,16 @@ def llm_response(
:param model_name: Name of the model from the `Pipeline.models` dictionary.
:param prompt: Prompt for the model.
:param history: Number of messages to keep in history.
:param filter: filter function to filter messages that will go the models context.
:param filter_func: filter function to filter messages that will go the models context.
"""
model = pipeline.get(model_name)
history_messages = []
if history == 0:
return model.respond([prompt + "\n" + ctx.last_request.text])
else:
for req, resp in filter(lambda x: filter(x[0], x[1]), zip(ctx.requests[-history:], ctx.responses[-history:])):
if req.attachments != []:
content = [{"type": "text", "text": prompt + "\n" + ctx.last_request.text}]
for image in ctx.last_request.attachments:
if image is not Image:
continue
content.append(
{"type": "image_url", "image_url": {"url": __attachment_to_content(image)}}
)
req_message = HumanMessage(content=content)
else:
req_message = HumanMessage(req.text)

history_messages.append(req_message)
history_messages.append(SystemMessage(resp.text))
for req, resp in filter(lambda x: filter_func(x), zip(ctx.requests[-history:], ctx.responses[-history:])):
history_messages.append(message_to_langchain(req))
history_messages.append(message_to_langchain(resp, human=False))
return model.respond(history_messages)


Expand Down Expand Up @@ -160,7 +148,7 @@ def __attachment_to_content(attachment: Image) -> str:
return image_b64


def message_to_langchain(message: Message):
def message_to_langchain(message: Message, human=True):
if message.attachments != []:
content = [{"type": "text", "text": message.text}]
for image in message.attachments:
Expand All @@ -169,4 +157,7 @@ def message_to_langchain(message: Message):
content.append(
{"type": "image_url", "image_url": {"url": __attachment_to_content(image)}}
)
return HumanMessage(content=content)
if human:
return HumanMessage(content=content)
else:
return SystemMessage(content=content)

0 comments on commit 592267f

Please sign in to comment.