Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved chat command 🗣️💬 #6

Merged
merged 8 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/decisions/001-RAG-pipeline.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Initial RAG Pipeline

Date: 2023-12-19

Status: proposed

## Context

At the time of writing MDChat uses [LangChain](https://www.langchain.com/) and [FAISS](https://github.com/facebookresearch/faiss) for the majority of it's AI pipeline to create a [retrieval augmented generation or (RAG)](https://python.langchain.com/docs/use_cases/question_answering/#what-is-rag) pipeline. This pipeline gives LLMs access to each user's specific notes and all of the information within them.

Here's a rundown of how we're currently using these tools:

- **FAISS for indexing**: is a similarity search tool that "vectorizes" the content you feed into it as and allows you to store it as a database. Essentially this means FAISS converts your "text" into "vectors" (like arrays) of numbers that can then be compared to eachother for word/word-pairing similarity. MDChat uses FAISS to vectorize the notes you provide it so they can be relatively easily compared computationally.
- **LangChain for RAG**: is a tool that simplifies working with multiple LLMs via the same interface. Specifically, MDChat uses LangChain to prepare our FAISS vectors by splitting them into more digestible pieces, converting them to embeddings that LLMs can inteface with and finally wrapping those embeddings up in a nice package with the LLM via [`RetrievalQAWithSourcesChain`](https://api.python.langchain.com/en/latest/chains/langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain.html?highlight=retrievalqawithsourceschain#).

### Successes with this approach

In my testing, I've found this method to be pretty good at referencing specific singular notes and finding information and "connecting dots" that I hadn't considered previously.

### Issues with this approach

We've had some good success with this method in general, but:
- it seems to have troulbe with more "general" questions about your notes overall; it can answer questions about one note but struggles with answering questions about your notes as a whole
- Another issue with this approach is speed and compatability; FAISS in particular requires a dGPU and is *very slow* and hard to serialized (currently it's stored in memory). This makes initial prompts slow even on expensive computers.

## Decision

We will continue to use this approach until we can find a more effective pipeline, but experimentation and testing should be encouraged!

## Consequences

Issues with more comprehensive understanding of notes will continue until either prompt engineering, better embeddings, or other changes to the pipeline yield better results.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mdchat"
version = "0.2.0"
version = "1.0.0"
description = "a CLI that lets you chat with your markdown notes"
authors = ["Mykal Machon <[email protected]>"]
license = "see LICENSE"
Expand Down
2 changes: 1 addition & 1 deletion src/mdchat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__app_name__ = "MDChat"
__version__ = "0.2.0"
__version__ = "1.0.0"
49 changes: 18 additions & 31 deletions src/mdchat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,42 @@
- load the content of those similar notes into the LLM context chain
- allow the LLM to generate a response based on the context chain (RAG)
"""
import os
import pickle
from datetime import datetime
from pathlib import Path

import faiss
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings


# I know this is weird, but trust me it helps.
chat_context = """
chat_context = f"""
You are a chatbot that helps people search through their notes.
The content you're aware of is a set of notes stored on your user's filesystem.
The current date is {datetime.now().strftime("%m/%d/%Y, %H:%M:%S")}
The content you're aware of is a set of notes stored on your user's computer.
Your goal is to summarize and discuss the content of these files and share your sources.
You are to do this while being kind and with a great attitude.
Always take a deep breath before searching; good searches will result in a $2000 cash tip!
"""

def pickle_store(store, index, db_path):
""" Pickle the store and index to disk """
faiss.write_index(index, f"{db_path}/docs.index")

with open(f"{db_path}/store.pkl", "wb") as store_file:
pickle.dump(store, store_file)

class Chatbot:
def __init__(self, notes_folder, db_path, open_ai_key, open_ai_model, force_new: bool = False):
def __init__(self, notes_folder, db_path, open_ai_key, open_ai_model):
# TODO: validate data passed in here
self.notes_folder = notes_folder
self.db_path = db_path
self.open_ai_key = open_ai_key
self.open_ai_model = open_ai_model

self.chain = None
self.chat_history = []

self.index = None
self.store = None

# TODO: this is a hack to get around the fact that we can't pickle self
self.load_db_and_index(force_new=True)
self.load_db_and_index()
self._create_chain()

def _create_chain(self):
Expand Down Expand Up @@ -82,7 +75,7 @@ def _create_store_and_index(self):

# split notes into chunks and store them with metadata for the source
# the chunk size ensures each note fits into contet of the LLM prompt.
text_splitter = CharacterTextSplitter(chunk_size=1500, separator="\n")
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200, separator="\n")
docs = []
meta = []

Expand All @@ -94,31 +87,25 @@ def _create_store_and_index(self):
meta.extend([{"source": sources[idx]}] * len(splits))

# finally create the store and index; save them to disk
# TODO: we have to move this out of the class cause it can't pickle self
store = FAISS.from_texts(docs, OpenAIEmbeddings(openai_api_key=self.open_ai_key), metadatas=meta)
# pickle_store(store, store.index, self.db_path)

return [store, store.index]

def load_db_and_index(self, force_new):
""" Load the index from disk """
# check if index exists at db_path if not, create it and load it into memory
if force_new or (not os.path.exists(f"{self.db_path}/store.pkl") and not os.path.exists(f"{self.db_path}/docs.index")):
[store, index] = self._create_store_and_index()
self.store = store
self.index = index
return
def load_db_and_index(self):
""" create a new index and vector store from the note files """
[store, index] = self._create_store_and_index()
self.store = store
self.index = index
return


self.index = faiss.read_index(f"{self.db_path}/docs.index")
with open(f"{self.db_path}/store.pkl", "rb") as store_file:
self.store = pickle.load(store_file)
self.store.index = self.index

def query(self, query):
""" Query the index for similar notes """
if not self.chain or not query:
raise TypeError

response = self.chain({"question": f"{chat_context} {query}"})
response = self.chain({"question": f"{chat_context} {query}", "chat_history": self.chat_history})
self.chat_history.extend([HumanMessage(content=query), AIMessage(content=response.get("answer", "no response found"))])
return response

18 changes: 3 additions & 15 deletions src/mdchat/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mdchat.chatbot import Chatbot

from mdchat.commands.config import cli_config
from mdchat.commands.chat import cli_chat

app = typer.Typer()

Expand Down Expand Up @@ -51,18 +52,5 @@ def config():


@app.command()
def chat(question: str):
"""
Chat with your notes.
"""
typer.echo("loading...\n")
bot = Chatbot(
notes_folder=get_config("note_path"),
db_path=CONFIG_DIR_PATH,
open_ai_key=get_config("open_ai_key"),
open_ai_model=get_config("open_ai_model"),
)
result = bot.query(question)
typer.echo(result.get("answer"))
typer.echo(result.get("sources", "no sources found with this information"))
typer.Exit()
def chat():
cli_chat(typer)
8 changes: 0 additions & 8 deletions src/mdchat/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,6 @@
from typer import Typer
from rich import print


def validate_open_ai_key(api_key):
if api_key is None:
return False
pattern = r"sk-[A-Za-z0-9_-]{32}"
return re.match(pattern, api_key) is not None


def config_prompt(
key: str,
curr_prompt: str,
Expand Down
51 changes: 51 additions & 0 deletions src/mdchat/commands/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typer import Typer
from rich import print
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn

from mdchat.config import get_config, CONFIG_DIR_PATH, check_if_config_is_valid
from mdchat.chatbot import Chatbot


def cli_show_progress(task_description: str, task_func: callable):
"""Show progress for a task"""
with Progress(
SpinnerColumn(),
TextColumn("[bold green]{task.description}[/bold green]"),
transient=True,
) as progress:
task = progress.add_task(task_description, total=1)
ret_val = task_func()
progress.update(task, advance=1, completed=1)
return ret_val


def cli_chat(typer: Typer):
"""
This initiates a continious chat with mdchat.
It will create a new index from your notes and then allow
you to chat back and forth in a continious chain.
"""
# validate that config is valid
config_valid = check_if_config_is_valid()
if not config_valid:
print(f"Config is invalid.\nPlease run [bold blue]mdchat config[/bold blue]")
typer.Exit()
return

# load initial model
bot = cli_show_progress(
"Indexing your notes...",
lambda: Chatbot(
notes_folder=get_config("note_path"),
db_path=CONFIG_DIR_PATH,
open_ai_key=get_config("open_ai_key"),
open_ai_model=get_config("open_ai_model"),
),
)

query = None
while query != "exit":
query = typer.prompt("you")
result = cli_show_progress("Generating a response...", lambda: bot.query(query))
print(Panel.fit(f"[bold blue]mdchat[/bold blue]: {result.get('answer', 'no response found')}\n[bold blue]sources[/bold blue]: {result.get('sources', 'no sources found')}"))
4 changes: 2 additions & 2 deletions src/mdchat/commands/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from mdchat.commands import config_prompt, validate_open_ai_key

from mdchat.commands import config_prompt
from mdchat.utils import validate_open_ai_key

def cli_config(typer):
"""
Expand Down
30 changes: 30 additions & 0 deletions src/mdchat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,24 @@
import os
import pathlib
import json
import platform

from mdchat.utils import validate_open_ai_key

# TODO: use ~ for home directory on linux/mac, use %USERPROFILE% on windows
NOTE_PATH_DEFAULT = None
if platform.system() == "Windows":
NOTE_PATH_DEFAULT = os.path.expandvars(r"%USERPROFILE%\Documents")
else:
NOTE_PATH_DEFAULT = os.path.expanduser("~/notes")

NOTE_PATH = os.getenv("NOTECHAT_NOTE_PATH", "~/notes")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", None)

CONFIG_DIR_PATH = pathlib.Path.home() / ".mdchat"
CONFIG_FILE_PATH = CONFIG_DIR_PATH / "config.json"


def init_app():
"""Initialize the app"""
if not _check_if_config_exists():
Expand Down Expand Up @@ -47,6 +57,26 @@ def _check_if_config_exists():
return os.path.exists(CONFIG_FILE_PATH)


def check_if_config_is_valid():
"""Check if the config file is valid"""
# TODO: generalize validators here and rework config in general
# config should be more general and easier to extend.
if not _check_if_config_exists():
return False
if get_config("note_path") is None or not os.path.exists(get_config("note_path")):
return False
if get_config("open_ai_key") is None or not validate_open_ai_key(
get_config("open_ai_key")
):
return False
if get_config("open_ai_model") is None or get_config("open_ai_model") not in [
"gpt-3.5-turbo",
"gpt-4",
]:
return False
return True


def _init_config_files():
"""Initialize the config file"""
try:
Expand Down
12 changes: 12 additions & 0 deletions src/mdchat/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import re

# validators
def validate_open_ai_key(api_key):
"""
validate the general format of an openai api key.
does not garuntee that the key is valid.
"""
if api_key is None:
return False
pattern = r"sk-[A-Za-z0-9_-]{32}"
return re.match(pattern, api_key) is not None