Skip to content

Commit

Permalink
[fix] remove unused memory class
Browse files Browse the repository at this point in the history
  • Loading branch information
sad-zero committed May 22, 2024
1 parent 9c88ce4 commit 0a05247
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 44 deletions.
39 changes: 36 additions & 3 deletions chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,49 @@
logging.basicConfig(level=logging.DEBUG)

if __name__ == "__main__":
embeddings = OllamaEmbeddings(model="all-minilm")
embeddings = OllamaEmbeddings(model="phi3:3.8b-mini-instruct-4k-q4_K_M")
client = PersistentClient(path="resources/chroma_db", settings=Settings(anonymized_telemetry=False))
collection = client.get_collection("books")
model = ChatOllama(model="phi3")
model = ChatOllama(model="phi3:3.8b-mini-instruct-4k-q4_K_M")

chatbot = Chatbot(embeddings=embeddings, db=collection, model=model)
chatbot.add_rule("You should find reasons of your answer in the given data")
chatbot.add_rule("You should answer with reasons")
chatbot.add_rule('You should answer only "모르겠습니다" if you can\'t find the reasons in the given data')
chatbot.add_rule("You should answer in korean")
chatbot.add_rule("You should answer in korean, should not answer in english")
chatbot.add_rule(
"""
[Examples]
Q1)
[Data]
자취 생활에는 다이소를 자주 이용합니다.
[Question]
자취 꿀팁 알려줘
A1)
제공된 내용에 따르면 자취시에는 다이소를 자주 이용해야 합니다.
Q2)
[Data]
계약 전에 집을 꼼꼼히 봐야 합니다.
[Question]
집을 계약하기 전에 주의할 점이 뭐야?
A2)
제공된 내용에 따르면 집을 계약하기 전에는 집을 꼼꼼히 봐야 합니다.
Q3)
[Data]
자취시에는 상비약을 챙겨야 합니다.
[Question]
배고파
A3)
모르겠습니다
Q4)
[Data]
자취시에는 채소를 꼭 먹어야 합니다.
[Question]
내가 방금 어떤 질문을 했어?
A4)
바로 직전에 했던 질문은 "배고파"입니다.
""".strip()
)
while True:
query = input("Query: ")
answer = chatbot.chat(query)
Expand Down
2 changes: 1 addition & 1 deletion etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def get_vector_store(coll_name: str = "books", persistent_path: str = "resources/chroma_db") -> Chroma:
embeddings = OllamaEmbeddings(model="all-minilm")
embeddings = OllamaEmbeddings(model="phi3:3.8b-mini-instruct-4k-q4_K_M")
store = InMemoryByteStore()
embeddings_func = CacheBackedEmbeddings.from_bytes_store(
underlying_embeddings=embeddings, document_embedding_cache=store, namespace="in-memory-chat"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dynamic = ["dependencies"]

name = "simple_chatbot"
authors = [{name = "sad-zero", email = "[email protected]"}]
version = "1.0.1"
version = "1.0.2"
readme = "README.md"
license = {file = "LICENSE"}

Expand Down
40 changes: 1 addition & 39 deletions simple_chatbot/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,18 @@
Chatbot
"""

from copy import copy
import logging
from typing import List
from chromadb import Collection
from langchain_core.embeddings.embeddings import Embeddings
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory


class _Memory:
__core_db: List[SystemMessage] = []
__chat_db: List[BaseMessage] = []
__limit: int = 20

def __init__(self, limit: int = 20):
if limit:
self.__limit = limit

def append_core_message(self, message: SystemMessage):
self.__core_db.append(message)

def append_chat_message(self, message: AIMessage | HumanMessage):
"""
Store latest n-messages
"""
if len(self.__chat_db) > self.__limit:
self.__chat_db.pop(0)
self.__chat_db.append(message)

def get_core(self) -> List[SystemMessage]:
return copy(self.__core_db)

def get_chat(self) -> List[HumanMessage | AIMessage]:
return copy(self.__chat_db)

def clear_chat(self):
self.__chat_db.clear()

def clear_core(self):
self.__core_db.clear()


from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
Expand Down Expand Up @@ -122,8 +86,6 @@ def __get_session_history(self, session_id: str) -> BaseChatMessageHistory:
if session_id not in self.__sessions:
self.__sessions[session_id] = ChatMessageHistory()
result: ChatMessageHistory = self.__sessions[session_id]
result.add_user_message(HumanMessage(content="안녕하세요"))
result.add_ai_message(AIMessage(content="반갑습니다"))
return result

def query(self, query: str) -> str:
Expand Down

0 comments on commit 0a05247

Please sign in to comment.