diff --git a/django_prometheus/cache/backends/redis.py b/django_prometheus/cache/backends/redis.py index bc982dac..707c85ee 100644 --- a/django_prometheus/cache/backends/redis.py +++ b/django_prometheus/cache/backends/redis.py @@ -1,5 +1,4 @@ from django import VERSION as DJANGO_VERSION -from django_redis import cache, exceptions from django_prometheus.cache.metrics import ( django_cache_get_fail_total, @@ -9,30 +8,6 @@ ) -class RedisCache(cache.RedisCache): - """Inherit redis to add metrics about hit/miss/interruption ratio""" - - @cache.omit_exception - def get(self, key, default=None, version=None, client=None): - try: - django_cache_get_total.labels(backend="redis").inc() - cached = self.client.get(key, default=None, version=version, client=client) - except exceptions.ConnectionInterrupted as e: - django_cache_get_fail_total.labels(backend="redis").inc() - if self._ignore_exceptions: - if self._log_ignored_exceptions: - cache.logger.error(str(e)) - return default - raise - else: - if cached is not None: - django_cache_hits_total.labels(backend="redis").inc() - return cached - else: - django_cache_misses_total.labels(backend="redis").inc() - return default - - if DJANGO_VERSION >= (4, 0): from django.core.cache.backends.redis import RedisCache as DjangoRedisCache @@ -50,3 +25,29 @@ def get(self, key, default=None, version=None): else: django_cache_misses_total.labels(backend="native_redis").inc() return default +else: + # Fallback for django_redis + from django_redis import cache, exceptions + + class RedisCache(cache.RedisCache): + """Inherit redis to add metrics about hit/miss/interruption ratio""" + + @cache.omit_exception + def get(self, key, default=None, version=None, client=None): + try: + django_cache_get_total.labels(backend="redis").inc() + cached = self.client.get(key, default=None, version=version, client=client) + except exceptions.ConnectionInterrupted as e: + django_cache_get_fail_total.labels(backend="redis").inc() + if self._ignore_exceptions: + if self._log_ignored_exceptions: + cache.logger.error(str(e)) + return default + raise + else: + if cached is not None: + django_cache_hits_total.labels(backend="redis").inc() + return cached + else: + django_cache_misses_total.labels(backend="redis").inc() + return default