From 4ac097e94e3bbb673edadce12f67439f63882d59 Mon Sep 17 00:00:00 2001
From: Dominik George <dominik.george@teckids.org>
Date: Sat, 8 Jan 2022 22:17:21 +0100
Subject: [PATCH] Ensure task is enqueued on commit

---
 celery_haystack/conf.py    |  9 +--------
 celery_haystack/signals.py | 33 ++++++++++++++++++++-------------
 celery_haystack/utils.py   | 29 -----------------------------
 3 files changed, 21 insertions(+), 50 deletions(-)

diff --git a/celery_haystack/conf.py b/celery_haystack/conf.py
index c88c488..7704901 100644
--- a/celery_haystack/conf.py
+++ b/celery_haystack/conf.py
@@ -55,11 +55,4 @@ class CeleryHaystack(AppConf):
         return data
 
 
-signal_processor = getattr(settings, 'HAYSTACK_SIGNAL_PROCESSOR', None)
-
-
-if signal_processor is None:
-    raise ImproperlyConfigured("When using celery-haystack with Haystack 2.X "
-                               "the HAYSTACK_SIGNAL_PROCESSOR setting must be "
-                               "set. Use 'celery_haystack.signals."
-                               "CelerySignalProcessor' as default.")
+signal_processor = getattr(settings, 'HAYSTACK_SIGNAL_PROCESSOR', "celery_haystack.signals.CelerySignalProcessor")
diff --git a/celery_haystack/signals.py b/celery_haystack/signals.py
index 0c6fe3b..afb6ec6 100644
--- a/celery_haystack/signals.py
+++ b/celery_haystack/signals.py
@@ -1,34 +1,41 @@
+from django.db import transaction
 from django.db.models import signals
 
-from haystack.signals import BaseSignalProcessor
+from haystack.signals import RealtimeSignalProcessor
 from haystack.exceptions import NotHandled
 from haystack.utils import get_identifier
 
-from .utils import enqueue_task
+from .conf import settings
+from .utils import get_update_task
 from .indexes import CelerySearchIndex
 
 
-class CelerySignalProcessor(BaseSignalProcessor):
+class CelerySignalProcessor(RealtimeSignalProcessor):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self._queue = []
 
     def setup(self):
-        signals.post_save.connect(self.enqueue_save)
-        signals.post_delete.connect(self.enqueue_delete)
+        transaction.on_commit(self.run_task)
+        super().setup()
 
-    def teardown(self):
-        signals.post_save.disconnect(self.enqueue_save)
-        signals.post_delete.disconnect(self.enqueue_delete)
-
-        enqueue_task(self._queue)
-
-    def enqueue_save(self, sender, instance, **kwargs):
+    def handle_save(self, sender, instance, **kwargs):
         return self.enqueue('update', instance, sender, **kwargs)
 
-    def enqueue_delete(self, sender, instance, **kwargs):
+    def handle_delete(self, sender, instance, **kwargs):
         return self.enqueue('delete', instance, sender, **kwargs)
 
+    def run_task(self):
+        options = {}
+        if settings.CELERY_HAYSTACK_QUEUE:
+            options['queue'] = settings.CELERY_HAYSTACK_QUEUE
+        if settings.CELERY_HAYSTACK_COUNTDOWN:
+            options['countdown'] = settings.CELERY_HAYSTACK_COUNTDOWN
+
+        task = get_update_task()
+        task.apply_async((self._queue,), {}, **options)
+
+
     def enqueue(self, action, instance, sender, **kwargs):
         """
         Given an individual model instance, determine if a backend
diff --git a/celery_haystack/utils.py b/celery_haystack/utils.py
index b016e6c..cc1f92d 100644
--- a/celery_haystack/utils.py
+++ b/celery_haystack/utils.py
@@ -3,7 +3,6 @@ try:
     from importlib import import_module
 except ImportError:
     from django.utils.importlib import import_module
-from django.db import connection, transaction
 
 from .conf import settings
 
@@ -22,31 +21,3 @@ def get_update_task(task_path=None):
         raise ImproperlyConfigured('Module "%s" does not define a "%s" '
                                    'class.' % (module, attr))
     return task
-
-
-def enqueue_task(queue, **kwargs):
-    """
-    Common utility for enqueing a task for the given action and
-    model instance.
-    """
-    options = {}
-    if settings.CELERY_HAYSTACK_QUEUE:
-        options['queue'] = settings.CELERY_HAYSTACK_QUEUE
-    if settings.CELERY_HAYSTACK_COUNTDOWN:
-        options['countdown'] = settings.CELERY_HAYSTACK_COUNTDOWN
-
-    task = get_update_task()
-    task_func = lambda: task.apply_async((queue,), kwargs, **options)
-
-    if hasattr(transaction, 'on_commit'):
-        # Django 1.9 on_commit hook
-        transaction.on_commit(
-            task_func
-        )
-    elif hasattr(connection, 'on_commit'):
-        # Django-transaction-hooks
-        connection.on_commit(
-            task_func
-        )
-    else:
-        task_func()
-- 
GitLab