source file: /opt/devel/celery/testproj/../celery/tests/test_task.py
file stats: 143 lines, 143 executed: 100.0% covered
   1. import unittest
   2. import uuid
   3. import logging
   4. from StringIO import StringIO
   5. 
   6. from celery import task
   7. from celery import registry
   8. from celery.log import setup_logger
   9. from celery import messaging
  10. 
  11. 
  12. # Task run functions can't be closures/lambdas, as they're pickled.
  13. def return_True(self, **kwargs):
  14.     return True
  15. 
  16. 
  17. def raise_exception(self, **kwargs):
  18.     raise Exception("%s error" % self.__class__)
  19. 
  20. 
  21. class IncrementCounterTask(task.Task):
  22.     name = "c.unittest.increment_counter_task"
  23.     count = 0
  24. 
  25.     def run(self, increment_by, **kwargs):
  26.         increment_by = increment_by or 1
  27.         self.__class__.count += increment_by
  28. 
  29. 
  30. class TestCeleryTasks(unittest.TestCase):
  31. 
  32.     def createTaskCls(self, cls_name, task_name=None):
  33.         attrs = {}
  34.         if task_name:
  35.             attrs["name"] = task_name
  36.         cls = type(cls_name, (task.Task, ), attrs)
  37.         cls.run = return_True
  38.         return cls
  39. 
  40.     def assertNextTaskDataEquals(self, consumer, task_id, task_name,
  41.             **kwargs):
  42.         next_task = consumer.fetch()
  43.         task_data = consumer.decoder(next_task.body)
  44.         self.assertEquals(task_data["celeryID"], task_id)
  45.         self.assertEquals(task_data["celeryTASK"], task_name)
  46.         for arg_name, arg_value in kwargs.items():
  47.             self.assertEquals(task_data.get(arg_name), arg_value)
  48. 
  49.     def test_raising_task(self):
  50.         rtask = self.createTaskCls("RaisingTask", "c.unittest.t.rtask")
  51.         rtask.run = raise_exception
  52.         sio = StringIO()
  53. 
  54.         taskinstance = rtask()
  55.         taskinstance(loglevel=logging.INFO, logfile=sio)
  56.         self.assertTrue(sio.getvalue().find("Task got exception") != -1)
  57. 
  58.     def test_incomplete_task_cls(self):
  59.         class IncompleteTask(task.Task):
  60.             name = "c.unittest.t.itask"
  61. 
  62.         self.assertRaises(NotImplementedError, IncompleteTask().run)
  63. 
  64.     def test_regular_task(self):
  65.         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
  66.         self.assertTrue(isinstance(T1(), T1))
  67.         self.assertTrue(T1().run())
  68.         self.assertTrue(callable(T1()),
  69.                 "Task class is callable()")
  70.         self.assertTrue(T1()(),
  71.                 "Task class runs run() when called")
  72. 
  73.         # task without name raises NotImplementedError
  74.         T2 = self.createTaskCls("T2")
  75.         self.assertRaises(NotImplementedError, T2)
  76. 
  77.         registry.tasks.register(T1)
  78.         t1 = T1()
  79.         consumer = t1.get_consumer()
  80.         self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
  81.         consumer.discard_all()
  82.         self.assertTrue(consumer.fetch() is None)
  83. 
  84.         # Without arguments.
  85.         tid = t1.delay()
  86.         self.assertNextTaskDataEquals(consumer, tid, t1.name)
  87. 
  88.         # With arguments.
  89.         tid2 = task.delay_task(t1.name, name="George Constanza")
  90.         self.assertNextTaskDataEquals(consumer, tid2, t1.name,
  91.                 name="George Constanza")
  92. 
  93.         self.assertRaises(registry.tasks.NotRegistered, task.delay_task,
  94.                 "some.task.that.should.never.exist.X.X.X.X.X")
  95. 
  96.         # Discarding all tasks.
  97.         task.discard_all()
  98.         tid3 = task.delay_task(t1.name)
  99.         self.assertEquals(task.discard_all(), 1)
 100.         self.assertTrue(consumer.fetch() is None)
 101. 
 102.         self.assertFalse(task.is_done(tid))
 103.         task.mark_as_done(tid, result=None)
 104.         self.assertTrue(task.is_done(tid))
 105. 
 106. 
 107.         publisher = t1.get_publisher()
 108.         self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
 109. 
 110.     def test_taskmeta_cache(self):
 111.         # TODO Needs to test task meta without TASK_META_USE_DB.
 112.         tid = str(uuid.uuid4())
 113.         ckey = task.gen_task_done_cache_key(tid)
 114.         self.assertTrue(ckey.rfind(tid) != -1)
 115. 
 116. 
 117. class TestTaskSet(unittest.TestCase):
 118. 
 119.     def test_counter_taskset(self):
 120.         ts = task.TaskSet(IncrementCounterTask, [
 121.             {},
 122.             {"increment_by": 2},
 123.             {"increment_by": 3},
 124.             {"increment_by": 4},
 125.             {"increment_by": 5},
 126.             {"increment_by": 6},
 127.             {"increment_by": 7},
 128.             {"increment_by": 8},
 129.             {"increment_by": 9},
 130.         ])
 131.         self.assertEquals(ts.task_name, IncrementCounterTask.name)
 132.         self.assertEquals(ts.total, 9)
 133. 
 134.         taskset_id, subtask_ids = ts.run()
 135. 
 136.         consumer = IncrementCounterTask().get_consumer()
 137.         for subtask_id in subtask_ids:
 138.             m = consumer.decoder(consumer.fetch().body)
 139.             self.assertEquals(m.get("celeryTASKSET"), taskset_id)
 140.             self.assertEquals(m.get("celeryTASK"), IncrementCounterTask.name)
 141.             self.assertEquals(m.get("celeryID"), subtask_id)
 142.             IncrementCounterTask().run(increment_by=m.get("increment_by"))
 143.         self.assertEquals(IncrementCounterTask.count, sum(xrange(1, 10)))