From 63b1497db2504ef4d213f88fa4da9645ddfe1a4a Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 27 Oct 2023 11:49:09 +0200 Subject: [PATCH] refactor: deduplicate code in templates and tests --- ansible_rulebook/action/run_job_template.py | 10 +--- .../action/run_workflow_template.py | 10 +--- ansible_rulebook/util.py | 18 ++++++++ tests/unit/action/conftest.py | 28 +++++++++++ tests/unit/action/test_controller.py | 46 +++++++++++++++++++ tests/unit/action/test_run_job_template.py | 46 ------------------- .../unit/action/test_run_workflow_template.py | 46 ------------------- 7 files changed, 96 insertions(+), 108 deletions(-) create mode 100644 tests/unit/action/conftest.py create mode 100644 tests/unit/action/test_controller.py diff --git a/ansible_rulebook/action/run_job_template.py b/ansible_rulebook/action/run_job_template.py index 62084e65..891dd1b1 100644 --- a/ansible_rulebook/action/run_job_template.py +++ b/ansible_rulebook/action/run_job_template.py @@ -25,7 +25,7 @@ JobTemplateNotFoundException, ) from ansible_rulebook.job_template_runner import job_template_runner -from ansible_rulebook.util import run_at +from ansible_rulebook.util import process_controller_host_limit, run_at from .control import Control from .helper import Helper @@ -46,13 +46,7 @@ def __init__(self, metadata: Metadata, control: Control, **action_args): self.organization = self.action_args["organization"] self.job_id = str(uuid.uuid4()) self.job_args = self.action_args.get("job_args", {}) - if "limit" in self.job_args: - if isinstance(self.job_args["limit"], list): - self.job_args["limit"] = ",".join(self.job_args["limit"]) - else: - self.job_args["limit"] = str(self.job_args["limit"]) - else: - self.job_args["limit"] = ",".join(self.helper.control.hosts) + process_controller_host_limit(self) self.controller_job = {} async def __call__(self): diff --git a/ansible_rulebook/action/run_workflow_template.py b/ansible_rulebook/action/run_workflow_template.py index 3dccbc39..00807062 100644 --- a/ansible_rulebook/action/run_workflow_template.py +++ b/ansible_rulebook/action/run_workflow_template.py @@ -25,7 +25,7 @@ WorkflowJobTemplateNotFoundException, ) from ansible_rulebook.job_template_runner import job_template_runner -from ansible_rulebook.util import run_at +from ansible_rulebook.util import process_controller_host_limit, run_at from .control import Control from .helper import Helper @@ -46,13 +46,7 @@ def __init__(self, metadata: Metadata, control: Control, **action_args): self.organization = self.action_args["organization"] self.job_id = str(uuid.uuid4()) self.job_args = self.action_args.get("job_args", {}) - if "limit" in self.job_args: - if isinstance(self.job_args["limit"], list): - self.job_args["limit"] = ",".join(self.job_args["limit"]) - else: - self.job_args["limit"] = str(self.job_args["limit"]) - else: - self.job_args["limit"] = ",".join(self.helper.control.hosts) + process_controller_host_limit(self) self.controller_job = {} async def __call__(self): diff --git a/ansible_rulebook/util.py b/ansible_rulebook/util.py index beefbcbf..ff8716d4 100644 --- a/ansible_rulebook/util.py +++ b/ansible_rulebook/util.py @@ -261,3 +261,21 @@ def _builtin_filter_path(name: str) -> Tuple[bool, str]: dirname = os.path.dirname(os.path.realpath(__file__)) path = os.path.join(dirname, "event_filter", filter_name + ".py") return os.path.exists(path), path + + +# TODO(alex): This function should be removed after the +# controller templates are refactored to deduplicate code +def process_controller_host_limit(template_obj): + if "limit" in template_obj.job_args: + if isinstance(template_obj.job_args["limit"], list): + template_obj.job_args["limit"] = ",".join( + template_obj.job_args["limit"], + ) + else: + template_obj.job_args["limit"] = str( + template_obj.job_args["limit"], + ) + else: + template_obj.job_args["limit"] = ",".join( + template_obj.helper.control.hosts, + ) diff --git a/tests/unit/action/conftest.py b/tests/unit/action/conftest.py new file mode 100644 index 00000000..a664eee6 --- /dev/null +++ b/tests/unit/action/conftest.py @@ -0,0 +1,28 @@ +import asyncio + +import pytest + +from ansible_rulebook.action.control import Control +from ansible_rulebook.action.metadata import Metadata + + +@pytest.fixture +def base_metadata(): + return Metadata( + rule="r1", + rule_set="rs1", + rule_uuid="u1", + rule_set_uuid="u2", + rule_run_at="abc", + ) + + +@pytest.fixture +def base_control(): + return Control( + queue=asyncio.Queue(), + inventory="abc", + hosts=["all"], + variables={"a": 1}, + project_data_file="", + ) diff --git a/tests/unit/action/test_controller.py b/tests/unit/action/test_controller.py new file mode 100644 index 00000000..5ba46db3 --- /dev/null +++ b/tests/unit/action/test_controller.py @@ -0,0 +1,46 @@ +import pytest + +from ansible_rulebook.action.run_job_template import RunJobTemplate +from ansible_rulebook.action.run_workflow_template import RunWorkflowTemplate + + +@pytest.mark.parametrize( + "template_class", + [ + pytest.param(RunJobTemplate, id="job_template"), + pytest.param(RunWorkflowTemplate, id="workflow_template"), + ], +) +@pytest.mark.parametrize( + "input,expected", + [ + pytest.param({"limit": "localhost"}, "localhost", id="single_host"), + pytest.param( + {"limit": "localhost,localhost2"}, + "localhost,localhost2", + id="multiple_hosts_str", + ), + pytest.param( + {"limit": ["localhost", "localhost2"]}, + "localhost,localhost2", + id="multiple_hosts", + ), + pytest.param({}, "all", id="default"), + ], +) +@pytest.mark.asyncio +async def test_controller_custom_host_limit( + input, expected, template_class, base_metadata, base_control +): + """Test controller templates process the host limit in job_args.""" + action_args = { + "name": "fred", + "organization": "Default", + "retries": 1, + "retry": True, + "delay": 1, + "set_facts": True, + "job_args": input, + } + template = template_class(base_metadata, base_control, **action_args) + assert template.job_args["limit"] == expected diff --git a/tests/unit/action/test_run_job_template.py b/tests/unit/action/test_run_job_template.py index 28f34a16..8646729e 100644 --- a/tests/unit/action/test_run_job_template.py +++ b/tests/unit/action/test_run_job_template.py @@ -215,49 +215,3 @@ async def test_run_job_template_retries(): drools_mock.assert_called_once() _validate(queue, True) - - -@pytest.mark.parametrize( - "input,output", - [ - pytest.param({"limit": "localhost"}, "localhost", id="single_host"), - pytest.param( - {"limit": "localhost,localhost2"}, - "localhost,localhost2", - id="multiple_hosts_str", - ), - pytest.param( - {"limit": ["localhost", "localhost2"]}, - "localhost,localhost2", - id="multiple_hosts", - ), - pytest.param({}, "all", id="default"), - ], -) -def test_custom_host_limit(input, output): - queue = asyncio.Queue() - metadata = Metadata( - rule="r1", - rule_set="rs1", - rule_uuid="u1", - rule_set_uuid="u2", - rule_run_at="abc", - ) - control = Control( - queue=queue, - inventory="abc", - hosts=["all"], - variables={"a": 1}, - project_data_file="", - ) - action_args = { - "name": "fred", - "organization": "Default", - "retries": 1, - "retry": True, - "delay": 1, - "set_facts": True, - "job_args": input, - } - template = RunJobTemplate(metadata, control, **action_args) - assert template.job_args["limit"] == output diff --git a/tests/unit/action/test_run_workflow_template.py b/tests/unit/action/test_run_workflow_template.py index a8f25158..c780fe17 100644 --- a/tests/unit/action/test_run_workflow_template.py +++ b/tests/unit/action/test_run_workflow_template.py @@ -218,49 +218,3 @@ async def test_run_workflow_template_retries(): drools_mock.assert_called_once() _validate(queue, True) - - -@pytest.mark.parametrize( - "input,output", - [ - pytest.param({"limit": "localhost"}, "localhost", id="single_host"), - pytest.param( - {"limit": "localhost,localhost2"}, - "localhost,localhost2", - id="multiple_hosts_str", - ), - pytest.param( - {"limit": ["localhost", "localhost2"]}, - "localhost,localhost2", - id="multiple_hosts", - ), - pytest.param({}, "all", id="default"), - ], -) -def test_custom_host_limit(input, output): - queue = asyncio.Queue() - metadata = Metadata( - rule="r1", - rule_set="rs1", - rule_uuid="u1", - rule_set_uuid="u2", - rule_run_at="abc", - ) - control = Control( - queue=queue, - inventory="abc", - hosts=["all"], - variables={"a": 1}, - project_data_file="", - ) - action_args = { - "name": "fred", - "organization": "Default", - "retries": 1, - "retry": True, - "delay": 1, - "set_facts": True, - "job_args": input, - } - template = RunWorkflowTemplate(metadata, control, **action_args) - assert template.job_args["limit"] == output