Skip to content

Commit

Permalink
Add support for reporting extra fields if configured (#20)
Browse files Browse the repository at this point in the history
* Add support for reporting extra fields if configured
  • Loading branch information
hammady authored Dec 19, 2023
1 parent 58dc027 commit a4cc795
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 15 deletions.
21 changes: 15 additions & 6 deletions pyworker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Job(object, metaclass=Meta):
"""docstring for Job"""
def __init__(self, class_name, database, logger,
job_id, queue, run_at, attempts=0, max_attempts=1,
attributes=None, abstract=False):
attributes=None, abstract=False, extra_fields=None):
super(Job, self).__init__()
self.class_name = class_name
self.database = database
Expand All @@ -32,13 +32,14 @@ def __init__(self, class_name, database, logger,
self.max_attempts = max_attempts
self.attributes = attributes
self.abstract = abstract
self.extra_fields = extra_fields

def __str__(self):
return "%s: %s" % (self.__class__.__name__, str(self.__dict__))

@classmethod
def from_row(cls, job_row, max_attempts, database, logger):
'''job_row is a tuple of (id, attempts, run_at, queue, handler)'''
def from_row(cls, job_row, max_attempts, database, logger, extra_fields=None):
'''job_row is a tuple of (id, attempts, run_at, queue, handler, *extra_fields)'''
def extract_class_name(line):
regex = re.compile('object: !ruby/object:(.+)')
match = regex.match(line)
Expand All @@ -60,7 +61,14 @@ def extract_attributes(lines):
attributes.append(line)
return attributes

job_id, attempts, run_at, queue, handler = job_row
def extract_extra_fields(extra_fields, extra_field_values):
if extra_fields is None or extra_field_values is None:
return None

return dict(zip(extra_fields, extra_field_values))

job_id, attempts, run_at, queue, handler, *extra_field_values = job_row
extra_fields_dict = extract_extra_fields(extra_fields, extra_field_values)
handler = handler.splitlines()

class_name = extract_class_name(handler[1])
Expand All @@ -72,7 +80,7 @@ def extract_attributes(lines):
max_attempts=max_attempts,
job_id=job_id, attempts=attempts,
run_at=run_at, queue=queue, database=database,
abstract=True)
abstract=True, extra_fields=extra_fields_dict)

attributes = extract_attributes(handler[2:])
logger.debug("Found attributes: %s" % str(attributes))
Expand All @@ -85,7 +93,8 @@ def extract_attributes(lines):
job_id=job_id, attempts=attempts,
run_at=run_at, queue=queue, database=database,
max_attempts=max_attempts,
attributes=payload['object']['attributes'])
attributes=payload['object']['attributes'],
abstract=False, extra_fields=extra_fields_dict)

def before(self):
self.logger.debug("Running Job.before hook")
Expand Down
26 changes: 21 additions & 5 deletions pyworker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os, signal, traceback
import time
import json
from contextlib import contextmanager
from pyworker.db import DBConnector
from pyworker.job import Job
Expand All @@ -12,7 +13,7 @@ class TimeoutException(Exception): pass
class TerminatedException(Exception): pass

class Worker(object):
def __init__(self, dbstring, logger=None):
def __init__(self, dbstring, logger=None, extra_delayed_job_fields=None):
super(Worker, self).__init__()
self.logger = Logger(logger)
self.logger.info('Starting pyworker...')
Expand All @@ -24,6 +25,7 @@ def __init__(self, dbstring, logger=None):
hostname = os.uname()[1]
pid = os.getpid()
self.name = 'host:%s pid:%d' % (hostname, pid)
self.extra_delayed_job_fields = extra_delayed_job_fields

# Configure NewRelic if ENV variables set
self.newrelic_app = None
Expand Down Expand Up @@ -84,7 +86,16 @@ def _latency(job_run_at):
newrelic.agent.add_custom_attribute('job_queue', job.queue)
newrelic.agent.add_custom_attribute('job_latency', latency)
newrelic.agent.add_custom_attribute('job_attempts', job.attempts)
# TODO report job.enqueue_attributes if available

# Record extra fields if configured
self.logger.debug('job extra fields: %s' % job.extra_fields)
if job.extra_fields is not None:
for key, value in job.extra_fields.items():
# NewRelic only supports string, int, float, bool
if value is not None:
if type(value) not in [str, int, float, bool]:
value = json.dumps(value)
newrelic.agent.add_custom_attribute(key, value)

