diff --git a/chatsky/llm/llm_response.py b/chatsky/llm/llm_response.py index 5af12e151..6aa2f8127 100644 --- a/chatsky/llm/llm_response.py +++ b/chatsky/llm/llm_response.py @@ -10,6 +10,7 @@ from langchain_google_vertexai import ChatVertexAI from langchain_cohere import ChatCohere from langchain_mistralai import ChatMistralAI + from langchain.output_parsers import ResponseSchema, StructuredOutputParser langchain_available = True except ImportError: langchain_available = False @@ -19,18 +20,23 @@ import base64 import httpx +import re from chatsky.script.core.message import Image, Message from chatsky.script import Context from chatsky.pipeline import Pipeline from pydantic import BaseModel -from typing import Union +from typing import Union, Callable -import re +try: + from deepeval.models import DeepEvalBaseLLM + deepeval_available = True +except ImportError: + deepeval_available = False -class LLM_API(BaseModel): +class LLM_API(BaseModel, DeepEvalBaseLLM): """ This class acts as a wrapper for all LLMs from langchain and handles message exchange between remote model and chatsky classes. @@ -57,6 +63,8 @@ def __init__( def __check_imports(self): if not langchain_available: raise ImportError("Langchain is not available. Please install it with `pip install chatsky[llm]`.") + if not deepeval_available: + raise ImportError("DeepEval is not available. Please install it with `pip install chatsky[llm]`.") def respond(self, history: list = []) -> Message: @@ -68,6 +76,23 @@ def respond(self, history: list = []) -> Message: def condition(self, prompt, request): result = self.parser.invoke(self.model.invoke([prompt+'\n'+request.text])) return result + + # Helper functions for DeepEval custom LLM usage + def generate(self, prompt: str, schema: BaseModel): + # TODO: Remake this + schema_parser = StructuredOutputParser.from_response_schemas([ResponseSchema(base_model=schema)]) + chain = prompt | self.model | schema_parser + return chain.invoke({"input": prompt}) + + async def a_generate(self, prompt: str, schema: BaseModel): + # TODO: Remake this + return self.generate(HumanMessage(prompt), schema) + + def load_model(self): + return self.model + + def get_model_name(self): + return self.name def llm_response( @@ -75,8 +100,8 @@ def llm_response( pipeline: Pipeline, model_name, prompt="", - history=10, - filter_non_llm=True + history=5, + filter: Callable=None ): """ Basic function for receiving LLM responses. @@ -85,16 +110,14 @@ def llm_response( :param model_name: Name of the model from the `Pipeline.models` dictionary. :param prompt: Prompt for the model. :param history: Number of messages to keep in history. - :param filter_non_llm: Whether to filter non-LLM messages from the history. + :param filter: filter function to filter messages that will go the models context. """ model = pipeline.get(model_name) history_messages = [] if history == 0: return model.respond([prompt + "\n" + ctx.last_request.text]) else: - for req, resp in zip(ctx.requests[-history:], ctx.responses[-history:]): - if filter_non_llm and resp.annotation.__generated_by_model__ != model_name: - continue + for req, resp in filter(lambda x: filter(x[0], x[1]), zip(ctx.requests[-history:], ctx.responses[-history:])): if req.attachments != []: content = [{"type": "text", "text": prompt + "\n" + ctx.last_request.text}] for image in ctx.last_request.attachments: @@ -117,15 +140,14 @@ def llm_condition( pipeline: Pipeline, model_name, prompt="", - method="regex", + method: Callable=None, threshold=0.9 ): """ Basic function for using LLM in condition cases. """ model = pipeline.get(model_name) - if method == "regex": - return re.match(r"True", model.condition(prompt, ctx.last_request)) + return method(model.condition(prompt, ctx.last_request)) def __attachment_to_content(attachment: Image) -> str: diff --git a/pyproject.toml b/pyproject.toml index f40b39cf7..799ceb04e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ python-telegram-bot = { version = "~=21.3", extras = ["all"], optional = true } opentelemetry-instrumentation = { version = "*", optional = true } sqlalchemy = { version = "*", extras = ["asyncio"], optional = true } opentelemetry-exporter-otlp = { version = ">=1.20.0", optional = true } # log body serialization is required +deepeval = { version = "^0.21.73", optional = true } [tool.poetry.extras] json = ["aiofiles"] @@ -87,7 +88,7 @@ ydb = ["ydb", "six"] telegram = ["python-telegram-bot"] stats = ["opentelemetry-exporter-otlp", "opentelemetry-instrumentation", "requests", "tqdm", "omegaconf"] benchmark = ["pympler", "humanize", "pandas", "altair", "tqdm"] -llm = ["httpx", "langchain", "langchain-openai", "langchain-anthropic", "langchain-google-vertexai", "langchain-cohere", "langchain-groq", "langchain-mistralai", "langchain-fireworks"] +llm = ["httpx", "langchain", "langchain-openai", "langchain-anthropic", "langchain-google-vertexai", "langchain-cohere", "langchain-groq", "langchain-mistralai", "langchain-fireworks", "deepeval"] [tool.poetry.group.lint] optional = true