Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
liu1700 committed Dec 22, 2024
1 parent 4c2a903 commit d85c209
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 6 deletions.
16 changes: 13 additions & 3 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def add(
metadata=None,
filters=None,
prompt=None,
infer=True,
):
"""
Create a new memory.
Expand All @@ -79,7 +80,8 @@ def add(
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
infer (bool, optional): Whether to use inference to add the memory. Defaults to True.
Returns:
dict: A dictionary containing the result of the memory addition operation.
result: dict of affected events with each dict has the following key:
Expand Down Expand Up @@ -111,7 +113,7 @@ def add(
messages = [{"role": "user", "content": messages}]

with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, infer)
future2 = executor.submit(self._add_to_graph, messages, filters)

concurrent.futures.wait([future1, future2])
Expand All @@ -134,9 +136,17 @@ def add(
)
return vector_store_result

def _add_to_vector_store(self, messages, metadata, filters):
def _add_to_vector_store(self, messages, metadata, filters, infer=True):
parsed_messages = parse_messages(messages)

if not infer:
messages_embeddings = self.embedding_model.embed(parsed_messages)
new_message_embeddings = {parsed_messages: messages_embeddings}
memory_id = self._create_memory(
data=parsed_messages, existing_embeddings=new_message_embeddings, metadata=metadata
)
return [{"id": memory_id, "memory": parsed_messages, "event": "ADD"}]

if self.custom_prompt:
system_prompt = self.custom_prompt
user_prompt = f"Input: {parsed_messages}"
Expand Down
75 changes: 74 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,87 @@ def test_add(memory_instance, version, enable_graph):
assert result["relations"] == []

memory_instance._add_to_vector_store.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}, True
)

# Remove the conditional assertion for _add_to_graph
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}
)

@pytest.mark.parametrize("version, enable_graph, infer", [
("v1.0", False, False),
("v1.1", True, True),
("v1.1", True, False)
])
def test_add_with_inference(memory_instance, version, enable_graph, infer):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph

# Setup mocks
memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3])
memory_instance.vector_store.insert = Mock()
memory_instance.vector_store.search = Mock(return_value=[])
memory_instance.db.add_history = Mock()
memory_instance._add_to_graph = Mock(return_value=[])

# Mock LLM responses for inference case
if infer:
memory_instance.llm.generate_response = Mock(side_effect=[
'{"facts": ["Test fact 1", "Test fact 2"]}', # First call for fact retrieval
'{"memory": [{"event": "ADD", "text": "Test fact 1"},{"event": "ADD", "text": "Test fact 2"}]}' # Second call for memory actions
])
else:
memory_instance.llm.generate_response = Mock()

# Execute
result = memory_instance.add(
messages=[{"role": "user", "content": "Test fact 1 Text fact 2"}],
user_id="test_user",
infer=infer
)

# Verify basic structure of result
assert "results" in result
assert "relations" in result
assert isinstance(result["results"], list)
assert isinstance(result["relations"], list)

# Verify LLM behavior
if infer:
# Should be called twice: once for fact retrieval, once for memory actions
assert memory_instance.llm.generate_response.call_count == 2

# Verify first call (fact retrieval)
first_call = memory_instance.llm.generate_response.call_args_list[0]
assert len(first_call[1]['messages']) == 2
assert first_call[1]['messages'][0]['role'] == 'system'
assert first_call[1]['messages'][1]['role'] == 'user'

# Verify embedding was called for the facts
assert memory_instance.embedding_model.embed.call_count == 2

# Verify vector store operations
assert memory_instance.vector_store.insert.call_count == 2
else:
# For non-inference case, should directly create memory without LLM
memory_instance.llm.generate_response.assert_not_called()
# Should still embed the original message
memory_instance.embedding_model.embed.assert_called_once_with("user: Test fact 1 Text fact 2\n")
memory_instance.vector_store.insert.assert_called_once()

# Verify graph behavior
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test fact 1 Text fact 2"}], {"user_id": "test_user"}
)

if version == "v1.1":
assert isinstance(result, dict)
assert "results" in result
assert "relations" in result
else:
assert isinstance(result["results"], list)


def test_get(memory_instance):
mock_memory = Mock(
Expand Down
3 changes: 1 addition & 2 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest

from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
from mem0.proxy.main import Chat, Completions, Mem0


Expand Down Expand Up @@ -94,4 +93,4 @@ def test_completions_create_with_system_message(mock_memory_client, mock_litellm

call_args = mock_litellm.completion.call_args[1]
assert call_args["messages"][0]["role"] == "system"
assert call_args["messages"][0]["content"] == MEMORY_ANSWER_PROMPT
assert call_args["messages"][0]["content"] == "You are a helpful assistant."

0 comments on commit d85c209

Please sign in to comment.