yield task
else:
Expand Down Expand Up @@ -119,6 +130,10 @@ def get_job_row():
now, expired = str(now), str(expired)
queues = self.queue_names.split(',')
queues = ', '.join(["'%s'" % q for q in queues])
fields = ['id', 'attempts', 'run_at', 'queue', 'handler']
if self.extra_delayed_job_fields:
fields += self.extra_delayed_job_fields
fields = ', '.join(fields)
query = '''
UPDATE delayed_jobs SET locked_at = '%s', locked_by = '%s'
WHERE id IN (SELECT delayed_jobs.id FROM delayed_jobs
Expand All @@ -127,16 +142,17 @@ def get_job_row():
OR locked_by = '%s') AND failed_at IS NULL)
AND delayed_jobs.queue IN (%s)
ORDER BY priority ASC, run_at ASC LIMIT 1 FOR UPDATE) RETURNING
id, attempts, run_at, queue, handler
''' % (now, self.name, now, expired, self.name, queues)
%s
''' % (now, self.name, now, expired, self.name, queues, fields)
self.logger.debug('query: %s' % query)
self._cursor.execute(query)
return self._cursor.fetchone()

job_row = get_job_row()
if job_row:
return Job.from_row(job_row, max_attempts=self.max_attempts,
database=self.database, logger=self.logger)
database=self.database, logger=self.logger,
extra_fields=self.extra_delayed_job_fields)
else:
return None

Expand Down
52 changes: 51 additions & 1 deletion tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def setUp(self):
self.mock_queue = 'default'
self.mock_max_attempts = 5
self.mock_now = datetime.datetime(2023, 10, 7, 0, 0, 0)
self.mock_extra_fields = {
'extra_field1': 'extra_field1_value',
'extra_field2': 100,
'extra_field3': True,
'extra_field4': {'a': [1, 2, 3]},
'extra_field5': None
}

def tearDown(self):
pass
Expand All @@ -38,16 +45,37 @@ def load_job(self, filename):
self.mock_max_attempts,
MagicMock(), MagicMock())

def load_job_with_extra_fields(self, filename):
mock_handler = self.load_fixture(filename)
mock_row = (
self.mock_job_id,
self.mock_attempts,
self.mock_run_at,
self.mock_queue,
mock_handler,
*self.mock_extra_fields.values()
)
return Job.from_row(mock_row,
self.mock_max_attempts,
MagicMock(), MagicMock(),
extra_fields=self.mock_extra_fields.keys())

def load_unregistered_job(self):
return self.load_job('handler_unregistered.yaml')

def load_unregistered_job_with_extra_fields(self):
return self.load_job_with_extra_fields('handler_unregistered.yaml')

def load_registered_job(self):
job = self.load_job('handler_registered.yaml')
job.error = MagicMock()
job.failure = MagicMock()
job._update_job = MagicMock()
return job

def load_registered_job_with_extra_fields(self):
return self.load_job_with_extra_fields('handler_registered.yaml')

def load_registered_job_with_attempts_exceeded(self):
job = self.load_registered_job()
job.attempts = self.mock_max_attempts - 1
Expand All @@ -61,20 +89,37 @@ def test_from_row_when_unregistered_class_returns_abstract_job_instance(self):
self.assertEqual(job.class_name, 'UnregisteredJob')
self.assertEqual(job.abstract, True)

def test_from_row_when_unregistered_class_returns_job_instance_without_attributes(self):
job = self.load_unregistered_job()

self.assertEqual(job.job_id, self.mock_job_id)
self.assertEqual(job.attempts, self.mock_attempts)
self.assertEqual(job.run_at, self.mock_run_at)
self.assertEqual(job.queue, self.mock_queue)
self.assertEqual(job.max_attempts, self.mock_max_attempts)
self.assertIsNone(job.extra_fields)
self.assertIsNone(job.attributes)

def test_from_row_when_unregistered_class_returns_job_instance_with_extra_fields(self):
job = self.load_unregistered_job_with_extra_fields()

self.assertDictEqual(job.extra_fields, self.mock_extra_fields)

def test_from_row_when_registered_class_returns_concrete_job_instance(self):
job = self.load_registered_job()

self.assertEqual(job.class_name, 'RegisteredJob')
self.assertEqual(job.abstract, False)

def test_from_row_when_registered_class_returns_concrete_job_instance_with_attributes(self):
def test_from_row_when_registered_class_returns_job_instance_with_attributes(self):
job = self.load_registered_job()

self.assertEqual(job.job_id, self.mock_job_id)
self.assertEqual(job.attempts, self.mock_attempts)
self.assertEqual(job.run_at, self.mock_run_at)
self.assertEqual(job.queue, self.mock_queue)
self.assertEqual(job.max_attempts, self.mock_max_attempts)
self.assertIsNone(job.extra_fields)
# below attributes match the registered class fixture
self.assertDictEqual(job.attributes, {
'id': 100,
Expand All @@ -84,6 +129,11 @@ def test_from_row_when_registered_class_returns_concrete_job_instance_with_attri
'is_blind': True
})

def test_from_row_when_registered_class_returns_job_instance_with_extra_fields(self):
job = self.load_registered_job_with_extra_fields()

self.assertDictEqual(job.extra_fields, self.mock_extra_fields)

#********** .set_error_unlock tests **********#

def assert_job_updated_field(self, job, field, value):
Expand Down
47 changes: 44 additions & 3 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import json
from unittest import TestCase
from unittest.mock import patch, MagicMock
from pyworker.worker import Worker, TerminatedException, get_current_time
Expand All @@ -20,6 +21,15 @@ def setUp(self, mock_db):
queue='default',
attempts=0,
run_at=mocked_run_at)
self.mock_extra_fields = {
'extra_field1_str': 'extra_field1_value',
'extra_field2_int': 100,
'extra_field3_float': 1.1,
'extra_field4_bool': True,
'extra_field5_bool': False,
'extra_field6_json': {'a': [1, 2, 3]},
'extra_field7_none': None
}

def tearDown(self):
pass
Expand All @@ -38,6 +48,13 @@ def test_worker_init(self, mock_db, *_):
self.assertEqual(worker.max_run_time, 3600)
self.assertEqual(worker.queue_names, 'default')
self.assertEqual(worker.name, 'host:localhost pid:1234')
self.assertIsNone(worker.extra_delayed_job_fields)

@patch('pyworker.worker.DBConnector')
def test_worker_init_with_extra_delayed_job_fields(self, *_):
worker = Worker('dummy', extra_delayed_job_fields=self.mock_extra_fields.keys())

self.assertEqual(worker.extra_delayed_job_fields, self.mock_extra_fields.keys())

#********** .run tests **********#

Expand Down Expand Up @@ -85,12 +102,18 @@ def assert_instrument_context_reports_custom_attributes(self, job, newrelic_agen
newrelic_agent.add_custom_attribute.assert_any_call('job_queue', job.queue)
newrelic_agent.add_custom_attribute.assert_any_call('job_latency', self.mocked_latency)
newrelic_agent.add_custom_attribute.assert_any_call('job_attempts', job.attempts)
if job.extra_fields is not None:
for key, value in job.extra_fields.items():
if value is not None:
if key.endswith('_json'):
value = json.dumps(value)
newrelic_agent.add_custom_attribute.assert_any_call(key, value)

def test_worker_handle_job_when_job_is_none_does_nothing(self):
self.worker.handle_job(None) # no error raised

@patch('pyworker.worker.newrelic.agent', return_value=MagicMock())
def test_worker_handle_job_when_job_is_unsupported_type_sets_error(self, newrelic_agent):
def test_worker_handle_job_when_job_is_unsupported_type_sets_error(self, *_):
job = self.mock_job
job.abstract = True

Expand All @@ -114,6 +137,20 @@ def test_worker_handle_job_when_job_is_unsupported_type_reports_error_to_newreli
newrelic_agent.record_exception.assert_called_once()
newrelic_agent.add_custom_attribute.assert_any_call('error', True)

@patch('pyworker.worker.get_current_time')
@patch('pyworker.worker.newrelic.agent', return_value=MagicMock())
def test_worker_handle_job_when_job_is_unsupported_type_reports_extra_fields_to_newrelic(
self, newrelic_agent, get_current_time):
get_current_time.return_value = self.mocked_now
job = self.mock_job
job.abstract = True
job.extra_fields = self.mock_extra_fields
self.worker.newrelic_app = MagicMock()

self.worker.handle_job(job)

self.assert_instrument_context_reports_custom_attributes(job, newrelic_agent)

def test_worker_handle_job_calls_all_hooks_then_removes_from_queue(self):
self.worker.handle_job(self.mock_job)

Expand All @@ -129,11 +166,13 @@ def test_worker_handle_job_calls_all_hooks_then_removes_from_queue(self):
def test_worker_handle_job_when_no_errors_reports_success_to_newrelic(
self, newrelic_agent, get_current_time):
get_current_time.return_value = self.mocked_now
job = self.mock_job
job.extra_fields = self.mock_extra_fields
self.worker.newrelic_app = MagicMock()

self.worker.handle_job(self.mock_job)
self.worker.handle_job(job)

self.assert_instrument_context_reports_custom_attributes(self.mock_job, newrelic_agent)
self.assert_instrument_context_reports_custom_attributes(job, newrelic_agent)
newrelic_agent.record_exception.assert_not_called()
newrelic_agent.add_custom_attribute.assert_any_call('error', False)
newrelic_agent.add_custom_attribute.assert_any_call('job_failure', False)
Expand All @@ -156,6 +195,7 @@ def test_worker_handle_job_when_error_report_to_newrelic(self,
job = self.mock_job
job.set_error_unlock.return_value = False
job.run.side_effect = Exception('test error')
job.extra_fields = self.mock_extra_fields
self.worker.newrelic_app = MagicMock()

self.worker.handle_job(job)
Expand All @@ -174,6 +214,7 @@ def test_worker_handle_job_when_permanent_error_reports_failure_to_newrelic(
job = self.mock_job
job.set_error_unlock.return_value = True
job.run.side_effect = Exception('test error')
job.extra_fields = self.mock_extra_fields
self.worker.newrelic_app = MagicMock()

self.worker.handle_job(job)
Expand Down

0 comments on commit a4cc795

Please sign in to comment.