From 0dadbb03f6725c0b3de609ccd7182f6fd3d066e2 Mon Sep 17 00:00:00 2001 From: Vasyl Dizhak Date: Sat, 22 Jun 2024 20:52:46 +0200 Subject: [PATCH] Issue #598 extend support for the hashmap functions --- changelog.d/731.feature | 1 + django_redis/cache.py | 20 ++++++ django_redis/client/default.py | 109 ++++++++++++++++++++++++++--- django_redis/client/sharded.py | 122 +++++++++++++++++++++++++++++++++ tests/test_backend.py | 58 +++++++++++++--- 5 files changed, 290 insertions(+), 20 deletions(-) create mode 100644 changelog.d/731.feature diff --git a/changelog.d/731.feature b/changelog.d/731.feature new file mode 100644 index 00000000..574ee606 --- /dev/null +++ b/changelog.d/731.feature @@ -0,0 +1 @@ +Support for the hashmaps functions hsetnx, hget, hgetall, hmget hincrby \ No newline at end of file diff --git a/django_redis/cache.py b/django_redis/cache.py index f7b943a3..bff8ec29 100644 --- a/django_redis/cache.py +++ b/django_redis/cache.py @@ -257,6 +257,22 @@ def sunionstore(self, *args, **kwargs): def hset(self, *args, **kwargs): return self.client.hset(*args, **kwargs) + @omit_exception + def hsetnx(self, *args, **kwargs): + return self.client.hsetnx(*args, **kwargs) + + @omit_exception + def hget(self, *args, **kwargs): + return self.client.hget(*args, **kwargs) + + @omit_exception + def hgetall(self, *args, **kwargs): + return self.client.hgetall(*args, **kwargs) + + @omit_exception + def hmget(self, *args, **kwargs): + return self.client.hmget(*args, **kwargs) + @omit_exception def hdel(self, *args, **kwargs): return self.client.hdel(*args, **kwargs) @@ -272,3 +288,7 @@ def hkeys(self, *args, **kwargs): @omit_exception def hexists(self, *args, **kwargs): return self.client.hexists(*args, **kwargs) + + @omit_exception + def hincrby(self, *args, **kwargs): + return self.client.hincrby(*args, **kwargs) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 3219f7c9..423063de 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -621,7 +621,7 @@ def _incr( # if cached value or total value is greater than 64 bit signed # integer. # elif int is encoded. so redis sees the data as string. - # In this situations redis will throw ResponseError + # In these situations redis will throw ResponseError # try to keep TTL of key timeout = self.ttl(key, version=version, client=client) @@ -1103,7 +1103,7 @@ def touch( def hset( self, - name: str, + name: KeyT, key: KeyT, value: EncodableT, version: Optional[int] = None, @@ -1115,13 +1115,76 @@ def hset( """ if client is None: client = self.get_client(write=True) + key_name = self.make_key(name, version=version) nkey = self.make_key(key, version=version) nvalue = self.encode(value) - return int(client.hset(name, nkey, nvalue)) + return int(client.hset(key_name, nkey, nvalue)) + + def hsetnx( + self, + name: KeyT, + key: KeyT, + value: EncodableT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + nkey = self.make_key(key, version=version) + key_name = self.make_key(name, version=version) + nvalue = self.encode(value) + return int(client.hsetnx(key_name, nkey, nvalue)) + + def hget( + self, + name: KeyT, + key: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Any: + """ + Return the value of hash name at key. + """ + if client is None: + client = self.get_client(write=False) + name = self.make_key(name, version=version) + nkey = self.make_key(key, version=version) + value = client.hget(name, nkey) + if value is None: + return None + return self.decode(value) + + def hgetall( + self, + name: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Dict[str, Any]: + + if client is None: + client = self.get_client(write=False) + name = self.make_key(name, version=version) + data = client.hgetall(name) + return {self.reverse_key(k.decode()): self.decode(v) for k, v in data.items()} + + def hmget( + self, + name: KeyT, + *keys: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> List[Any]: + if client is None: + client = self.get_client(write=False) + name = self.make_key(name, version=version) + nkeys = [self.make_key(k, version=version) for k in keys] + return [ + self.decode(v) if v is not None else None for v in client.hmget(name, nkeys) + ] def hdel( self, - name: str, + name: KeyT, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, @@ -1132,12 +1195,14 @@ def hdel( """ if client is None: client = self.get_client(write=True) + name = self.make_key(name, version=version) nkey = self.make_key(key, version=version) return int(client.hdel(name, nkey)) def hlen( self, - name: str, + name: KeyT, + version: Optional[int] = None, client: Optional[Redis] = None, ) -> int: """ @@ -1145,26 +1210,29 @@ def hlen( """ if client is None: client = self.get_client(write=False) - return int(client.hlen(name)) + key_name = self.make_key(name, version=version) + return int(client.hlen(key_name)) def hkeys( self, - name: str, + name: KeyT, + version: Optional[int] = None, client: Optional[Redis] = None, ) -> List[Any]: """ Return a list of keys in hash name. """ + key_name = self.make_key(name, version=version) if client is None: client = self.get_client(write=False) try: - return [self.reverse_key(k.decode()) for k in client.hkeys(name)] + return [self.reverse_key(k.decode()) for k in client.hkeys(key_name)] except _main_exceptions as e: raise ConnectionInterrupted(connection=client) from e def hexists( self, - name: str, + name: KeyT, key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, @@ -1174,5 +1242,26 @@ def hexists( """ if client is None: client = self.get_client(write=False) + key_name = self.make_key(name, version=version) nkey = self.make_key(key, version=version) - return bool(client.hexists(name, nkey)) + return bool(client.hexists(key_name, nkey)) + + def hincrby( + self, + name: KeyT, + key: KeyT, + increment: int = 1, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + key_name = self.make_key(name, version=version) + nkey = self.make_key(key, version=version) + try: + value = client.hincrby(key_name, nkey, increment) + except ResponseError as exc: + value = self.hget(key_name, nkey) + msg = f"Value: {value} is not an integer or out of range." + raise ValueError(msg) from exc + return int(value) diff --git a/django_redis/client/sharded.py b/django_redis/client/sharded.py index 5e2eec90..e90cec0a 100644 --- a/django_redis/client/sharded.py +++ b/django_redis/client/sharded.py @@ -483,3 +483,125 @@ def smismember( key = self.make_key(key, version=version) client = self.get_server(key) return super().smismember(key, *members, version=version, client=client) + + def hset( + self, + name: KeyT, + key: KeyT, + value: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key_name = self.make_key(name, version=version) + client = self.get_server(key_name) + return super().hset(key_name, key, value, version=version, client=client) + + def hsetnx( + self, + name: KeyT, + key: KeyT, + value: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key_name = self.make_key(name, version=version) + client = self.get_server(key_name) + return super().hsetnx(key_name, key, value, version=version, client=client) + + def hlen( + self, + name: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key_name = self.make_key(name, version=version) + client = self.get_server(key_name) + return super().hlen(key_name, version=version, client=client) + + def hget( + self, + name: KeyT, + key: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Any: + if client is None: + key_name = self.make_key(name, version=version) + client = self.get_server(key_name) + return super().hget(key_name, key, version=version, client=client) + + def hexists( + self, + name: KeyT, + key: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + key_name = self.make_key(name, version=version) + client = self.get_server(key_name) + return super().hexists(key_name, key, version=version, client=client) + + def hkeys( + self, + name: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> List[Any]: + if client is None: + key_name = self.make_key(name, version=version) + client = self.get_server(key_name) + return super().hkeys(key_name, client=client) + + def hgetall( + self, + name: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Any: + if client is None: + key_name = self.make_key(name, version=version) + client = self.get_server(key_name) + return super().hgetall(key_name, version=version, client=client) + + def hincrby( + self, + name: KeyT, + key: KeyT, + increment: int = 1, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + name = self.make_key(name, version=version) + client = self.get_server(name) + return super().hincrby( + name, key, increment=increment, version=version, client=client + ) + + def hmget( + self, + name: KeyT, + *keys: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> List[Any]: + if client is None: + name = self.make_key(name, version=version) + client = self.get_server(name) + return super().hmget(name, *keys, version=version, client=client) + + def hdel( + self, + name: KeyT, + key: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + name = self.make_key(name, version=version) + client = self.get_server(name) + return super().hdel(name, key, version=version, client=client) diff --git a/tests/test_backend.py b/tests/test_backend.py index 8619931e..f787b30b 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -810,17 +810,61 @@ def test_clear(self, cache: RedisCache): assert value_from_cache_after_clear is None def test_hset(self, cache: RedisCache): - if isinstance(cache.client, ShardClient): - pytest.skip("ShardClient doesn't support get_client") cache.hset("foo_hash1", "foo1", "bar1") cache.hset("foo_hash1", "foo2", "bar2") assert cache.hlen("foo_hash1") == 2 assert cache.hexists("foo_hash1", "foo1") assert cache.hexists("foo_hash1", "foo2") + def test_hsetnx(self, cache: RedisCache): + result_foo1 = cache.hsetnx("foo_hash1", "foo1", "bar1") + result_foo2 = cache.hsetnx("foo_hash1", "foo2", "bar2") + result_foo2_1 = cache.hsetnx("foo_hash1", "foo2", "bar2") + assert result_foo1 == 1 + assert result_foo2 == 1 + assert result_foo2_1 == 0 + assert cache.hlen("foo_hash1") == 2 + assert cache.hexists("foo_hash1", "foo1") + assert cache.hexists("foo_hash1", "foo2") + + def test_hget(self, cache: RedisCache): + cache.hset("foo_hash1", "foo1", "bar1") + cache.hset("foo_hash1", "foo2", "bar2") + assert cache.hget("foo_hash1", "foo1") == "bar1" + assert cache.hget("foo_hash1", "foo2") == "bar2" + + def test_hgetall(self, cache: RedisCache): + cache.hset("foo_hash1", "foo1", "bar1") + cache.hset("foo_hash1", "foo2", "bar2") + assert cache.hgetall("foo_hash1") == { + "foo1": "bar1", + "foo2": "bar2", + } + + def test_hmget(self, cache: RedisCache): + cache.hset("foo_hash1", "foo1", "bar1") + cache.hset("foo_hash1", "foo2", "bar2") + cache.hset("foo_hash1", "foo3", "bar3") + assert cache.hmget("foo_hash1", "foo3", "foo1") == ["bar3", "bar1"] + assert cache.hmget("foo_hash1", "foo1", "foo2", "not-a-key") == [ + "bar1", + "bar2", + None, + ] + + def test_hincrby(self, cache: RedisCache): + cache.hset("foo_hash1", "foo1", 1) + assert cache.hincrby("foo_hash1", "foo1") == 2 + + assert cache.hincrby("foo_hash1", "foo2", 3) == 3 + assert cache.hincrby("foo_hash1", "foo2", -3) == 0 + + cache.hset("foo_hash1", "foo3", "not_an_int") + + with pytest.raises(ValueError): + cache.hincrby("foo_hash1", "foo3", -3) + def test_hdel(self, cache: RedisCache): - if isinstance(cache.client, ShardClient): - pytest.skip("ShardClient doesn't support get_client") cache.hset("foo_hash2", "foo1", "bar1") cache.hset("foo_hash2", "foo2", "bar2") assert cache.hlen("foo_hash2") == 2 @@ -831,8 +875,6 @@ def test_hdel(self, cache: RedisCache): assert cache.hexists("foo_hash2", "foo2") def test_hlen(self, cache: RedisCache): - if isinstance(cache.client, ShardClient): - pytest.skip("ShardClient doesn't support get_client") assert cache.hlen("foo_hash3") == 0 cache.hset("foo_hash3", "foo1", "bar1") assert cache.hlen("foo_hash3") == 1 @@ -840,8 +882,6 @@ def test_hlen(self, cache: RedisCache): assert cache.hlen("foo_hash3") == 2 def test_hkeys(self, cache: RedisCache): - if isinstance(cache.client, ShardClient): - pytest.skip("ShardClient doesn't support get_client") cache.hset("foo_hash4", "foo1", "bar1") cache.hset("foo_hash4", "foo2", "bar2") cache.hset("foo_hash4", "foo3", "bar3") @@ -851,8 +891,6 @@ def test_hkeys(self, cache: RedisCache): assert keys[i] == f"foo{i + 1}" def test_hexists(self, cache: RedisCache): - if isinstance(cache.client, ShardClient): - pytest.skip("ShardClient doesn't support get_client") cache.hset("foo_hash5", "foo1", "bar1") assert cache.hexists("foo_hash5", "foo1") assert not cache.hexists("foo_hash5", "foo")