diff --git a/.github/workflows/task_runner_e2e.yml b/.github/workflows/task_runner_e2e.yml index 98ffde8433..a7e01c579f 100644 --- a/.github/workflows/task_runner_e2e.yml +++ b/.github/workflows/task_runner_e2e.yml @@ -73,7 +73,7 @@ jobs: run: | python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py \ -m ${{ env.MODEL_NAME }} --model_name ${{ env.MODEL_NAME }} \ - --num_rounds $NUM_ROUNDS --num_collaborators $NUM_COLLABORATORS + --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} echo "Task runner end to end test run completed" - name: Print test summary @@ -140,7 +140,7 @@ jobs: run: | python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py \ -m ${{ env.MODEL_NAME }} --model_name ${{ env.MODEL_NAME }} \ - --num_rounds $NUM_ROUNDS --num_collaborators $NUM_COLLABORATORS --disable_tls + --num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} --disable_tls echo "Task runner end to end test run completed" - name: Print test summary diff --git a/tests/end_to_end/utils/summary_helper.py b/tests/end_to_end/utils/summary_helper.py index 685b94bb6a..599c4b6ec5 100644 --- a/tests/end_to_end/utils/summary_helper.py +++ b/tests/end_to_end/utils/summary_helper.py @@ -5,6 +5,8 @@ from lxml import etree import os +import tests.end_to_end.utils.constants as constants + # Initialize the XML parser parser = etree.XMLParser(recover=True, encoding='utf-8') tree = ET.parse("results/results.xml", parser=parser) @@ -101,29 +103,41 @@ def get_testcase_result(): return database_list -if __name__ == "__main__": +def main(): """ Main function to get the test case results and aggregator logs And write the results to GitHub step summary + IMP: Do not fail the test in any scenario """ result = get_testcase_result() + if not all([os.getenv(var) for var in ["NUM_COLLABORATORS", "NUM_ROUNDS", "MODEL_NAME", "GITHUB_STEP_SUMMARY"]]): + print("One or more environment variables not set. Skipping writing to GitHub step summary") + return + num_cols = os.getenv("NUM_COLLABORATORS") num_rounds = os.getenv("NUM_ROUNDS") model_name = os.getenv("MODEL_NAME") + summary_file = os.getenv("GITHUB_STEP_SUMMARY") - if not model_name: - print("MODEL_NAME is not set, cannot find out aggregator logs") - agg_accuracy = "Not Found" - else: - workspace_name = "workspace_" + model_name - agg_log_file = os.path.join("results", workspace_name, "aggregator.log") - agg_accuracy = get_aggregated_accuracy(agg_log_file) + # Validate the model name and create the workspace name + if not model_name.upper() in constants.ModelName._member_names_: + print(f"Invalid model name: {model_name}. Skipping writing to GitHub step summary") + return - # Write the results to GitHub step summary - with open(os.getenv('GITHUB_STEP_SUMMARY'), 'a') as fh: + workspace_name = "workspace_" + model_name + agg_log_file = os.path.join("results", workspace_name, "aggregator.log") + agg_accuracy = get_aggregated_accuracy(agg_log_file) + + # Write the results to GitHub step summary file + # This file is created at runtime by the GitHub action, thus we cannot verify its existence beforehand + with open(summary_file, 'a') as fh: # DO NOT change the print statements print("| Name | Time (in seconds) | Result | Error (if any) | Collaborators | Rounds to train | Score (if applicable) |", file=fh) print("| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |", file=fh) for item in result: print(f"| {item['name']} | {item['time']} | {item['result']} | {item['err_msg']} | {num_cols} | {num_rounds} | {agg_accuracy} |", file=fh) + + +if __name__ == "__main__": + main() diff --git a/tests/github/test_double_ws_export.py b/tests/github/test_double_ws_export.py index da15a97d9f..95c9440b31 100644 --- a/tests/github/test_double_ws_export.py +++ b/tests/github/test_double_ws_export.py @@ -10,10 +10,11 @@ from concurrent.futures import ProcessPoolExecutor import psutil -from tests.github.utils import create_certified_workspace, certify_aggregator, create_collaborator +from tests.github.utils import create_certified_workspace, certify_aggregator, create_collaborator, is_path_name_allowed from openfl.utilities.utils import getfqdn_env -if __name__ == '__main__': + +def main(): # Test the pipeline parser = argparse.ArgumentParser() workspace_choice = [] @@ -31,6 +32,12 @@ args = parser.parse_args() fed_workspace = args.fed_workspace + + # Check if the path name is allowed before creating the workspace + if not is_path_name_allowed(fed_workspace): + print(f"The path name {fed_workspace} is not allowed") + return + archive_name = f'{fed_workspace}.zip' fqdn = getfqdn_env() template = args.template @@ -81,3 +88,7 @@ dir1 = workspace_root / col1 / fed_workspace executor.submit(check_call, ['fx', 'collaborator', 'start', '-n', col1], cwd=dir1) shutil.rmtree(workspace_root) + + +if __name__ == '__main__': + main() diff --git a/tests/github/test_gandlf.py b/tests/github/test_gandlf.py index 68b00a2382..a57f9f53a0 100644 --- a/tests/github/test_gandlf.py +++ b/tests/github/test_gandlf.py @@ -10,7 +10,7 @@ from subprocess import check_call from concurrent.futures import ProcessPoolExecutor -from tests.github.utils import create_collaborator, certify_aggregator +from tests.github.utils import create_collaborator, certify_aggregator, is_path_name_allowed from openfl.utilities.utils import getfqdn_env @@ -19,7 +19,7 @@ def exec(command, directory): check_call(command) -if __name__ == '__main__': +def main(): parser = argparse.ArgumentParser() parser.add_argument('--template', default='keras_cnn_mnist') parser.add_argument('--fed_workspace', default='fed_work12345alpha81671') @@ -34,6 +34,12 @@ def exec(command, directory): origin_dir = Path().resolve() args = parser.parse_args() fed_workspace = args.fed_workspace + + # Check if the path name is allowed before creating the workspace + if not is_path_name_allowed(fed_workspace): + print(f"The path name {fed_workspace} is not allowed") + return + archive_name = f'{fed_workspace}.zip' fqdn = getfqdn_env() template = args.template @@ -116,3 +122,7 @@ def exec(command, directory): dir2 = workspace_root / col2 / fed_workspace executor.submit(exec, ['fx', 'collaborator', 'start', '-n', col2], dir2) shutil.rmtree(workspace_root) + + +if __name__ == '__main__': + main() diff --git a/tests/github/test_hello_federation.py b/tests/github/test_hello_federation.py index 7e9e676fbe..e9071eed1c 100644 --- a/tests/github/test_hello_federation.py +++ b/tests/github/test_hello_federation.py @@ -9,10 +9,11 @@ from concurrent.futures import ProcessPoolExecutor from openfl.utilities.utils import rmtree -from tests.github.utils import create_collaborator, create_certified_workspace, certify_aggregator +from tests.github.utils import create_collaborator, create_certified_workspace, certify_aggregator, is_path_name_allowed from openfl.utilities.utils import getfqdn_env -if __name__ == '__main__': + +def main(): # Test the pipeline parser = argparse.ArgumentParser() workspace_choice = [] @@ -32,6 +33,12 @@ origin_dir = Path.cwd().resolve() args = parser.parse_args() fed_workspace = args.fed_workspace + + # Check if the path name is allowed before creating the workspace + if not is_path_name_allowed(fed_workspace): + print(f"The path name {fed_workspace} is not allowed") + return + archive_name = f'{fed_workspace}.zip' fqdn = getfqdn_env() template = args.template @@ -73,3 +80,7 @@ os.chdir(origin_dir) rmtree(workspace_root) + + +if __name__ == '__main__': + main() diff --git a/tests/github/utils.py b/tests/github/utils.py index b265448111..06e7f8ae97 100644 --- a/tests/github/utils.py +++ b/tests/github/utils.py @@ -117,3 +117,22 @@ def create_signed_cert_for_collaborator(col, data_path): os.remove(f) # Remove request archive os.remove(f'col_{col}_to_agg_cert_request.zip') + + +def is_path_name_allowed(path): + """ + Check if given path name is allowed. + Allow alphanumeric characters, hyphens and underscores. + Also, / in case of a nested directory. + + Args: + path (str): The path name to check. + Returns: + bool: True if the path name is allowed, False otherwise. + """ + special_characters = "!@#$%^&*()+?=,<>" + + if any(c in special_characters for c in path): + return False + else: + return True