Skip to content

Commit

Permalink
Merge pull request #976 from dave90/issue_899
Browse files Browse the repository at this point in the history
Add New POST Endpoint /memory/recall
  • Loading branch information
pieroit authored Nov 19, 2024
2 parents 4e74923 + 8c24b05 commit b55773d
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 7 deletions.
93 changes: 90 additions & 3 deletions core/cat/routes/memory/points.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Dict, List
from pydantic import BaseModel
from fastapi import Query, Request, APIRouter, HTTPException, Depends
from fastapi import Query, Body, Request, APIRouter, HTTPException, Depends
import time

from cat.auth.connection import HTTPAuth
from cat.auth.permissions import AuthPermission, AuthResource
from cat.memory.vector_memory import VectorMemory
from cat.looking_glass.stray_cat import StrayCat

from cat.log import log

class MemoryPointBase(BaseModel):
content: str
Expand All @@ -24,14 +24,15 @@ class MemoryPoint(MemoryPointBase):


# GET memories from recall
@router.get("/recall")
@router.get("/recall", deprecated=True)
async def recall_memory_points_from_text(
request: Request,
text: str = Query(description="Find memories similar to this text."),
k: int = Query(default=100, description="How many memories to return."),
stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)),
) -> Dict:
"""Search k memories similar to given text."""
log.warning("Deprecated: This endpoint will be removed in the next major version.")

# Embed the query to plot it in the Memory page
query_embedding = stray.embedder.embed_query(text)
Expand Down Expand Up @@ -76,6 +77,92 @@ async def recall_memory_points_from_text(
},
}

# POST memories from recall
@router.post("/recall")
async def recall_memory_points(
request: Request,
text: str = Body(description="Find memories similar to this text."),
k: int = Body(default=100, description="How many memories to return."),
metadata: Dict = Body(default={},
description="Flat dictionary where each key-value pair represents a filter."
"The memory points returned will match the specified metadata criteria."
),
stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)),
) -> Dict:
"""Search k memories similar to given text with specified metadata criteria.
Example
----------
```
collection = "episodic"
content = "MIAO!"
metadata = {"custom_key": "custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}
# create a point
res = requests.post(
f"http://localhost:1865/memory/collections/{collection}/points", json=req_json
)
# recall with metadata
req_json = {
"text": "CAT",
"metadata":{"custom_key":"custom_value"}
}
res = requests.post(
f"http://localhost:1865/memory/recall", json=req_json
)
json = res.json()
print(json)
```
"""

# Embed the query to plot it in the Memory page
query_embedding = stray.embedder.embed_query(text)
query = {
"text": text,
"vector": query_embedding,
}

# Loop over collections and retrieve nearby memories
collections = list(
stray.memory.vectors.collections.keys()
)
recalled = {}
for c in collections:
# only episodic collection has users
user_id = stray.user_id
if c == "episodic":
metadata["source"] = user_id
else:
metadata.pop("source", None)

memories = stray.memory.vectors.collections[c].recall_memories_from_embedding(
query_embedding, k=k, metadata=metadata
)

recalled[c] = []
for metadata_memories, score, vector, id in memories:
memory_dict = dict(metadata_memories)
memory_dict.pop("lc_kwargs", None) # langchain stuff, not needed
memory_dict["id"] = id
memory_dict["score"] = float(score)
memory_dict["vector"] = vector
recalled[c].append(memory_dict)

return {
"query": query,
"vectors": {
"embedder": str(
stray.embedder.__class__.__name__
), # TODO: should be the config class name
"collections": recalled,
},
}

# CREATE a point in memory
@router.post("/collections/{collection_id}/points", response_model=MemoryPoint)
async def create_memory_point(
Expand Down
51 changes: 47 additions & 4 deletions core/tests/routes/memory/test_memory_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# search on default startup memory
def test_memory_recall_default_success(client):
params = {"text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200

Expand All @@ -30,7 +30,7 @@ def test_memory_recall_default_success(client):

# search without query should throw error
def test_memory_recall_without_query_error(client):
response = client.get("/memory/recall")
response = client.post("/memory/recall")
assert response.status_code == 400


Expand All @@ -42,7 +42,7 @@ def test_memory_recall_success(client):

# recall
params = {"text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
Expand All @@ -58,8 +58,51 @@ def test_memory_recall_with_k_success(client):
# recall at max k memories
max_k = 2
params = {"k": max_k, "text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == max_k # only 2 of 6 memories recalled

# search with query and metadata
def test_memory_recall_with_metadata(client):
messages = [
{
"content": "MIAO_1",
"metadata": {"key_1":"v1","key_2":"v2"},
},
{
"content": "MIAO_2",
"metadata": {"key_1":"v1","key_2":"v3"},
},
{
"content": "MIAO_3",
"metadata": {},
}
]

# insert a new points with metadata
for req_json in messages:
client.post(
"/memory/collections/episodic/points", json=req_json
)

# recall with metadata
params = {"text": "MIAO", "metadata":{"key_1":"v1"}}
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == 2

# recall with metadata multiple keys in metadata
params = {"text": "MIAO", "metadata":{"key_1":"v1","key_2":"v2"}}
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == 1
assert episodic_memories[0]["page_content"] == "MIAO_1"



0 comments on commit b55773d

Please sign in to comment.