diff --git a/Dockerfile b/Dockerfile index 38d19a9a4..053439580 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,7 @@ RUN pip install -U pip \ aiokafka \ watchdog \ azure-servicebus \ + psycopg \ && ansible-galaxy collection install ansible.eda RUN bash -c "if [ $DEVEL_COLLECTION_LIBRARY -ne 0 ]; then \ diff --git a/ansible_rulebook/action/helper.py b/ansible_rulebook/action/helper.py index de198b8d8..3d7142b73 100644 --- a/ansible_rulebook/action/helper.py +++ b/ansible_rulebook/action/helper.py @@ -24,6 +24,8 @@ KEY_EDA_VARS = "ansible_eda" INTERNAL_ACTION_STATUS = "successful" +FAILED_STATUS = "failed" +SUCCESSFUL_STATUS = "successful" class Helper: diff --git a/ansible_rulebook/action/pg_notify.py b/ansible_rulebook/action/pg_notify.py new file mode 100644 index 000000000..6c5df29ce --- /dev/null +++ b/ansible_rulebook/action/pg_notify.py @@ -0,0 +1,109 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import uuid + +import xxhash +from psycopg import AsyncClientCursor, AsyncConnection, OperationalError + +from .control import Control +from .helper import FAILED_STATUS, Helper +from .metadata import Metadata + +logger = logging.getLogger(__name__) + +MAX_MESSAGE_LENGTH = 7 * 1024 +MESSAGE_CHUNKED_UUID = "_message_chunked_uuid" +MESSAGE_CHUNK_COUNT = "_message_chunk_count" +MESSAGE_CHUNK_SEQUENCE = "_message_chunk_sequence" +MESSAGE_CHUNK = "_chunk" +MESSAGE_LENGTH = "_message_length" +MESSAGE_XX_HASH = "_message_xx_hash" + + +class PGNotify: + """The PGNotify action sends an event to a PG Pub Sub Channel + Needs + dsn https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-CONNSTRING-KEYWORD-VALUE + channel the channel name to send the notifies + event + """ + + def __init__(self, metadata: Metadata, control: Control, **action_args): + self.helper = Helper(metadata, control, "pg_notify") + self.action_args = action_args + + async def __call__(self): + if not self.action_args["event"]: + return + + try: + async with await AsyncConnection.connect( + conninfo=self.action_args["dsn"], + autocommit=True, + ) as conn: + async with AsyncClientCursor(connection=conn) as cursor: + if self.action_args.get("remove_meta", False): + event = self.action_args["event"].copy() + if "meta" in event: + event.pop("meta") + else: + event = self.action_args["event"] + + payload = json.dumps(event) + message_length = len(payload) + if message_length >= MAX_MESSAGE_LENGTH: + for chunk in self._to_chunks(payload, message_length): + await cursor.execute( + f"NOTIFY {self.action_args['channel']}, " + f"'{json.dumps(chunk)}';" + ) + else: + await cursor.execute( + f"NOTIFY {self.action_args['channel']}, " + f"'{payload}';" + ) + except OperationalError as e: + logger.error(f"PG Notify operational error {e}") + data = dict(status=FAILED_STATUS, message=str(e)) + await self.helper.send_status(data) + raise e + + await self.helper.send_default_status() + + def _to_chunks(self, payload: str, message_length: int): + xx_hash = xxhash.xxh32(payload.encode("utf-8")).hexdigest() + logger.debug("Message length exceeds, will chunk") + message_uuid = str(uuid.uuid4()) + number_of_chunks = int(message_length / MAX_MESSAGE_LENGTH) + 1 + chunked = { + MESSAGE_CHUNKED_UUID: message_uuid, + MESSAGE_CHUNK_COUNT: number_of_chunks, + MESSAGE_LENGTH: message_length, + MESSAGE_XX_HASH: xx_hash, + } + logger.debug(f"Chunk info {message_uuid}") + logger.debug(f"Number of chunks {number_of_chunks}") + logger.debug(f"Total data size {message_length}") + logger.debug(f"XX Hash {xx_hash}") + + sequence = 1 + for i in range(0, message_length, MAX_MESSAGE_LENGTH): + chunked[MESSAGE_CHUNK] = payload[i : i + MAX_MESSAGE_LENGTH] + chunked[MESSAGE_CHUNK_SEQUENCE] = sequence + sequence += 1 + yield chunked diff --git a/ansible_rulebook/rule_set_runner.py b/ansible_rulebook/rule_set_runner.py index 930e45f16..31e95efda 100644 --- a/ansible_rulebook/rule_set_runner.py +++ b/ansible_rulebook/rule_set_runner.py @@ -33,6 +33,7 @@ from ansible_rulebook.action.debug import Debug from ansible_rulebook.action.metadata import Metadata from ansible_rulebook.action.noop import Noop +from ansible_rulebook.action.pg_notify import PGNotify from ansible_rulebook.action.post_event import PostEvent from ansible_rulebook.action.print_event import PrintEvent from ansible_rulebook.action.retract_fact import RetractFact @@ -75,6 +76,7 @@ "run_module": RunModule, "run_job_template": RunJobTemplate, "run_workflow_template": RunWorkflowTemplate, + "pg_notify": PGNotify, } diff --git a/ansible_rulebook/schema/ruleset_schema.json b/ansible_rulebook/schema/ruleset_schema.json index 0f13a9dae..6321f0892 100644 --- a/ansible_rulebook/schema/ruleset_schema.json +++ b/ansible_rulebook/schema/ruleset_schema.json @@ -206,6 +206,9 @@ }, { "$ref": "#/$defs/shutdown-action" + }, + { + "$ref": "#/$defs/pg-notify-action" } ] } @@ -244,6 +247,9 @@ }, { "$ref": "#/$defs/shutdown-action" + }, + { + "$ref": "#/$defs/pg-notify-action" } ] } @@ -510,6 +516,42 @@ ], "additionalProperties": false }, + "pg-notify-action": { + "type": "object", + "properties": { + "pg_notify": { + "type": "object", + "properties": { + "dsn": { + "type": "string" + }, + "channel": { + "type": "string" + }, + "event": { + "type": [ + "string", + "object" + ] + }, + "remove_meta": { + "type": "boolean", + "default": false + } + }, + "required": [ + "dsn", + "channel", + "event" + ], + "additionalProperties": false + } + }, + "required": [ + "pg_notify" + ], + "additionalProperties": false + }, "post-event-action": { "type": "object", "properties": { diff --git a/setup.cfg b/setup.cfg index cca38f83e..799bbc7cc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,9 @@ install_requires = ansible-runner websockets drools_jpy == 0.3.8 - watchdog + watchdog + psycopg + xxhash [options.packages.find] include = diff --git a/tests/unit/action/test_pg_notify.py b/tests/unit/action/test_pg_notify.py new file mode 100644 index 000000000..f0b9f0690 --- /dev/null +++ b/tests/unit/action/test_pg_notify.py @@ -0,0 +1,248 @@ +# Copyright 2023 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import json +from unittest.mock import MagicMock, patch + +import pytest +from freezegun import freeze_time +from psycopg import OperationalError + +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.helper import FAILED_STATUS, SUCCESSFUL_STATUS +from ansible_rulebook.action.metadata import Metadata +from ansible_rulebook.action.pg_notify import ( + MESSAGE_CHUNK, + MESSAGE_CHUNK_COUNT, + MESSAGE_CHUNKED_UUID, + PGNotify, +) +from ansible_rulebook.conf import settings + +DUMMY_UUID = "eb7de03f-6f8f-4943-b69e-3c90db346edf" +RULE_UUID = "abcdef3f-6f8f-4943-b69e-3c90db346edf" +RULE_SET_UUID = "00aabbcc-1111-2222-b69e-3c90db346edf" +RULE_RUN_AT = "2023-06-11T12:13:10Z" +ACTION_RUN_AT = "2023-06-11T12:13:14Z" +REQUIRED_KEYS = { + "action", + "action_uuid", + "activation_id", + "activation_instance_id", + "message", + "rule_run_at", + "run_at", + "rule", + "ruleset", + "rule_uuid", + "ruleset_uuid", + "status", + "type", + "matching_events", +} + + +class AsyncContextManager: + async def __aenter__(self): + return self + + async def __aexit__(self): + pass + + +def _validate(queue, metadata, status, event, message=None): + while not queue.empty(): + data = queue.get_nowait() + if data["type"] == "Action": + action = data + + assert action["action"] == "pg_notify" + assert action["action_uuid"] == DUMMY_UUID + assert action["activation_id"] == settings.identifier + assert action["rule_run_at"] == metadata.rule_run_at + assert action["rule"] == metadata.rule + assert action["ruleset"] == metadata.rule_set + assert action["rule_uuid"] == metadata.rule_uuid + assert action["ruleset_uuid"] == metadata.rule_set_uuid + assert action["status"] == status + assert action["type"] == "Action" + if action["status"] == SUCCESSFUL_STATUS: + assert action["run_at"] == ACTION_RUN_AT + assert action["matching_events"] == event + assert action.get("message", None) == message + assert len(set(action.keys()).difference(REQUIRED_KEYS)) == 0 + + +TEST_PAYLOADS = [ + ({"abc": "def", "simple": True, "pi": 3.14259}, {"notifies": 1}), + ( + {"abc": "def", "simple": True, "pi": 3.14259, "meta": {"uuid": 1}}, + {"notifies": 1}, + ), + ( + {"a": 1, "blob": "x" * 9000, "y": 365, "phased": True}, + {"notifies": 2, "number_of_chunks": 2}, + ), +] + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +@pytest.mark.parametrize("event,result", TEST_PAYLOADS) +async def test_pg_notify(event, result): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + channel_name = "my_chanel" + control = Control( + queue=queue, + inventory="abc", + hosts=["all"], + variables={"event": event}, + project_data_file="", + ) + + dsn = "host=localhost port=5432 dbname=mydb connect_timeout=10" + action_args = { + "dsn": dsn, + "event": event, + "channel": channel_name, + } + notifies = 0 + if "meta" in event: + action_args["remove_meta"] = True + compared_event = event.copy() + compared_event.pop("meta") + else: + compared_event = event + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with patch( + "ansible_rulebook.action.pg_notify." "AsyncConnection.connect", + return_value=MagicMock(AsyncContextManager()), + ) as conn: + if "exception" in result: + conn.side_effect = result["exception"] + with patch( + "ansible_rulebook.action.pg_notify.AsyncClientCursor", + new=MagicMock(AsyncContextManager()), + ) as cursor: + await PGNotify(metadata, control, **action_args)() + conn.assert_called_once_with(conninfo=dsn, autocommit=True) + conn.assert_called_once() + assert len(cursor.mock_calls) == 3 + result["notifies"] + entire_msg = "" + for c in cursor.mock_calls: + if len(c.args) == 1 and type(c.args[0]) == str: + notifies += 1 + parts = c.args[0].split(" ", 2) + assert len(parts) == 3 + assert parts[0] == "NOTIFY" + assert parts[1].strip(",") == channel_name + payload = json.loads(parts[2][1:-2]) + if MESSAGE_CHUNKED_UUID in payload: + assert ( + payload[MESSAGE_CHUNK_COUNT] + == result["number_of_chunks"] + ) + entire_msg += payload[MESSAGE_CHUNK] + else: + entire_msg = parts[2][1:-2] + + assert notifies == result["notifies"] + assert json.loads(entire_msg) == compared_event + _validate(queue, metadata, SUCCESSFUL_STATUS, {"m": event}) + + +EXCEPTIONAL_PAYLOADS = [ + ( + {"abc": "will fail"}, + {"message": "Kaboom", "exception": OperationalError("Kaboom")}, + ), +] + + +@freeze_time("2023-06-11 12:13:14") +@pytest.mark.asyncio +@pytest.mark.parametrize("event,result", EXCEPTIONAL_PAYLOADS) +async def test_pg_notify_with_exception(event, result): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + channel_name = "my_chanel" + control = Control( + queue=queue, + inventory="abc", + hosts=["all"], + variables={"event": event}, + project_data_file="", + ) + + dsn = "host=localhost port=5432 dbname=mydb connect_timeout=10" + action_args = { + "dsn": dsn, + "event": event, + "channel": channel_name, + } + + with patch("uuid.uuid4", return_value=DUMMY_UUID): + with pytest.raises(OperationalError): + with patch( + "ansible_rulebook.action.pg_notify." "AsyncConnection.connect", + return_value=MagicMock(AsyncContextManager()), + ) as conn: + conn.side_effect = result["exception"] + await PGNotify(metadata, control, **action_args)() + + _validate(queue, metadata, FAILED_STATUS, {"m": event}, result["message"]) + + +@pytest.mark.asyncio +async def test_pg_notify_with_no_event(): + queue = asyncio.Queue() + metadata = Metadata( + rule="r1", + rule_set="rs1", + rule_uuid=RULE_UUID, + rule_set_uuid=RULE_SET_UUID, + rule_run_at=RULE_RUN_AT, + ) + channel_name = "my_chanel" + control = Control( + queue=queue, + inventory="abc", + hosts=["all"], + variables={"event": {}}, + project_data_file="", + ) + + dsn = "host=localhost port=5432 dbname=mydb connect_timeout=10" + action_args = { + "dsn": dsn, + "event": {}, + "channel": channel_name, + } + + await PGNotify(metadata, control, **action_args)() + assert queue.empty()