Skip to content

Commit

Permalink
Figured out how to implement DeepEval functions
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Jul 30, 2024
1 parent 61f302e commit 38a8f8f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
14 changes: 14 additions & 0 deletions chatsky/llm/conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
LLM conditions.
---------
In this file stored unified functions for some basic condition cases
including regex search, semantic distance (cosine) etc.
"""

from chatsky.script.core.message import Message

def regex_search(pattern: str) -> bool:
pass

def semantic_distance(target: str | Message, threshold: float) -> bool:
pass
19 changes: 7 additions & 12 deletions chatsky/llm/llm_response.py → chatsky/llm/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
from langchain_cohere import ChatCohere
from langchain_mistralai import ChatMistralAI
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
langchain_available = True
except ImportError:
langchain_available = False

from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser

import base64
import httpx
import re
Expand Down Expand Up @@ -78,15 +77,11 @@ def condition(self, prompt, request):
return result

# Helper functions for DeepEval custom LLM usage
def generate(self, prompt: str, schema: BaseModel):
# TODO: Remake this
schema_parser = StructuredOutputParser.from_response_schemas([ResponseSchema(base_model=schema)])
chain = prompt | self.model | schema_parser
return chain.invoke({"input": prompt})
def generate(self, prompt: str):
return self.model.invoke(prompt).content

async def a_generate(self, prompt: str, schema: BaseModel):
# TODO: Remake this
return self.generate(HumanMessage(prompt), schema)
async def a_generate(self, prompt: str):
return self.generate(prompt)

def load_model(self):
return self.model
Expand Down Expand Up @@ -174,4 +169,4 @@ def message_to_langchain(message: Message):
content.append(
{"type": "image_url", "image_url": {"url": __attachment_to_content(image)}}
)
return HumanMessage(content=content)
return HumanMessage(content=content)

0 comments on commit 38a8f8f

Please sign in to comment.