Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add New POST Endpoint /memory/recall #976

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions core/cat/routes/memory/points.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Dict, List
from pydantic import BaseModel
from fastapi import Query, Request, APIRouter, HTTPException, Depends
from fastapi import Query, Body, Request, APIRouter, HTTPException, Depends
import time

from cat.auth.connection import HTTPAuth
from cat.auth.permissions import AuthPermission, AuthResource
from cat.memory.vector_memory import VectorMemory
from cat.looking_glass.stray_cat import StrayCat

from cat.log import log

class MemoryPointBase(BaseModel):
content: str
Expand All @@ -24,14 +24,15 @@ class MemoryPoint(MemoryPointBase):


# GET memories from recall
@router.get("/recall")
@router.get("/recall", deprecated=True)
async def recall_memory_points_from_text(
request: Request,
text: str = Query(description="Find memories similar to this text."),
k: int = Query(default=100, description="How many memories to return."),
stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)),
) -> Dict:
"""Search k memories similar to given text."""
log.warning("Deprecated: This endpoint will be removed in the next major version.")

# Embed the query to plot it in the Memory page
query_embedding = stray.embedder.embed_query(text)
Expand Down Expand Up @@ -76,6 +77,92 @@ async def recall_memory_points_from_text(
},
}

# POST memories from recall
@router.post("/recall")
async def recall_memory_points(
request: Request,
text: str = Body(description="Find memories similar to this text."),
k: int = Body(default=100, description="How many memories to return."),
metadata: Dict = Body(default={},
description="Flat dictionary where each key-value pair represents a filter."
"The memory points returned will match the specified metadata criteria."
),
stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)),
) -> Dict:
"""Search k memories similar to given text with specified metadata criteria.

Example
----------
```
collection = "episodic"
content = "MIAO!"
metadata = {"custom_key": "custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}
# create a point
res = requests.post(
f"http://localhost:1865/memory/collections/{collection}/points", json=req_json
)

# recall with metadata
req_json = {
"text": "CAT",
"metadata":{"custom_key":"custom_value"}
}
res = requests.post(
f"http://localhost:1865/memory/recall", json=req_json
)
json = res.json()
print(json)
```

"""

# Embed the query to plot it in the Memory page
query_embedding = stray.embedder.embed_query(text)
query = {
"text": text,
"vector": query_embedding,
}

# Loop over collections and retrieve nearby memories
collections = list(
stray.memory.vectors.collections.keys()
)
recalled = {}
for c in collections:
# only episodic collection has users
user_id = stray.user_id
if c == "episodic":
metadata["source"] = user_id
else:
metadata.pop("source", None)

memories = stray.memory.vectors.collections[c].recall_memories_from_embedding(
query_embedding, k=k, metadata=metadata
)

recalled[c] = []
for metadata_memories, score, vector, id in memories:
memory_dict = dict(metadata_memories)
memory_dict.pop("lc_kwargs", None) # langchain stuff, not needed
memory_dict["id"] = id
memory_dict["score"] = float(score)
memory_dict["vector"] = vector
recalled[c].append(memory_dict)

return {
"query": query,
"vectors": {
"embedder": str(
stray.embedder.__class__.__name__
), # TODO: should be the config class name
"collections": recalled,
},
}

# CREATE a point in memory
@router.post("/collections/{collection_id}/points", response_model=MemoryPoint)
async def create_memory_point(
Expand Down
51 changes: 47 additions & 4 deletions core/tests/routes/memory/test_memory_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# search on default startup memory
def test_memory_recall_default_success(client):
params = {"text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200

Expand All @@ -30,7 +30,7 @@ def test_memory_recall_default_success(client):

# search without query should throw error
def test_memory_recall_without_query_error(client):
response = client.get("/memory/recall")
response = client.post("/memory/recall")
assert response.status_code == 400


Expand All @@ -42,7 +42,7 @@ def test_memory_recall_success(client):

# recall
params = {"text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
Expand All @@ -58,8 +58,51 @@ def test_memory_recall_with_k_success(client):
# recall at max k memories
max_k = 2
params = {"k": max_k, "text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == max_k # only 2 of 6 memories recalled

# search with query and metadata
def test_memory_recall_with_metadata(client):
messages = [
{
"content": "MIAO_1",
"metadata": {"key_1":"v1","key_2":"v2"},
},
{
"content": "MIAO_2",
"metadata": {"key_1":"v1","key_2":"v3"},
},
{
"content": "MIAO_3",
"metadata": {},
}
]

# insert a new points with metadata
for req_json in messages:
client.post(
"/memory/collections/episodic/points", json=req_json
)

# recall with metadata
params = {"text": "MIAO", "metadata":{"key_1":"v1"}}
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == 2

# recall with metadata multiple keys in metadata
params = {"text": "MIAO", "metadata":{"key_1":"v1","key_2":"v2"}}
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == 1
assert episodic_memories[0]["page_content"] == "MIAO_1"