Skip to content

Commit

Permalink
Add some unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
postrational committed Nov 12, 2024
1 parent f993eb5 commit 504f41a
Show file tree
Hide file tree
Showing 13 changed files with 255 additions and 8 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Please set this as an environment variable before running the `muffin` commands.

$ export OPENAI_API_KEY=sk-proj-XXXX........

### Create a Chat Agent based on a directory of documents
### Create a chat agent based on a directory of documents

You can generate a RAG index based on a directory of files (e.g. TXT, PDF, EPUB, etc.).

Expand All @@ -42,7 +42,7 @@ Start the chat agent using the following command:

(venv) $ muffin chat my_agent

### Generate a RAG index based on your Zotero library
### Create a chat agent based on your Zotero library

In order to use Ragamuffin with Zotero, you need to generate a [Zotero API key][zotero-key] and
an [OpenAI API key][openai-key]. Set these as environment variables before running `muffin`.
Expand All @@ -62,7 +62,7 @@ Later, you can chat with Ragamuffin using the `muffin chat` command:

(venv) $ muffin chat zotero_agent

### Generate a RAG index based on a Git repository
### Create a chat agent based on a Git repository

If you want to learn about a specific codebase, you can generate a RAG index based on a GitHub repository.

Expand Down
50 changes: 49 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "ragamuffin"
version = "0.4.1"
version = "0.4.2"
description = ""
authors = ["Michal Karzynski <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -43,6 +43,7 @@ types-python-dateutil = "^2.9.0.20241003"
types-redis = "^4.6.0.20241004"
types-requests = "^2.32.0.20241016"
types-tabulate = "^0.9.0.20240106"
pytest = "^8.3.3"

[tool.mypy]
ignore_missing_imports = true
Expand Down
2 changes: 1 addition & 1 deletion src/ragamuffin/models/model_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_embedding_model_by_name(name: str) -> BaseEmbedding:
except ValueError as e:
raise ConfigurationError(f"Unrecognized embedding model name: {name}") from e

if provider == "huggingface":
if provider == "huggingface.co":
return HuggingFaceEmbedding(model_name=model_name)

if provider == "openai":
Expand Down
9 changes: 8 additions & 1 deletion src/ragamuffin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_settings() -> dict[str, str | int | bool | None]:
"storage_type": os.environ.get("RAGAMUFFIN_STORAGE_TYPE", "file"),
"data_dir": os.environ.get("RAGAMUFFIN_DATA_DIR", user_data_dir("ragamuffin")),
"llm_model": os.environ.get("RAGAMUFFIN_LLM_MODEL", "openai/gpt-4o-mini"),
# Local model: "huggingface/BAAI/bge-m3", uses 1024-dimensional embeddings
# Local model: "huggingface.co/BAAI/bge-m3", uses 1024-dimensional embeddings
"embedding_model": os.environ.get("RAGAMUFFIN_EMBEDDING_MODEL", "openai/text-embedding-ada-002"),
"embedding_dimension": os.environ.get("RAGAMUFFIN_EMBEDDING_DIMENSION", 1536),
"debug_mode": os.environ.get("RAGAMUFFIN_DEBUG", False),
Expand All @@ -19,10 +19,17 @@ def get_settings() -> dict[str, str | int | bool | None]:
"zotero_api_key": os.environ.get("ZOTERO_API_KEY"),
"openai_api_key": os.environ.get("OPENAI_API_KEY"),
}

# Handle boolean values
for key in ["debug_mode"]:
value = settings[key]
if isinstance(value, str):
settings[key] = value.lower() in ["true", "1", "yes"]

# Handle integer values
for key in ["embedding_dimension"]:
value = settings[key]
if isinstance(value, str):
settings[key] = int(value)

return settings
4 changes: 3 additions & 1 deletion src/ragamuffin/storage/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@ def _validate_agent_name(self, agent_name: str) -> None:
def generate_index(self, agent_name: str, reader: BaseReader) -> BaseIndex:
"""Load the documents and create a RAG index."""
self._validate_agent_name(agent_name)
logger.info("Loading documents...")
documents = reader.load_data()

logger.info("Generating RAG embeddings...")
settings = get_settings()
embed_dim = ensure_int(settings.get("embedding_dimension"))
vector_store = CassandraVectorStore(table=agent_name, embedding_dimension=embed_dim)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
configure_llamaindex_embedding_model()

logger.info("Generating RAG embeddings...")
logger.info("Storing the index in Cassandra...")
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
index.storage_context.persist()
return index
Expand Down
2 changes: 2 additions & 0 deletions src/ragamuffin/storage/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_agent_storage_dir(self, agent_name: str) -> Path:

def generate_index(self, agent_name: str, reader: BaseReader) -> BaseIndex:
"""Load the documents and create a RAG index."""
logger.info("Loading documents...")
documents = reader.load_data()

# Configure chunking settings
Expand All @@ -37,6 +38,7 @@ def generate_index(self, agent_name: str, reader: BaseReader) -> BaseIndex:
# Build the index from documents and persist to disk
logger.info("Generating RAG embeddings...")
index = VectorStoreIndex.from_documents(documents)
logger.info("Storing the index in the file system...")
index.storage_context.persist(persist_dir=self.get_agent_storage_dir(agent_name))
return index

Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/conftest.py
Empty file.
42 changes: 42 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
from pathlib import Path

