From 0a05247f18635e7379b0cb620d09297c123cec1e Mon Sep 17 00:00:00 2001 From: sad-zero Date: Wed, 22 May 2024 13:35:02 +0900 Subject: [PATCH] [fix] remove unused memory class --- chat.py | 39 +++++++++++++++++++++++++++++++++++--- etl.py | 2 +- pyproject.toml | 2 +- simple_chatbot/chatbot.py | 40 +-------------------------------------- 4 files changed, 39 insertions(+), 44 deletions(-) diff --git a/chat.py b/chat.py index 30c4f1e..3825416 100644 --- a/chat.py +++ b/chat.py @@ -13,16 +13,49 @@ logging.basicConfig(level=logging.DEBUG) if __name__ == "__main__": - embeddings = OllamaEmbeddings(model="all-minilm") + embeddings = OllamaEmbeddings(model="phi3:3.8b-mini-instruct-4k-q4_K_M") client = PersistentClient(path="resources/chroma_db", settings=Settings(anonymized_telemetry=False)) collection = client.get_collection("books") - model = ChatOllama(model="phi3") + model = ChatOllama(model="phi3:3.8b-mini-instruct-4k-q4_K_M") chatbot = Chatbot(embeddings=embeddings, db=collection, model=model) chatbot.add_rule("You should find reasons of your answer in the given data") chatbot.add_rule("You should answer with reasons") chatbot.add_rule('You should answer only "모르겠습니다" if you can\'t find the reasons in the given data') - chatbot.add_rule("You should answer in korean") + chatbot.add_rule("You should answer in korean, should not answer in english") + chatbot.add_rule( + """ +[Examples] +Q1) +[Data] +자취 생활에는 다이소를 자주 이용합니다. +[Question] +자취 꿀팁 알려줘 +A1) +제공된 내용에 따르면 자취시에는 다이소를 자주 이용해야 합니다. +Q2) +[Data] +계약 전에 집을 꼼꼼히 봐야 합니다. +[Question] +집을 계약하기 전에 주의할 점이 뭐야? +A2) +제공된 내용에 따르면 집을 계약하기 전에는 집을 꼼꼼히 봐야 합니다. +Q3) +[Data] +자취시에는 상비약을 챙겨야 합니다. +[Question] +배고파 +A3) +모르겠습니다 +Q4) +[Data] +자취시에는 채소를 꼭 먹어야 합니다. +[Question] +내가 방금 어떤 질문을 했어? +A4) +바로 직전에 했던 질문은 "배고파"입니다. +""".strip() + ) while True: query = input("Query: ") answer = chatbot.chat(query) diff --git a/etl.py b/etl.py index 20104eb..ca9a4b2 100644 --- a/etl.py +++ b/etl.py @@ -13,7 +13,7 @@ def get_vector_store(coll_name: str = "books", persistent_path: str = "resources/chroma_db") -> Chroma: - embeddings = OllamaEmbeddings(model="all-minilm") + embeddings = OllamaEmbeddings(model="phi3:3.8b-mini-instruct-4k-q4_K_M") store = InMemoryByteStore() embeddings_func = CacheBackedEmbeddings.from_bytes_store( underlying_embeddings=embeddings, document_embedding_cache=store, namespace="in-memory-chat" diff --git a/pyproject.toml b/pyproject.toml index 8afc2db..9829167 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ dynamic = ["dependencies"] name = "simple_chatbot" authors = [{name = "sad-zero", email = "zeroro.yun@gmail.com"}] -version = "1.0.1" +version = "1.0.2" readme = "README.md" license = {file = "LICENSE"} diff --git a/simple_chatbot/chatbot.py b/simple_chatbot/chatbot.py index 0a47f69..6939619 100644 --- a/simple_chatbot/chatbot.py +++ b/simple_chatbot/chatbot.py @@ -2,54 +2,18 @@ Chatbot """ -from copy import copy import logging from typing import List from chromadb import Collection from langchain_core.embeddings.embeddings import Embeddings from langchain_core.language_models.llms import BaseLLM -from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage +from langchain_core.messages import SystemMessage from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import Runnable from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_community.chat_message_histories.in_memory import ChatMessageHistory - - -class _Memory: - __core_db: List[SystemMessage] = [] - __chat_db: List[BaseMessage] = [] - __limit: int = 20 - - def __init__(self, limit: int = 20): - if limit: - self.__limit = limit - - def append_core_message(self, message: SystemMessage): - self.__core_db.append(message) - - def append_chat_message(self, message: AIMessage | HumanMessage): - """ - Store latest n-messages - """ - if len(self.__chat_db) > self.__limit: - self.__chat_db.pop(0) - self.__chat_db.append(message) - - def get_core(self) -> List[SystemMessage]: - return copy(self.__core_db) - - def get_chat(self) -> List[HumanMessage | AIMessage]: - return copy(self.__chat_db) - - def clear_chat(self): - self.__chat_db.clear() - - def clear_core(self): - self.__core_db.clear() - - from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever @@ -122,8 +86,6 @@ def __get_session_history(self, session_id: str) -> BaseChatMessageHistory: if session_id not in self.__sessions: self.__sessions[session_id] = ChatMessageHistory() result: ChatMessageHistory = self.__sessions[session_id] - result.add_user_message(HumanMessage(content="안녕하세요")) - result.add_ai_message(AIMessage(content="반갑습니다")) return result def query(self, query: str) -> str: