Skip to content

Commit

Permalink
Merge branch 'develop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
payalcha authored Nov 21, 2024
2 parents e4a992e + 3f4ce31 commit 17531e1
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 18 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/task_runner_e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
run: |
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py \
-m ${{ env.MODEL_NAME }} --model_name ${{ env.MODEL_NAME }} \
--num_rounds $NUM_ROUNDS --num_collaborators $NUM_COLLABORATORS
--num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }}
echo "Task runner end to end test run completed"
- name: Print test summary
Expand Down Expand Up @@ -140,7 +140,7 @@ jobs:
run: |
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py \
-m ${{ env.MODEL_NAME }} --model_name ${{ env.MODEL_NAME }} \
--num_rounds $NUM_ROUNDS --num_collaborators $NUM_COLLABORATORS --disable_tls
--num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} --disable_tls
echo "Task runner end to end test run completed"
- name: Print test summary
Expand Down
34 changes: 24 additions & 10 deletions tests/end_to_end/utils/summary_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from lxml import etree
import os

import tests.end_to_end.utils.constants as constants

# Initialize the XML parser
parser = etree.XMLParser(recover=True, encoding='utf-8')
tree = ET.parse("results/results.xml", parser=parser)
Expand Down Expand Up @@ -101,29 +103,41 @@ def get_testcase_result():
return database_list


if __name__ == "__main__":
def main():
"""
Main function to get the test case results and aggregator logs
And write the results to GitHub step summary
IMP: Do not fail the test in any scenario
"""
result = get_testcase_result()

if not all([os.getenv(var) for var in ["NUM_COLLABORATORS", "NUM_ROUNDS", "MODEL_NAME", "GITHUB_STEP_SUMMARY"]]):
print("One or more environment variables not set. Skipping writing to GitHub step summary")
return

num_cols = os.getenv("NUM_COLLABORATORS")
num_rounds = os.getenv("NUM_ROUNDS")
model_name = os.getenv("MODEL_NAME")
summary_file = os.getenv("GITHUB_STEP_SUMMARY")

if not model_name:
print("MODEL_NAME is not set, cannot find out aggregator logs")
agg_accuracy = "Not Found"
else:
workspace_name = "workspace_" + model_name
agg_log_file = os.path.join("results", workspace_name, "aggregator.log")
agg_accuracy = get_aggregated_accuracy(agg_log_file)
# Validate the model name and create the workspace name
if not model_name.upper() in constants.ModelName._member_names_:
print(f"Invalid model name: {model_name}. Skipping writing to GitHub step summary")
return

# Write the results to GitHub step summary
with open(os.getenv('GITHUB_STEP_SUMMARY'), 'a') as fh:
workspace_name = "workspace_" + model_name
agg_log_file = os.path.join("results", workspace_name, "aggregator.log")
agg_accuracy = get_aggregated_accuracy(agg_log_file)

# Write the results to GitHub step summary file
# This file is created at runtime by the GitHub action, thus we cannot verify its existence beforehand
with open(summary_file, 'a') as fh:
# DO NOT change the print statements
print("| Name | Time (in seconds) | Result | Error (if any) | Collaborators | Rounds to train | Score (if applicable) |", file=fh)
print("| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |", file=fh)
for item in result:
print(f"| {item['name']} | {item['time']} | {item['result']} | {item['err_msg']} | {num_cols} | {num_rounds} | {agg_accuracy} |", file=fh)


if __name__ == "__main__":
main()
15 changes: 13 additions & 2 deletions tests/github/test_double_ws_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from concurrent.futures import ProcessPoolExecutor
import psutil

from tests.github.utils import create_certified_workspace, certify_aggregator, create_collaborator
from tests.github.utils import create_certified_workspace, certify_aggregator, create_collaborator, is_path_name_allowed
from openfl.utilities.utils import getfqdn_env

if __name__ == '__main__':

def main():
# Test the pipeline
parser = argparse.ArgumentParser()
workspace_choice = []
Expand All @@ -31,6 +32,12 @@

args = parser.parse_args()
fed_workspace = args.fed_workspace

# Check if the path name is allowed before creating the workspace
if not is_path_name_allowed(fed_workspace):
print(f"The path name {fed_workspace} is not allowed")
return

archive_name = f'{fed_workspace}.zip'
fqdn = getfqdn_env()
template = args.template
Expand Down Expand Up @@ -81,3 +88,7 @@
dir1 = workspace_root / col1 / fed_workspace
executor.submit(check_call, ['fx', 'collaborator', 'start', '-n', col1], cwd=dir1)
shutil.rmtree(workspace_root)


if __name__ == '__main__':
main()
14 changes: 12 additions & 2 deletions tests/github/test_gandlf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from subprocess import check_call
from concurrent.futures import ProcessPoolExecutor

from tests.github.utils import create_collaborator, certify_aggregator
from tests.github.utils import create_collaborator, certify_aggregator, is_path_name_allowed
from openfl.utilities.utils import getfqdn_env


Expand All @@ -19,7 +19,7 @@ def exec(command, directory):
check_call(command)


if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--template', default='keras_cnn_mnist')
parser.add_argument('--fed_workspace', default='fed_work12345alpha81671')
Expand All @@ -34,6 +34,12 @@ def exec(command, directory):
origin_dir = Path().resolve()
args = parser.parse_args()
fed_workspace = args.fed_workspace

# Check if the path name is allowed before creating the workspace
if not is_path_name_allowed(fed_workspace):
print(f"The path name {fed_workspace} is not allowed")
return

archive_name = f'{fed_workspace}.zip'
fqdn = getfqdn_env()
template = args.template
Expand Down Expand Up @@ -116,3 +122,7 @@ def exec(command, directory):
dir2 = workspace_root / col2 / fed_workspace
executor.submit(exec, ['fx', 'collaborator', 'start', '-n', col2], dir2)
shutil.rmtree(workspace_root)


if __name__ == '__main__':
main()
15 changes: 13 additions & 2 deletions tests/github/test_hello_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from concurrent.futures import ProcessPoolExecutor

from openfl.utilities.utils import rmtree
from tests.github.utils import create_collaborator, create_certified_workspace, certify_aggregator
from tests.github.utils import create_collaborator, create_certified_workspace, certify_aggregator, is_path_name_allowed
from openfl.utilities.utils import getfqdn_env

if __name__ == '__main__':

def main():
# Test the pipeline
parser = argparse.ArgumentParser()
workspace_choice = []
Expand All @@ -32,6 +33,12 @@
origin_dir = Path.cwd().resolve()
args = parser.parse_args()
fed_workspace = args.fed_workspace

# Check if the path name is allowed before creating the workspace
if not is_path_name_allowed(fed_workspace):
print(f"The path name {fed_workspace} is not allowed")
return

archive_name = f'{fed_workspace}.zip'
fqdn = getfqdn_env()
template = args.template
Expand Down Expand Up @@ -73,3 +80,7 @@

os.chdir(origin_dir)
rmtree(workspace_root)


if __name__ == '__main__':
main()
19 changes: 19 additions & 0 deletions tests/github/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,22 @@ def create_signed_cert_for_collaborator(col, data_path):
os.remove(f)
# Remove request archive
os.remove(f'col_{col}_to_agg_cert_request.zip')


def is_path_name_allowed(path):
"""
Check if given path name is allowed.
Allow alphanumeric characters, hyphens and underscores.
Also, / in case of a nested directory.
Args:
path (str): The path name to check.
Returns:
bool: True if the path name is allowed, False otherwise.
"""
special_characters = "!@#$%^&*()+?=,<>"

if any(c in special_characters for c in path):
return False
else:
return True

0 comments on commit 17531e1

Please sign in to comment.