Skip to content

Commit

Permalink
Add semantic query builder
Browse files Browse the repository at this point in the history
  • Loading branch information
postrational committed Nov 10, 2024
1 parent c7f9914 commit b56bc77
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "ragamuffin"
version = "0.3.5"
version = "0.4.0"
description = ""
authors = ["Michal Karzynski <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions src/ragamuffin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from rich.logging import RichHandler

from ragamuffin.settings import get_settings
from ragamuffin import settings

if get_settings().get("debug_mode"):
if settings.get_settings().get("debug_mode"):
logging.basicConfig(
level="DEBUG", format="[%(name)s] %(message)s", datefmt="[%X]", handlers=[RichHandler(show_path=True)]
)
Expand Down
2 changes: 1 addition & 1 deletion src/ragamuffin/cli/muffin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ragamuffin.libraries.files import LocalLibrary
from ragamuffin.libraries.git_repo import GitLibrary
from ragamuffin.libraries.zotero import ZoteroLibrary
from ragamuffin.models.select import configure_llamaindex_embedding_model, get_llm_by_name
from ragamuffin.models.model_picker import configure_llamaindex_embedding_model, get_llm_by_name
from ragamuffin.settings import get_settings
from ragamuffin.storage.utils import get_storage

Expand Down
39 changes: 39 additions & 0 deletions src/ragamuffin/models/enhancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from ragamuffin.error_handling import ensure_string
from ragamuffin.models.model_picker import get_llm_by_name
from ragamuffin.settings import get_settings


class QueryEnhancer:

def __init__(self):
settings = get_settings()
llm_model = ensure_string(settings.get("llm_model"))
self.model = get_llm_by_name(llm_model)

def __call__(self, chat_history: list[dict]) -> str:
"""Enhance the last query in the chat history."""
return self.enhance(chat_history)

def enhance(self, chat_history: list[dict]) -> str:
"""Enhance the last query in the chat history."""
context_str = "\n".join(
[f"{idx + 1}. {msg['content']}" for idx, msg in enumerate(chat_history) if msg["role"] == "user"]
)
query_str = chat_history[-1]["content"]
prompt = (
"You are an expert Q&A system that is trusted around the world.\n"
"The user has provided a query and wants to search for matching sources.\n"
"Please rewrite the query and add keywords to improve the changes of finding relevant sources.\n"
"The search is based on semantic similarity as part of a RAG-based system.\n"
"There are some rule: \n"
"1. Only return the enhanced query, no other output. \n"
"2. Focus on the meaning of the last query and don't mix in previous queries unless it's needed. \n"
"All queries in the conversation:\n"
"---------------------\n"
f"{context_str}\n"
"---------------------\n"
"The query to enhance is:\n"
f"{query_str}\n"
)
response = self.model.complete(prompt)
return response.text
File renamed without changes.
2 changes: 1 addition & 1 deletion src/ragamuffin/storage/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from llama_index.vector_stores.cassandra import CassandraVectorStore

from ragamuffin.error_handling import ensure_int
from ragamuffin.models.select import configure_llamaindex_embedding_model
from ragamuffin.models.model_picker import configure_llamaindex_embedding_model
from ragamuffin.settings import get_settings
from ragamuffin.storage.interface import Storage

Expand Down
2 changes: 1 addition & 1 deletion src/ragamuffin/storage/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from llama_index.core.indices.base import BaseIndex
from llama_index.core.readers.base import BaseReader

from ragamuffin.models.select import configure_llamaindex_embedding_model
from ragamuffin.models.model_picker import configure_llamaindex_embedding_model
from ragamuffin.settings import get_settings
from ragamuffin.storage.interface import Storage

Expand Down
63 changes: 36 additions & 27 deletions src/ragamuffin/webui/gradio_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from gradio.themes.utils import colors, fonts
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.core.llama_pack import BaseLlamaPack
from llama_index.core.schema import NodeWithScore

from ragamuffin.models.enhancer import QueryEnhancer
from ragamuffin.models.highlighter import SemanticHighlighter


Expand All @@ -22,6 +24,7 @@ def __init__(
"""Init params."""
self.agent = agent
self.semantic_highlighter = SemanticHighlighter()
self.query_enhancer = QueryEnhancer()
self.title = f"Ragamuffin {snake_to_title_case(name)} Chat"

def get_modules(self) -> dict[str, Any]:
Expand Down Expand Up @@ -84,7 +87,7 @@ def respond(self, chat_history: list[dict]) -> Generator[tuple[list[dict], str],
query = chat_history[-1]["content"]
response = self.agent.stream_chat(query)

sources_html = self.generate_sources_html(query, response.sources)
sources_html = self.generate_sources_html(query, response.source_nodes)

chat_history.append({"role": "assistant", "content": ""})
for token in response.response_gen:
Expand All @@ -94,46 +97,51 @@ def respond(self, chat_history: list[dict]) -> Generator[tuple[list[dict], str],
def accept_message(self, user_message: str, chat_history: list[dict]) -> tuple[str, list[dict]]:
"""Accept the user message."""
chat_history.append({"role": "user", "content": user_message})
chat_history.append(
{
"role": "assistant",
"content": self.query_enhancer(chat_history),
"metadata": {"title": "🧠 Building semantic search query"},
}
)
return "", chat_history

def reset_chat(self) -> tuple[str, str, str]:
"""Reset the agent's chat history. And clear all dialogue boxes."""
self.agent.reset() # clear agent history
return "", "", ""

def generate_sources_html(self, query: str, sources: list) -> str:
def generate_sources_html(self, query: str, source_nodes: list[NodeWithScore]) -> str:
"""Generate HTML for the sources."""
output_html = ""
sources_text = []
nodes_info = []

if not sources:
if not source_nodes:
return "<p>No sources found.</p>"

# Collect all texts and their associated metadata
for source in sources:
source_nodes = source.raw_output.source_nodes if hasattr(source.raw_output, "source_nodes") else []
for node_with_score in source_nodes:
text_node = node_with_score.node
metadata = text_node.metadata
score = node_with_score.score
page = metadata.get("page_label")

filename = metadata.get("file_name", "Unknown Filename")
name = metadata.get("name", filename)
url = metadata.get("url")
filename_html = f"<a href='{url}' target='_blank'>{name}</a>" if url else f"<b>{name}</b>"

# Append text and metadata to lists
if text_node.text:
sources_text.append(text_node.text)
nodes_info.append(
{
"filename_html": filename_html,
"page": page,
"score": score,
}
)
for node_with_score in source_nodes:
text_node = node_with_score.node
metadata = text_node.metadata
score = node_with_score.score
page = metadata.get("page_label")

filename = metadata.get("file_name", "Unknown Filename")
name = metadata.get("name", filename)
url = metadata.get("url")
filename_html = f"<a href='{url}' target='_blank'>{name}</a>" if url else f"<b>{name}</b>"

# Append text and metadata to lists
if text_node.text:
sources_text.append(text_node.text)
nodes_info.append(
{
"filename_html": filename_html,
"page": page,
"score": score,
}
)

# Highlight the texts
sources_text = [html.escape(text) for text in sources_text]
Expand All @@ -142,7 +150,8 @@ def generate_sources_html(self, query: str, sources: list) -> str:
# Construct the output using the highlighted texts and metadata
for highlighted_text, info in zip(highlighted_texts, nodes_info, strict=False):
source_footer = f"<br>Page {info['page']}" if info["page"] else "<br>"
source_footer += f" ({info['score']:.2f})"
similarity_class = int(min(score * 10, 9))
source_footer += f" <span class='badge similarity-{similarity_class}'>{info['score']:.2f}</span>"
output_html += f"<p><b>{info['filename_html']}</b><br>{highlighted_text}{source_footer}</p>"

return output_html
Expand Down
12 changes: 11 additions & 1 deletion src/ragamuffin/webui/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
font-size: 12px;
overflow: auto;
overflow-x: hidden;
background-color: var(--background-fill-secondary);
border-radius: 0.5em;
}

#sources p {
Expand All @@ -29,11 +31,19 @@
text-decoration: underline;
}

a:visited {
#sources a:visited {
color: var(--link-text-color-visited) !important;
text-decoration: none;
}

.badge {
background-color: #28a745;
color: var(--block-label-text-color);
padding: 0.2em 0.4em;
border-radius: 0.2em;
font-size: 0.9em;
}

/*.similarity-0 { background-color: rgba(0, 255, 255, 0.1); }*/
/*.similarity-1 { background-color: rgba(0, 255, 255, 0.15); }*/
.similarity-2 { background-color: rgba(0, 255, 255, 0.2); }
Expand Down

0 comments on commit b56bc77

Please sign in to comment.