Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making use of service resource over the low level client for SQS. #42

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Here is a basic code sample:

**Standard Listener**

::
.. code:: python

from sqs_listener import SqsListener

Expand All @@ -44,7 +44,7 @@ Here is a basic code sample:

**Error Listener**

::
.. code:: python

from sqs_listener import SqsListener
class MyErrorListener(SqsListener):
Expand Down Expand Up @@ -75,7 +75,7 @@ Running as a Daemon

| Typically, in a production environment, you'll want to listen to an SQS queue with a daemonized process.
The simplest way to do this is by running the listener in a detached process. On a typical Linux distribution it might look like this:
|
|
``nohup python my_listener.py > listener.log &``
| And saving the resulting process id for later (for stopping the listener via the ``kill`` command).
|
Expand All @@ -94,7 +94,7 @@ Logging
|
| For instance:

::
.. code:: python

logger = logging.getLogger('sqs_listener')
logger.setLevel(logging.INFO)
Expand All @@ -111,7 +111,7 @@ Logging
|
| Or to a log file:

::
.. code:: python

logger = logging.getLogger('sqs_listener')
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -140,7 +140,7 @@ Sending messages

**Launcher Example**

::
.. code:: python

from sqs_launcher import SqsLauncher

Expand Down
27 changes: 15 additions & 12 deletions sqs_launcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,43 @@ def __init__(self, queue=None, queue_url=None, create_queue=False, visibility_ti
to finish execution. See http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-visibility-timeout.html
for more information
"""
if not any(queue, queue_url):
if not any([queue, queue_url]):
raise ValueError('Either `queue` or `queue_url` should be provided.')
if (not os.environ.get('AWS_ACCOUNT_ID', None) and
not (boto3.Session().get_credentials().method in ['iam-role', 'assume-role'])):
raise EnvironmentError('Environment variable `AWS_ACCOUNT_ID` not set and no role found.')
# new session for each instantiation
self._session = boto3.session.Session()
self._resource = self._session.resource('sqs')
self._client = self._session.client('sqs')

self._queue_name = queue
self._queue_url = queue_url
self._queue = None
if not queue_url:
queues = self._client.list_queues(QueueNamePrefix=self._queue_name)
queues = self._resource.queues.filter(QueueNamePrefix=self._queue_name)
exists = False
for q in queues.get('QueueUrls', []):
qname = q.split('/')[-1]
for q in queues:
qname = q.url.split('/')[-1]
if qname == self._queue_name:
exists = True
self._queue_url = q

self._queue_url = q.url
self._queue = q
if not exists:
if create_queue:
q = self._client.create_queue(
q = self._resource.create_queue(
QueueName=self._queue_name,
Attributes={
'VisibilityTimeout': visibility_timeout # 10 minutes
}
'VisibilityTimeout': visibility_timeout, # 10 minutes
},
)
self._queue_url = q['QueueUrl']
self._queue_url = q.url
self._queue = q
else:
raise ValueError('No queue found with name ' + self._queue_name)
else:
self._queue_name = self._get_queue_name_from_url(queue_url)
self._queue = self._resource.Queue(queue_url)

def launch_message(self, message, **kwargs):
"""
Expand All @@ -81,8 +85,7 @@ def launch_message(self, message, **kwargs):
:return: (dict) the message response from SQS
"""
sqs_logger.info("Sending message to queue " + self._queue_name)
return self._client.send_message(
QueueUrl=self._queue_url,
return self._queue.send_message(
MessageBody=json.dumps(message),
**kwargs,
)
Expand Down
207 changes: 88 additions & 119 deletions sqs_listener/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

sqs_logger = logging.getLogger('sqs_listener')


class SqsListener(object):
__metaclass__ = ABCMeta

Expand All @@ -37,17 +38,17 @@ def __init__(self, queue, **kwargs):
"""
aws_access_key = kwargs.get('aws_access_key', '')
aws_secret_key = kwargs.get('aws_secret_key', '')

if len(aws_access_key) != 0 and len(aws_secret_key) != 0:
self._aws_account_id = os.environ.get('AWS_ACCOUNT_ID', None)
if all([aws_access_key, aws_secret_key]):
boto3_session = boto3.Session(
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key
aws_secret_access_key=aws_secret_key,
)
else:
if (not os.environ.get('AWS_ACCOUNT_ID', None) and
if (not self._aws_account_id and
not ('iam-role' == boto3.Session().get_credentials().method)):
raise EnvironmentError('Environment variable `AWS_ACCOUNT_ID` not set and no role found.')

self._queue = None # The SQS Queue resource
self._queue_name = queue
self._poll_interval = kwargs.get("interval", 60)
self._queue_visibility_timeout = kwargs.get('visibility_timeout', '600')
Expand All @@ -67,135 +68,116 @@ def __init__(self, queue, **kwargs):
else:
self._session = boto3.session.Session()
self._region_name = kwargs.get('region_name', self._session.region_name)
self._client = self._initialize_client()
self._resource = self._initialize_resource()


def _initialize_client(self):
def _initialize_resource(self):
# new session for each instantiation
ssl = True
if self._region_name == 'elasticmq':
ssl = False

sqs = self._session.client('sqs', region_name=self._region_name, endpoint_url=self._endpoint_name, use_ssl=ssl)
queues = sqs.list_queues(QueueNamePrefix=self._queue_name)
mainQueueExists = False
errorQueueExists = False
if 'QueueUrls' in queues:
for q in queues['QueueUrls']:
qname = q.split('/')[-1]
if qname == self._queue_name:
mainQueueExists = True
if self._error_queue_name and qname == self._error_queue_name:
errorQueueExists = True


# create queue if necessary.
sqs = self._session.resource('sqs', region_name=self._region_name, endpoint_url=self._endpoint_name, use_ssl=ssl)
queues = sqs.queues.filter(QueueNamePrefix=self._queue_name)
main_queue_exists = False
error_queue_exists = False
for q in queues:
qname = q.url.split('/')[-1]
if qname == self._queue_name:
self._queue_url = q.url
main_queue_exists = True
if self._error_queue_name and qname == self._error_queue_name:
error_queue_exists = True

# create queue if necessary.
# creation is idempotent, no harm in calling on a queue if it already exists.
if self._queue_url is None:
if not mainQueueExists:
if not main_queue_exists:
sqs_logger.warning("main queue not found, creating now")

queue_attributes = {
'VisibilityTimeout': self._queue_visibility_timeout, # 10 minutes
}
# is this a fifo queue?
if self._queue_name.endswith(".fifo"):
fifoQueue="true"
q = sqs.create_queue(
QueueName=self._queue_name,
Attributes={
'VisibilityTimeout': self._queue_visibility_timeout, # 10 minutes
'FifoQueue':fifoQueue
}
)
else:
# need to avoid FifoQueue property for normal non-fifo queues
q = sqs.create_queue(
QueueName=self._queue_name,
Attributes={
'VisibilityTimeout': self._queue_visibility_timeout, # 10 minutes
}
)
self._queue_url = q['QueueUrl']

if self._error_queue_name and not errorQueueExists:
queue_attributes["FifoQueue"] = "true"
q = sqs.create_queue(
QueueName=self._queue_name,
Attributes=queue_attributes,
)
self._queue_url = q.url

if self._error_queue_name and not error_queue_exists:
sqs_logger.warning("error queue not found, creating now")
q = sqs.create_queue(
QueueName=self._error_queue_name,
Attributes={
'VisibilityTimeout': self._queue_visibility_timeout # 10 minutes
}
'VisibilityTimeout': self._queue_visibility_timeout, # 10 minutes
},
)

if self._queue_url is None:
if os.environ.get('AWS_ACCOUNT_ID', None):
qs = sqs.get_queue_url(QueueName=self._queue_name,
QueueOwnerAWSAccountId=os.environ.get('AWS_ACCOUNT_ID', None))
if self._aws_account_id:
qs = sqs.get_queue_by_name(
QueueName=self._queue_name,
QueueOwnerAWSAccountId=self._aws_account_id,
)
else:
qs = sqs.get_queue_url(QueueName=self._queue_name)
self._queue_url = qs['QueueUrl']
qs = sqs.get_queue_by_name(
QueueName=self._queue_name,
)
self._queue_url = qs.url
self._queue = sqs.Queue(self._queue_url)
return sqs

def _start_listening(self):
# TODO consider incorporating output processing from here: https://github.com/debrouwere/sqs-antenna/blob/master/antenna/__init__.py
# TODO consider incorporating output processing from here:
# https://github.com/debrouwere/sqs-antenna/blob/master/antenna/__init__.py
while True:
# calling with WaitTimeSecconds of zero show the same behavior as
# calling with `WaitTimeSecconds` of zero show the same behavior as
# not specifiying a wait time, ie: short polling
messages = self._client.receive_message(
QueueUrl=self._queue_url,
MessageAttributeNames=self._message_attribute_names,
messages = self._queue.receive_messages(
AttributeNames=self._attribute_names,
MessageAttributeNames=self._message_attribute_names,
WaitTimeSeconds=self._wait_time,
MaxNumberOfMessages=self._max_number_of_messages
MaxNumberOfMessages=self._max_number_of_messages,
)
if 'Messages' in messages:

sqs_logger.debug(messages)
continue
sqs_logger.info("{} messages received".format(len(messages['Messages'])))
for m in messages['Messages']:
receipt_handle = m['ReceiptHandle']
m_body = m['Body']
message_attribs = None
attribs = None

# catch problems with malformed JSON, usually a result of someone writing poor JSON directly in the AWS console
try:
params_dict = json.loads(m_body)
except:
sqs_logger.warning("Unable to parse message - JSON is not formatted properly")
continue
if 'MessageAttributes' in m:
message_attribs = m['MessageAttributes']
if 'Attributes' in m:
attribs = m['Attributes']
try:
if self._force_delete:
self._client.delete_message(
QueueUrl=self._queue_url,
ReceiptHandle=receipt_handle
)
self.handle_message(params_dict, message_attribs, attribs)
else:
self.handle_message(params_dict, message_attribs, attribs)
self._client.delete_message(
QueueUrl=self._queue_url,
ReceiptHandle=receipt_handle
)
except Exception as ex:
# need exception logtype to log stack trace
sqs_logger.exception(ex)
if self._error_queue_name:
exc_type, exc_obj, exc_tb = sys.exc_info()

sqs_logger.info( "Pushing exception to error queue")
error_launcher = SqsLauncher(queue=self._error_queue_name, create_queue=True)
error_launcher.launch_message(
{
'exception_type': str(exc_type),
'error_message': str(ex.args)
}
)

else:
if not messages:
time.sleep(self._poll_interval)
continue
sqs_logger.debug(messages)
sqs_logger.info("{} messages received".format(len(messages)))
for m in messages:
receipt_handle = m.receipt_handle
m_body = m.body
message_attribs = m.message_attributes
attribs = m.attributes
# catch problems with malformed JSON, usually a result
# of someone writing poor JSON directly in the AWS
# console
try:
params_dict = json.loads(m_body)
except:
sqs_logger.warning("Unable to parse message - JSON is not formatted properly")
continue
try:
if self._force_delete:
m.delete()
self.handle_message(params_dict, message_attribs, attribs)
else:
self.handle_message(params_dict, message_attribs, attribs)
m.delete()
except Exception as ex:
# need exception logtype to log stack trace
sqs_logger.exception(ex)
if self._error_queue_name:
exc_type, exc_obj, exc_tb = sys.exc_info()
sqs_logger.info( "Pushing exception to error queue")
error_launcher = SqsLauncher(queue=self._error_queue_name, create_queue=True)
error_launcher.launch_message(
{
'exception_type': str(exc_type),
'error_message': str(ex.args)
}
)

def listen(self):
sqs_logger.info( "Listening to queue " + self._queue_name)
Expand All @@ -204,19 +186,6 @@ def listen(self):

self._start_listening()

def _prepare_logger(self):
logger = logging.getLogger('eg_daemon')
logger.setLevel(logging.INFO)

sh = logging.StreamHandler(sys.stdout)
sh.setLevel(logging.INFO)

formatstr = '[%(asctime)s - %(name)s - %(levelname)s] %(message)s'
formatter = logging.Formatter(formatstr)

sh.setFormatter(formatter)
logger.addHandler(sh)

@abstractmethod
def handle_message(self, body, attributes, messages_attributes):
"""
Expand Down