from click.testing import CliRunner

from ragamuffin.cli.muffin import cli
from ragamuffin.storage.utils import get_storage
from tests.utils import env_vars


@env_vars(
RAGAMUFFIN_STORAGE_TYPE="file",
RAGAMUFFIN_EMBEDDING_DIMENSION="312",
RAGAMUFFIN_EMBEDDING_MODEL="huggingface.co/huawei-noah/TinyBERT_General_4L_312D",
)
def test_muffin_cli(caplog):
caplog.set_level(logging.INFO, logger="ragamuffin")
runner = CliRunner()

agent_name = "test_agent"
test_data_path = Path(__file__).parent / "data" / "udhr"

result = runner.invoke(cli, ["generate", "from_files", agent_name, str(test_data_path)])
assert result.exit_code == 0
assert agent_name in get_storage().list_agents()

result = runner.invoke(cli, ["delete", agent_name])
assert result.exit_code == 0
assert agent_name not in get_storage().list_agents()

result = runner.invoke(cli, ["generate", "from_git", agent_name, "https://github.com/postrational/ragamuffin/"])
assert result.exit_code == 0
assert agent_name in get_storage().list_agents()

caplog.clear()
result = runner.invoke(cli, ["agents"])
assert result.exit_code == 0
assert agent_name in caplog.text

result = runner.invoke(cli, ["delete", agent_name])
assert result.exit_code == 0
assert agent_name not in get_storage().list_agents()
37 changes: 37 additions & 0 deletions tests/test_libraries_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pathlib import Path

from llama_index.core import Document, SimpleDirectoryReader

from ragamuffin.libraries.files import LocalLibrary
from tests.utils import seed


@seed(42)
def test_local_directory():
test_data_path = Path(__file__).parent / "data" / "udhr"
library = LocalLibrary(str(test_data_path))
reader = library.get_reader()

assert isinstance(reader, SimpleDirectoryReader)
assert len(reader.list_resources()) == 2

data = reader.load_data()
assert len(data) == 9

document = data[0]
assert isinstance(document, Document)
assert document.metadata["file_name"] == "udhr-en.pdf"
assert document.metadata["file_type"] == "application/pdf"
assert "progress and better standards of life in larger freedom" in document.text


def test_local_file():
test_data_path = Path(__file__).parent / "data" / "udhr" / "udhr-en.pdf"
library = LocalLibrary(str(test_data_path))
reader = library.get_reader()

assert isinstance(reader, SimpleDirectoryReader)
assert len(reader.list_resources()) == 1

data = reader.load_data()
assert len(data) == 8
31 changes: 31 additions & 0 deletions tests/test_storage_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

from ragamuffin.libraries.files import LocalLibrary
from ragamuffin.storage.file import FileStorage
from tests.utils import env_vars, seed


@seed(42)
def test_file_storage_create_agent():
test_data_path = Path(__file__).parent / "data" / "udhr"
library = LocalLibrary(str(test_data_path))
reader = library.get_reader()

storage = FileStorage()
agent_name = "test_agent"
with env_vars(
RAGAMUFFIN_EMBEDDING_DIMENSION="312",
RAGAMUFFIN_EMBEDDING_MODEL="huggingface.co/huawei-noah/TinyBERT_General_4L_312D",
):
storage.generate_index(agent_name, reader=reader)

assert agent_name in storage.list_agents()

index = storage.load_index(agent_name)
ingested_doc_metadata = list(index.ref_doc_info.values())[0].metadata
assert ingested_doc_metadata["file_name"] == "udhr-en.pdf"
assert ingested_doc_metadata["file_type"] == "application/pdf"
assert ingested_doc_metadata["page_label"] == "1"

storage.delete_agent(agent_name)
assert agent_name not in storage.list_agents()
77 changes: 77 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import random
from collections.abc import Callable
from contextlib import contextmanager
from functools import wraps
from typing import Any

import numpy as np
import torch


def seed(seed_value: int):
"""Decorator to set RNG seed values for reproducibility."""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Save current RNG states
prev_random_state = random.getstate()
prev_numpy_state = np.random.get_state()
prev_torch_state = torch.get_rng_state()

# Set the provided seed
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)

try:
# Run the test function
result = func(*args, **kwargs)
finally:
# Reset RNG states
random.setstate(prev_random_state)
np.random.set_state(prev_numpy_state)
torch.set_rng_state(prev_torch_state)

return result

return wrapper

return decorator


@contextmanager
def _env_vars(vars: dict[str, str]):
"""Context manager which temporarily sets environment variables."""
original_values = {key: os.getenv(key) for key in vars}

try:
for key, value in vars.items():
os.environ[key] = value
yield
finally:
for key, original_value in original_values.items():
if original_value is None:
del os.environ[key]
else:
os.environ[key] = original_value


def env_vars(func: Callable | None = None, **vars) -> Any:
"""Can be used as both a context manager and a decorator to set environment variables.
Usage:
- As a decorator: @env_vars(VAR1="value1", VAR2="value2")
- As a context manager: with env_vars(VAR1="value1", VAR2="value2"):
"""
if func is None:
# Used as a context manager
return _env_vars(vars)

# Used as a decorator
@wraps(func)
def wrapper(*args, **kwargs):
with _env_vars(vars):
return func(*args, **kwargs)

return wrapper

0 comments on commit 504f41a

Please sign in to comment.