From 4ac097e94e3bbb673edadce12f67439f63882d59 Mon Sep 17 00:00:00 2001 From: Dominik George <dominik.george@teckids.org> Date: Sat, 8 Jan 2022 22:17:21 +0100 Subject: [PATCH] Ensure task is enqueued on commit --- celery_haystack/conf.py | 9 +-------- celery_haystack/signals.py | 33 ++++++++++++++++++++------------- celery_haystack/utils.py | 29 ----------------------------- 3 files changed, 21 insertions(+), 50 deletions(-) diff --git a/celery_haystack/conf.py b/celery_haystack/conf.py index c88c488..7704901 100644 --- a/celery_haystack/conf.py +++ b/celery_haystack/conf.py @@ -55,11 +55,4 @@ class CeleryHaystack(AppConf): return data -signal_processor = getattr(settings, 'HAYSTACK_SIGNAL_PROCESSOR', None) - - -if signal_processor is None: - raise ImproperlyConfigured("When using celery-haystack with Haystack 2.X " - "the HAYSTACK_SIGNAL_PROCESSOR setting must be " - "set. Use 'celery_haystack.signals." - "CelerySignalProcessor' as default.") +signal_processor = getattr(settings, 'HAYSTACK_SIGNAL_PROCESSOR', "celery_haystack.signals.CelerySignalProcessor") diff --git a/celery_haystack/signals.py b/celery_haystack/signals.py index 0c6fe3b..afb6ec6 100644 --- a/celery_haystack/signals.py +++ b/celery_haystack/signals.py @@ -1,34 +1,41 @@ +from django.db import transaction from django.db.models import signals -from haystack.signals import BaseSignalProcessor +from haystack.signals import RealtimeSignalProcessor from haystack.exceptions import NotHandled from haystack.utils import get_identifier -from .utils import enqueue_task +from .conf import settings +from .utils import get_update_task from .indexes import CelerySearchIndex -class CelerySignalProcessor(BaseSignalProcessor): +class CelerySignalProcessor(RealtimeSignalProcessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._queue = [] def setup(self): - signals.post_save.connect(self.enqueue_save) - signals.post_delete.connect(self.enqueue_delete) + transaction.on_commit(self.run_task) + super().setup() - def teardown(self): - signals.post_save.disconnect(self.enqueue_save) - signals.post_delete.disconnect(self.enqueue_delete) - - enqueue_task(self._queue) - - def enqueue_save(self, sender, instance, **kwargs): + def handle_save(self, sender, instance, **kwargs): return self.enqueue('update', instance, sender, **kwargs) - def enqueue_delete(self, sender, instance, **kwargs): + def handle_delete(self, sender, instance, **kwargs): return self.enqueue('delete', instance, sender, **kwargs) + def run_task(self): + options = {} + if settings.CELERY_HAYSTACK_QUEUE: + options['queue'] = settings.CELERY_HAYSTACK_QUEUE + if settings.CELERY_HAYSTACK_COUNTDOWN: + options['countdown'] = settings.CELERY_HAYSTACK_COUNTDOWN + + task = get_update_task() + task.apply_async((self._queue,), {}, **options) + + def enqueue(self, action, instance, sender, **kwargs): """ Given an individual model instance, determine if a backend diff --git a/celery_haystack/utils.py b/celery_haystack/utils.py index b016e6c..cc1f92d 100644 --- a/celery_haystack/utils.py +++ b/celery_haystack/utils.py @@ -3,7 +3,6 @@ try: from importlib import import_module except ImportError: from django.utils.importlib import import_module -from django.db import connection, transaction from .conf import settings @@ -22,31 +21,3 @@ def get_update_task(task_path=None): raise ImproperlyConfigured('Module "%s" does not define a "%s" ' 'class.' % (module, attr)) return task - - -def enqueue_task(queue, **kwargs): - """ - Common utility for enqueing a task for the given action and - model instance. - """ - options = {} - if settings.CELERY_HAYSTACK_QUEUE: - options['queue'] = settings.CELERY_HAYSTACK_QUEUE - if settings.CELERY_HAYSTACK_COUNTDOWN: - options['countdown'] = settings.CELERY_HAYSTACK_COUNTDOWN - - task = get_update_task() - task_func = lambda: task.apply_async((queue,), kwargs, **options) - - if hasattr(transaction, 'on_commit'): - # Django 1.9 on_commit hook - transaction.on_commit( - task_func - ) - elif hasattr(connection, 'on_commit'): - # Django-transaction-hooks - connection.on_commit( - task_func - ) - else: - task_func() -- GitLab