diff --git a/celery_redis_sentinel/backend.py b/celery_redis_sentinel/backend.py index 35006e7..8916d5b 100644 --- a/celery_redis_sentinel/backend.py +++ b/celery_redis_sentinel/backend.py @@ -3,10 +3,18 @@ from celery.backends.redis import RedisBackend from kombu.utils import cached_property -from redis import Redis +import celery from .redis_sentinel import EnsuredRedisMixin, get_redis_via_sentinel +try: + if celery.VERSION.major >= 4: + from redis import StrictRedis as Redis + else: + from redis import Redis +except AttributeError: + from redis import Redis + class RedisSentinelBackend(RedisBackend): """ diff --git a/celery_redis_sentinel/register.py b/celery_redis_sentinel/register.py index 06ad8f8..75c35af 100644 --- a/celery_redis_sentinel/register.py +++ b/celery_redis_sentinel/register.py @@ -7,10 +7,13 @@ from .backend import RedisSentinelBackend from .transport import SentinelTransport -if celery.VERSION.major < 4: +try: + if celery.VERSION.major >= 4: + from celery.app.backends import BACKEND_ALIASES + else: + from celery.backends import BACKEND_ALIASES +except AttributeError: from celery.backends import BACKEND_ALIASES -else: - from celery.app.backends import BACKEND_ALIASES def get_class_path(cls): diff --git a/tests/test_backend.py b/tests/test_backend.py index 97935a5..e789e4a 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, unicode_literals import mock +import celery from celery_redis_sentinel.backend import RedisSentinelBackend @@ -25,16 +26,25 @@ def test_client(self, mock_get_redis_via_sentinel): client = backend.client assert client == mock_get_redis_via_sentinel.return_value + + ignored_args = [ + 'redis_class', + 'host', + 'max_connections', + 'password', + 'port', + ] + try: + if celery.VERSION.major >= 4: + ignored_args.append('socket_connect_timeout') + except AttributeError: + pass mock_get_redis_via_sentinel.assert_called_once_with( db=0, - redis_class=mock.ANY, sentinels=[('192.168.1.1', 26379), ('192.168.1.2', 26379), ('192.168.1.3', 26379)], service_name='master', socket_timeout=1, - host=mock.ANY, - max_connections=mock.ANY, - password=mock.ANY, - port=mock.ANY, + **{arg: mock.ANY for arg in ignored_args} ) diff --git a/tests/test_register.py b/tests/test_register.py index 3dea7e1..d119e1c 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -2,11 +2,19 @@ from __future__ import absolute_import, print_function, unicode_literals import mock -from celery.backends import BACKEND_ALIASES +import celery from kombu.transport import TRANSPORT_ALIASES from celery_redis_sentinel.register import get_class_path, register +try: + if celery.VERSION.major >= 4: + from celery.app.backends import BACKEND_ALIASES + else: + from celery.backends import BACKEND_ALIASES +except AttributeError: + from celery.backends import BACKEND_ALIASES + class Foo(object): pass diff --git a/tests/test_transport.py b/tests/test_transport.py index 97acc35..9696277 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -3,6 +3,7 @@ import mock from kombu import Connection +import celery from redis import StrictRedis from celery_redis_sentinel.redis_sentinel import CelerySentinelConnectionPool @@ -35,11 +36,21 @@ def test_sentinel_pool(self, mock_get_redis_via_sentinel, mock_execute_command): pool = channel.sentinel_pool assert pool == mock_get_redis_via_sentinel.return_value.connection_pool + ignored_args = [ + 'host', + 'max_connections', + 'password', + 'port', + 'socket_connect_timeout', + 'socket_keepalive', + 'socket_keepalive_options', + ] + try: + if celery.VERSION.major >= 4: + ignored_args.append('connection_class') + except AttributeError: + pass mock_get_redis_via_sentinel.assert_called_once_with( - host=mock.ANY, - max_connections=mock.ANY, - password=mock.ANY, - port=mock.ANY, connection_pool_class=CelerySentinelConnectionPool, redis_class=channel.Client, db=0, @@ -48,9 +59,7 @@ def test_sentinel_pool(self, mock_get_redis_via_sentinel, mock_execute_command): ('192.168.1.3', 26379)], service_name='master', socket_timeout=1, - socket_connect_timeout=mock.ANY, - socket_keepalive=mock.ANY, - socket_keepalive_options=mock.ANY, + **{arg: mock.ANY for arg in ignored_args} )