Drop-in celery AbortableTask replacement

24 October 2011 (updated 04 March 2015)

If you need to report progress updates from the tasks (or you call update_state in the task) you cannot use the bundled AbortableTask from celery.contrib.abortable because it relies on status updates too. That means you'll get race conditions if you do that.

You can use revokes for aborting tasks but they don't give you enough control and it's not guaranteed that your tasks will stop gracefully (or stop at all). Revokes can raise SoftTimeLimitExceeded if enabled (via TERM signal) however it might be tricky to perform cleanup - if you call C extension the exception will get delayed till the call returns. See the signal module docs for what happens when you raise an exception from a signal handler (that's what celery does).

Given this, an alternative is to use redis to store the aborted task ids in a redis set. If you use the redis broker you can use this drop-in replacement:

from contextlib import contextmanager
import celery
from celery.task.base import Task
from celery.result import AsyncResult

from django.conf import settings

assert settings.BROKER_TRANSPORT == 'redis', "AbortableTask can only work with a 'redis' BROKER_TRANSPORT"
REDIS_KEY = getattr(settings, 'ABORTABLE_REDIS_KEY', 'task-aborts')

@contextmanager
def client_from_pool():
    connection = celery.current_app.pool.acquire()
    try:
        yield connection.default_channel.client
    finally:
        connection.release()

class AbortableAsyncResult(AsyncResult):

    def is_aborted(self):
        with client_from_pool() as client:
            return client.sismember(REDIS_KEY, self.task_id)

    def abort(self):
        with client_from_pool() as client:
            client.sadd(REDIS_KEY, self.task_id)

class AbortableTask(Task):

    @classmethod
    def AsyncResult(cls, task_id):
        return AbortableAsyncResult(task_id, backend=cls.backend,
                                             task_name=cls.name)

    def is_aborted(self, **kwargs):
        task_id = kwargs.get('task_id', self.request.id)
        with client_from_pool() as client:
            return client.sismember(REDIS_KEY, task_id)

    def cleanup(self, **kwargs):
        task_id = kwargs.get('task_id', self.request.id)
        with client_from_pool() as client:
            client.srem(REDIS_KEY, task_id)

    def after_return(self, status, retval, task_id, args, kwargs, einfo):
        self.cleanup(task_id=task_id)

This will use the broker's connection pool if enabled (you should enable it, just set BROKER_POOL_LIMIT).

This entry was tagged as celery django python