Skip to content

Commit

Permalink
use local cache for load_prompt (#503)
Browse files Browse the repository at this point in the history
return cached prompt if server is unavailable
  • Loading branch information
sachinpad authored Dec 27, 2024
1 parent 8b8fc9d commit 3ab0dd9
Show file tree
Hide file tree
Showing 12 changed files with 853 additions and 12 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/py.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ jobs:
pylint --errors-only $(git ls-files '*.py')
- name: Run tests
run: |
python -m unittest discover ./core/py/src "*_test.py"
python -m unittest discover ./core/py/src
python -m unittest discover ./py/src
43 changes: 43 additions & 0 deletions js/src/prompt-cache/disk-cache.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { DiskCache } from "./disk-cache";
import { tmpdir } from "os";
import { beforeEach, describe, it, afterEach, expect } from "vitest";
import { configureNode } from "../node";
import iso from "../isomorph";

describe("DiskCache", () => {
configureNode();
Expand Down Expand Up @@ -75,4 +76,46 @@ describe("DiskCache", () => {
// Should throw on corrupted data.
await expect(cache.get("test-key")).rejects.toThrow();
});

it("should throw when eviction stat fails", async () => {
// Fill cache.
for (let i = 0; i < 3; i++) {
await cache.set(`key${i}`, { value: i });
}

// Fake stat to fail for one file.
const origStat = iso.stat;
iso.stat = async (path: string) => {
if (path.endsWith("key0")) {
throw new Error("stat error");
}
return origStat!(path);
};

// Should throw when trying to get stats during eviction.
await expect(cache.set("key3", { value: 3 })).rejects.toThrow("stat error");

iso.stat = origStat;
});

it("should throw when eviction unlink fails", async () => {
// Fill cache.
for (let i = 0; i < 3; i++) {
await cache.set(`key${i}`, { value: i });
await new Promise((resolve) => setTimeout(resolve, 100)); // Ensure different mtimes
}

// Fake unlink to fail.
const origUnlink = iso.unlink;
iso.unlink = async () => {
throw new Error("unlink error");
};

// Should throw when trying to remove files during eviction.
await expect(cache.set("key3", { value: 3 })).rejects.toThrow(
"unlink error",
);

iso.unlink = origUnlink;
});
});
3 changes: 3 additions & 0 deletions js/src/prompt-cache/prompt-cache.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ describe("PromptCache", () => {
// Memory cache should still be updated.
const result = await cache.get(testKey);
expect(result).toEqual(testPrompt);

// Restore permissions so cleanup can happen.
await fs.chmod(cacheDir, 0o777);
});

it("should handle disk read errors", async () => {
Expand Down
64 changes: 53 additions & 11 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
from .merge_row_batch import batch_items, merge_row_batch
from .object import DEFAULT_IS_LEGACY_DATASET, ensure_dataset_record, make_legacy_event
from .prompt import BRAINTRUST_PARAMS, PromptBlockData, PromptSchema
from .prompt_cache.disk_cache import DiskCache
from .prompt_cache.lru_cache import LRUCache
from .prompt_cache.prompt_cache import PromptCache
from .span_identifier_v3 import SpanComponentsV3, SpanObjectTypeV3
from .span_types import SpanTypeAttribute
from .types import (
Expand Down Expand Up @@ -294,6 +297,20 @@ def default_get_api_conn():

self.reset_login_info()

self._prompt_cache = PromptCache(
memory_cache=LRUCache(
max_size=int(os.environ.get("BRAINTRUST_PROMPT_CACHE_MEMORY_MAX_SIZE", str(1 << 10)))
),
disk_cache=DiskCache(
cache_dir=os.environ.get(
"BRAINTRUST_PROMPT_CACHE_DIR", f"{os.environ.get('HOME')}/.braintrust/prompt_cache"
),
max_size=int(os.environ.get("BRAINTRUST_PROMPT_CACHE_DISK_MAX_SIZE", str(1 << 20))),
serializer=lambda x: x.as_dict(),
deserializer=PromptSchema.from_dict_deep,
),
)

def reset_login_info(self):
self.app_url: Optional[str] = None
self.app_public_url: Optional[str] = None
Expand Down Expand Up @@ -1199,24 +1216,49 @@ def load_prompt(
raise ValueError("Must specify slug")

def compute_metadata():
login(org_name=org_name, api_key=api_key, app_url=app_url)
args = _populate_args(
{
"project_name": project,
"project_id": project_id,
"slug": slug,
"version": version,
},
)
response = _state.api_conn().get_json("/v1/prompt", args)
try:
login(org_name=org_name, api_key=api_key, app_url=app_url)
args = _populate_args(
{
"project_name": project,
"project_id": project_id,
"slug": slug,
"version": version,
},
)
response = _state.api_conn().get_json("/v1/prompt", args)
except Exception as server_error:
eprint(f"Failed to load prompt, attempting to fall back to cache: {server_error}")
try:
return _state._prompt_cache.get(
slug,
version=str(version) if version else "latest",
project_id=project_id,
project_name=project,
)
except Exception as cache_error:
raise ValueError(
f"Prompt {slug} (version {version or 'latest'}) not found in {project or project_id} (not found on server or in local cache): {cache_error}"
) from server_error
if response is None or "objects" not in response or len(response["objects"]) == 0:
raise ValueError(f"Prompt {slug} not found in project {project or project_id}.")
elif len(response["objects"]) > 1:
raise ValueError(
f"Multiple prompts found with slug {slug} in project {project or project_id}. This should never happen."
)
resp_prompt = response["objects"][0]
return PromptSchema.from_dict_deep(resp_prompt)
prompt = PromptSchema.from_dict_deep(resp_prompt)
try:
_state._prompt_cache.set(
slug,
str(version) if version else "latest",
prompt,
project_id=project_id,
project_name=project,
)
except Exception as e:
eprint(f"Failed to store prompt in cache: {e}")
return prompt

return Prompt(
lazy_metadata=LazyValue(compute_metadata, use_mutex=True), defaults=defaults or {}, no_trace=no_trace
Expand Down
Empty file.
149 changes: 149 additions & 0 deletions py/src/braintrust/prompt_cache/disk_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
A module providing a persistent disk-based cache implementation.
This module contains a generic disk cache that can store serializable objects of any type.
The cache persists entries as compressed files on disk and implements an LRU (Least Recently Used)
eviction policy based on file modification times. It provides thread-safe access to cached items
and handles file system errors gracefully.
"""

import gzip
import json
import os
from typing import Any, Callable, Generic, List, Optional, TypeVar

T = TypeVar("T")


class DiskCache(Generic[T]):
"""
A persistent filesystem-based cache implementation.
This cache stores entries as compressed files on disk and implements an LRU eviction
policy based on file modification times (mtime). While access times (atime) would be more
semantically accurate for LRU, we use mtime because:
1. Many modern filesystems mount with noatime for performance reasons.
2. Even when atime updates are enabled, they may be subject to update delays.
3. mtime updates are more reliably supported across different filesystems.
"""

def __init__(
self,
cache_dir: str,
max_size: Optional[int] = None,
serializer: Optional[Callable[[T], Any]] = None,
deserializer: Optional[Callable[[Any], T]] = None,
):
"""
Creates a new DiskCache instance.
Args:
cache_dir: Directory where cache files will be stored.
max_size: Maximum number of entries to store in the cache.
If not specified, the cache will grow unbounded.
serializer: Optional function to convert values to JSON-serializable format.
deserializer: Optional function to convert JSON-deserialized data back to original type.
Should be the inverse of serializer.
Example:
# Create a cache for PromptSchema objects using its serialization methods.
cache = DiskCache[PromptSchema](
cache_dir="cache",
serializer=lambda x: x.as_dict(),
deserializer=PromptSchema.from_dict_deep
)
"""
self._dir = cache_dir
self._max_size = max_size
self._serializer = serializer
self._deserializer = deserializer

def _get_entry_path(self, key: str) -> str:
"""Gets the file path for a cache entry."""
return os.path.join(self._dir, key)

def get(self, key: str) -> T:
"""
Retrieves a value from the cache.
Updates the entry's access time when read.
Args:
key: The key to look up in the cache.
Returns:
The cached value.
Raises:
KeyError: If the key is not found in the cache.
RuntimeError: If there is an error reading from the disk cache.
"""
try:
file_path = self._get_entry_path(key)
with gzip.open(file_path, "rb") as f:
data = json.loads(f.read().decode("utf-8"))
if self._deserializer is not None:
data = self._deserializer(data)

# Update both access and modification times.
os.utime(file_path, None)
return data
except FileNotFoundError:
raise KeyError(f"Cache key not found: {key}")
except Exception as e:
raise RuntimeError(f"Failed to read from disk cache: {e}") from e

def set(self, key: str, value: T) -> None:
"""
Stores a value in the cache.
If the cache is at its maximum size, the least recently used entries will be evicted.
Args:
key: The key to store the value under.
value: The value to store in the cache.
Raises:
RuntimeError: If there is an error writing to the disk cache.
"""
try:
os.makedirs(self._dir, exist_ok=True)
file_path = self._get_entry_path(key)

with gzip.open(file_path, "wb") as f:
if self._serializer is not None:
value = self._serializer(value)
f.write(json.dumps(value).encode("utf-8"))

if self._max_size:
entries = os.listdir(self._dir)
if len(entries) > self._max_size:
self._evict_oldest(entries)
except Exception as e:
raise RuntimeError(f"Failed to write to disk cache: {e}") from e

def _evict_oldest(self, entries: List[str]) -> None:
"""
Evicts the oldest entries from the cache until it is under the maximum size.
This method requires that self.max_size is not None, as it is only called when
evicting entries to maintain the maximum cache size.
Args:
entries: List of cache entry filenames.
Raises:
OSError: If there is an error getting file mtimes or removing entries.
"""
assert self._max_size is not None

stats = []
for entry in entries:
path = self._get_entry_path(entry)
mtime = os.path.getmtime(path)
stats.append({"name": entry, "mtime": mtime})

stats.sort(key=lambda x: x["mtime"])
to_remove = stats[0 : len(stats) - self._max_size]

for entry in to_remove:
os.unlink(self._get_entry_path(entry["name"]))
77 changes: 77 additions & 0 deletions py/src/braintrust/prompt_cache/lru_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
A module providing an LRU (Least Recently Used) cache implementation.
This module contains a generic LRU cache that can store key-value pairs of any type.
The cache maintains items in order of use and can optionally evict least recently
used items when it reaches a maximum size. The implementation uses an OrderedDict
for O(1) access and update operations.
"""

from typing import Generic, Optional, OrderedDict, TypeVar

K = TypeVar("K")
V = TypeVar("V")


class LRUCache(Generic[K, V]):
"""
A Least Recently Used (LRU) cache implementation.
This cache maintains items in order of use, evicting the least recently used item
when the cache reaches its maximum size (if specified). Items are considered "used"
when they are either added to the cache or retrieved from it.
If no maximum size is specified, the cache will grow unbounded.
Args:
max_size: Maximum number of items to store in the cache.
If not specified, the cache will grow unbounded.
"""

def __init__(self, max_size: Optional[int] = None):
self._cache: OrderedDict[K, V] = OrderedDict()
self._max_size = max_size

def get(self, key: K) -> V:
"""
Retrieves a value from the cache.
If the key exists, the item is marked as most recently used.
Args:
key: The key to look up.
Returns:
The cached value.
Raises:
KeyError: If the key is not found in the cache.
"""
if key not in self._cache:
raise KeyError(f"Cache key not found: {key}")

# Refresh key by moving to end of OrderedDict.
value = self._cache.pop(key)
self._cache[key] = value
return value

def set(self, key: K, value: V) -> None:
"""
Stores a value in the cache.
If the key already exists, the value is updated and marked as most recently used.
If the cache is at its maximum size, the least recently used item is evicted.
Args:
key: The key to store.
value: The value to store.
"""
if key in self._cache:
self._cache.pop(key)
elif self._max_size and len(self._cache) >= self._max_size:
# Remove oldest item (first item in ordered dict).
self._cache.popitem(last=False)

self._cache[key] = value

def clear(self) -> None:
"""Removes all items from the cache."""
self._cache.clear()
Loading

0 comments on commit 3ab0dd9

Please sign in to comment.