-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
test_cache.py
131 lines (98 loc) · 4.07 KB
/
test_cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import unittest
import json
from unittest.mock import patch, MagicMock
from application.cache import gen_cache_key, stream_cache, gen_cache
from application.utils import get_hash
# Test for gen_cache_key function
def test_make_gen_cache_key():
messages = [
{'role': 'user', 'content': 'test_user_message'},
{'role': 'system', 'content': 'test_system_message'},
]
model = "test_docgpt"
# Manually calculate the expected hash
expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}"
expected_hash = get_hash(expected_combined)
cache_key = gen_cache_key(*messages, model=model)
assert cache_key == expected_hash
def test_gen_cache_key_invalid_message_format():
# Test when messages is not a list
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
gen_cache_key("This is not a list", model="docgpt")
assert str(context.exception) == "All messages must be dictionaries."
# Test for gen_cache decorator
@patch('application.cache.get_redis_instance') # Mock the Redis client
def test_gen_cache_hit(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit
@gen_cache
def mock_function(self, model, messages):
return "new_result"
messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
# Assert
assert result == "cached_result" # Should return cached result
mock_redis_instance.get.assert_called_once() # Ensure Redis get was called
mock_redis_instance.set.assert_not_called() # Ensure the function result is not cached again
@patch('application.cache.get_redis_instance') # Mock the Redis client
def test_gen_cache_miss(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = None # Simulate a cache miss
@gen_cache
def mock_function(self, model, messages):
return "new_result"
messages = [
{'role': 'user', 'content': 'test_user_message'},
{'role': 'system', 'content': 'test_system_message'},
]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
# Assert
assert result == "new_result"
mock_redis_instance.get.assert_called_once()
@patch('application.cache.get_redis_instance')
def test_stream_cache_hit(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance
cached_chunk = json.dumps(["chunk1", "chunk2"]).encode('utf-8')
mock_redis_instance.get.return_value = cached_chunk
@stream_cache
def mock_function(self, model, messages, stream):
yield "new_chunk"
messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True))
# Assert
assert result == ["chunk1", "chunk2"] # Should return cached chunks
mock_redis_instance.get.assert_called_once()
mock_redis_instance.set.assert_not_called()
@patch('application.cache.get_redis_instance')
def test_stream_cache_miss(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = None # Simulate a cache miss
@stream_cache
def mock_function(self, model, messages, stream):
yield "new_chunk"
messages = [
{'role': 'user', 'content': 'This is the context'},
{'role': 'system', 'content': 'Some other message'},
{'role': 'user', 'content': 'What is the answer?'}
]
model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True))
# Assert
assert result == ["new_chunk"]
mock_redis_instance.get.assert_called_once()
mock_redis_instance.set.assert_called_once()