Skip to content

Commit

Permalink
Implementing deepeval integration
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste committed Jul 29, 2024
1 parent a1884e5 commit 61f302e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
46 changes: 34 additions & 12 deletions chatsky/llm/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_google_vertexai import ChatVertexAI
from langchain_cohere import ChatCohere
from langchain_mistralai import ChatMistralAI
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
langchain_available = True
except ImportError:
langchain_available = False
Expand All @@ -19,18 +20,23 @@

import base64
import httpx
import re

from chatsky.script.core.message import Image, Message
from chatsky.script import Context
from chatsky.pipeline import Pipeline

from pydantic import BaseModel
from typing import Union
from typing import Union, Callable

import re
try:
from deepeval.models import DeepEvalBaseLLM
deepeval_available = True
except ImportError:
deepeval_available = False


class LLM_API(BaseModel):
class LLM_API(BaseModel, DeepEvalBaseLLM):
"""
This class acts as a wrapper for all LLMs from langchain
and handles message exchange between remote model and chatsky classes.
Expand All @@ -57,6 +63,8 @@ def __init__(
def __check_imports(self):
if not langchain_available:
raise ImportError("Langchain is not available. Please install it with `pip install chatsky[llm]`.")
if not deepeval_available:
raise ImportError("DeepEval is not available. Please install it with `pip install chatsky[llm]`.")


def respond(self, history: list = []) -> Message:
Expand All @@ -68,15 +76,32 @@ def respond(self, history: list = []) -> Message:
def condition(self, prompt, request):
result = self.parser.invoke(self.model.invoke([prompt+'\n'+request.text]))
return result

# Helper functions for DeepEval custom LLM usage
def generate(self, prompt: str, schema: BaseModel):
# TODO: Remake this
schema_parser = StructuredOutputParser.from_response_schemas([ResponseSchema(base_model=schema)])
chain = prompt | self.model | schema_parser
return chain.invoke({"input": prompt})

async def a_generate(self, prompt: str, schema: BaseModel):
# TODO: Remake this
return self.generate(HumanMessage(prompt), schema)

def load_model(self):
return self.model

def get_model_name(self):
return self.name


def llm_response(
ctx: Context,
pipeline: Pipeline,
model_name,
prompt="",
history=10,
filter_non_llm=True
history=5,
filter: Callable=None
):
"""
Basic function for receiving LLM responses.
Expand All @@ -85,16 +110,14 @@ def llm_response(
:param model_name: Name of the model from the `Pipeline.models` dictionary.
:param prompt: Prompt for the model.
:param history: Number of messages to keep in history.
:param filter_non_llm: Whether to filter non-LLM messages from the history.
:param filter: filter function to filter messages that will go the models context.
"""
model = pipeline.get(model_name)
history_messages = []
if history == 0:
return model.respond([prompt + "\n" + ctx.last_request.text])
else:
for req, resp in zip(ctx.requests[-history:], ctx.responses[-history:]):
if filter_non_llm and resp.annotation.__generated_by_model__ != model_name:
continue
for req, resp in filter(lambda x: filter(x[0], x[1]), zip(ctx.requests[-history:], ctx.responses[-history:])):
if req.attachments != []:
content = [{"type": "text", "text": prompt + "\n" + ctx.last_request.text}]
for image in ctx.last_request.attachments:
Expand All @@ -117,15 +140,14 @@ def llm_condition(
pipeline: Pipeline,
model_name,
prompt="",
method="regex",
method: Callable=None,
threshold=0.9
):
"""
Basic function for using LLM in condition cases.
"""
model = pipeline.get(model_name)
if method == "regex":
return re.match(r"True", model.condition(prompt, ctx.last_request))
return method(model.condition(prompt, ctx.last_request))


def __attachment_to_content(attachment: Image) -> str:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ python-telegram-bot = { version = "~=21.3", extras = ["all"], optional = true }
opentelemetry-instrumentation = { version = "*", optional = true }
sqlalchemy = { version = "*", extras = ["asyncio"], optional = true }
opentelemetry-exporter-otlp = { version = ">=1.20.0", optional = true } # log body serialization is required
deepeval = { version = "^0.21.73", optional = true }

[tool.poetry.extras]
json = ["aiofiles"]
Expand All @@ -87,7 +88,7 @@ ydb = ["ydb", "six"]
telegram = ["python-telegram-bot"]
stats = ["opentelemetry-exporter-otlp", "opentelemetry-instrumentation", "requests", "tqdm", "omegaconf"]
benchmark = ["pympler", "humanize", "pandas", "altair", "tqdm"]
llm = ["httpx", "langchain", "langchain-openai", "langchain-anthropic", "langchain-google-vertexai", "langchain-cohere", "langchain-groq", "langchain-mistralai", "langchain-fireworks"]
llm = ["httpx", "langchain", "langchain-openai", "langchain-anthropic", "langchain-google-vertexai", "langchain-cohere", "langchain-groq", "langchain-mistralai", "langchain-fireworks", "deepeval"]

[tool.poetry.group.lint]
optional = true
Expand Down

0 comments on commit 61f302e

Please sign in to comment.