From 566df8355c907a9d688d4b124f09deae98acd2a8 Mon Sep 17 00:00:00 2001 From: Hossam Hammady Date: Tue, 19 Dec 2023 18:51:18 -0500 Subject: [PATCH] Inject reporter into job from worker --- pyworker/job.py | 13 +++++++++---- pyworker/worker.py | 3 ++- tests/test_job.py | 24 ++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/pyworker/job.py b/pyworker/job.py index 82781d9..aa120a9 100644 --- a/pyworker/job.py +++ b/pyworker/job.py @@ -19,7 +19,8 @@ class Job(object, metaclass=Meta): """docstring for Job""" def __init__(self, class_name, database, logger, job_id, queue, run_at, attempts=0, max_attempts=1, - attributes=None, abstract=False, extra_fields=None): + attributes=None, abstract=False, extra_fields=None, + reporter=None): super(Job, self).__init__() self.class_name = class_name self.database = database @@ -33,12 +34,14 @@ def __init__(self, class_name, database, logger, self.attributes = attributes self.abstract = abstract self.extra_fields = extra_fields + self.reporter = reporter def __str__(self): return "%s: %s" % (self.__class__.__name__, str(self.__dict__)) @classmethod - def from_row(cls, job_row, max_attempts, database, logger, extra_fields=None): + def from_row(cls, job_row, max_attempts, database, logger, + extra_fields=None, reporter=None): '''job_row is a tuple of (id, attempts, run_at, queue, handler, *extra_fields)''' def extract_class_name(line): regex = re.compile('object: !ruby/object:(.+)') @@ -80,7 +83,8 @@ def extract_extra_fields(extra_fields, extra_field_values): max_attempts=max_attempts, job_id=job_id, attempts=attempts, run_at=run_at, queue=queue, database=database, - abstract=True, extra_fields=extra_fields_dict) + abstract=True, extra_fields=extra_fields_dict, + reporter=reporter) attributes = extract_attributes(handler[2:]) logger.debug("Found attributes: %s" % str(attributes)) @@ -94,7 +98,8 @@ def extract_extra_fields(extra_fields, extra_field_values): run_at=run_at, queue=queue, database=database, max_attempts=max_attempts, attributes=payload['object']['attributes'], - abstract=False, extra_fields=extra_fields_dict) + abstract=False, extra_fields=extra_fields_dict, + reporter=reporter) def before(self): self.logger.debug("Running Job.before hook") diff --git a/pyworker/worker.py b/pyworker/worker.py index 8a544b3..5e4a13d 100644 --- a/pyworker/worker.py +++ b/pyworker/worker.py @@ -145,7 +145,8 @@ def get_job_row(): if job_row: return Job.from_row(job_row, max_attempts=self.max_attempts, database=self.database, logger=self.logger, - extra_fields=self.extra_delayed_job_fields) + extra_fields=self.extra_delayed_job_fields, + reporter=self.reporter) else: return None diff --git a/tests/test_job.py b/tests/test_job.py index 75076b8..3d6b5df 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -66,6 +66,11 @@ def load_unregistered_job(self): def load_unregistered_job_with_extra_fields(self): return self.load_job_with_extra_fields('handler_unregistered.yaml') + def load_unregistered_job_with_reporter(self, reporter): + job = self.load_unregistered_job() + job.reporter = reporter + return job + def load_registered_job(self): job = self.load_job('handler_registered.yaml') job.error = MagicMock() @@ -76,6 +81,11 @@ def load_registered_job(self): def load_registered_job_with_extra_fields(self): return self.load_job_with_extra_fields('handler_registered.yaml') + def load_registered_job_with_reporter(self, reporter): + job = self.load_registered_job() + job.reporter = reporter + return job + def load_registered_job_with_attempts_exceeded(self): job = self.load_registered_job() job.attempts = self.mock_max_attempts - 1 @@ -99,12 +109,19 @@ def test_from_row_when_unregistered_class_returns_job_instance_without_attribute self.assertEqual(job.max_attempts, self.mock_max_attempts) self.assertIsNone(job.extra_fields) self.assertIsNone(job.attributes) + self.assertIsNone(job.reporter) def test_from_row_when_unregistered_class_returns_job_instance_with_extra_fields(self): job = self.load_unregistered_job_with_extra_fields() self.assertDictEqual(job.extra_fields, self.mock_extra_fields) + def test_from_row_when_unregistered_class_returns_abstract_job_instance_with_reporter(self): + mock_reporter = MagicMock() + job = self.load_unregistered_job_with_reporter(mock_reporter) + + self.assertEqual(job.reporter, mock_reporter) + def test_from_row_when_registered_class_returns_concrete_job_instance(self): job = self.load_registered_job() @@ -128,12 +145,19 @@ def test_from_row_when_registered_class_returns_job_instance_with_attributes(sel 'total_articles': 1000, 'is_blind': True }) + self.assertIsNone(job.reporter) def test_from_row_when_registered_class_returns_job_instance_with_extra_fields(self): job = self.load_registered_job_with_extra_fields() self.assertDictEqual(job.extra_fields, self.mock_extra_fields) + def test_from_row_when_registered_class_returns_concrete_job_instance_with_reporter(self): + mock_reporter = MagicMock() + job = self.load_registered_job_with_reporter(mock_reporter) + + self.assertEqual(job.reporter, mock_reporter) + #********** .set_error_unlock tests **********# def assert_job_updated_field(self, job, field, value):