diff --git a/bolna/agent_types/contextual_conversational_agent.py b/bolna/agent_types/contextual_conversational_agent.py index bd000758..12f0027e 100644 --- a/bolna/agent_types/contextual_conversational_agent.py +++ b/bolna/agent_types/contextual_conversational_agent.py @@ -6,10 +6,27 @@ from bolna.llms import OpenAiLLM from bolna.prompts import CHECK_FOR_COMPLETION_PROMPT from bolna.helpers.logger_config import configure_logger +from ..knowledgebase import Knowledgebase load_dotenv() logger = configure_logger(__name__) +class ContextualConversationalAgent(BaseAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.knowledgebase = Knowledgebase() + + def add_to_knowledgebase(self, document): + self.knowledgebase.add_document(document) + + async def get_response(self, input_text): + relevant_info = self.knowledgebase.query(input_text) + context = "Relevant information: " + " ".join(relevant_info) + + # (keep existing logic, but add context to the prompt) + messages = self.history + [{"role": "user", "content": input_text}] + response = await self.llm.get_chat_response(messages + [{"role": "system", "content": context}]) + return response class StreamingContextualAgent(BaseAgent): def __init__(self, llm): @@ -22,11 +39,9 @@ async def check_for_completion(self, messages, check_for_completion_prompt = CHE prompt = [ {'role': 'system', 'content': check_for_completion_prompt}, {'role': 'user', 'content': format_messages(messages, use_system_prompt=True)}] - answer = None response = await self.conversation_completion_llm.generate(prompt, True, False, request_json=True) answer = json.loads(response) - logger.info('Agent: {}'.format(answer['answer'])) return answer @@ -34,3 +49,5 @@ async def generate(self, history, synthesize=False): async for token in self.llm.generate_stream(history, synthesize=synthesize): logger.info('Agent: {}'.format(token)) yield token + +# (keep the rest of the file as is) diff --git a/bolna/knowledgebase.py b/bolna/knowledgebase.py new file mode 100644 index 00000000..05c4c32b --- /dev/null +++ b/bolna/knowledgebase.py @@ -0,0 +1,28 @@ +from sentence_transformers import SentenceTransformer +import faiss +import numpy as np + +class Knowledgebase: + def __init__(self): + # Initialize the sentence transformer model and FAISS index + self.model = SentenceTransformer('all-MiniLM-L6-v2') + self.index = faiss.IndexFlatL2(384) # 384 is the dimension of the chosen model + self.documents = [] + + def add_document(self, document): + # Split the document into sentences and add them to the index + sentences = document.split('.') + embeddings = self.model.encode(sentences) + self.index.add(np.array(embeddings)) + self.documents.extend(sentences) + + def query(self, question, k=5): + # Encode the question and find the k most similar sentences + question_embedding = self.model.encode([question]) + _, indices = self.index.search(np.array(question_embedding), k) + return [self.documents[i] for i in indices[0]] + +# Usage example +# kb = Knowledgebase() +# kb.add_document("Your document content here.") +# relevant_info = kb.query("User's question here") diff --git a/requirements.txt b/requirements.txt index 534bde0f..ab10ba51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,5 @@ uvicorn==0.22.0 websockets==10.4 onnxruntime==1.16.3 scipy==1.11.4 -uvloop==0.19.0 \ No newline at end of file +sentence-transformers +faiss-cpu