Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Code changes due to TLS Parameter changes in openfl #1170

Merged
merged 11 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/task_runner_e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
run: |
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py \
-m ${{ env.MODEL_NAME }} --model_name ${{ env.MODEL_NAME }} \
--num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} --disable_tls
--num_rounds ${{ env.NUM_ROUNDS }} --num_collaborators ${{ env.NUM_COLLABORATORS }} --use_tls False
echo "Task runner end to end test run completed"

- name: Print test summary
Expand Down
6 changes: 3 additions & 3 deletions tests/end_to_end/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ Below parameters are available for modification:
1. --num_collaborators <int> - to modify the number of collaborators
2. --num_rounds <int> - to modify the number of rounds to train
3. --model_name <str> - to use a specific model
4. --disable_tls - to disable TLS communication (by default it is enabled)
5. --disable_client_auth - to disable the client authentication (by default it is enabled)
4. --use_tls <str> - to enable TLS communication (by default it is enabled)
5. --require_client_auth <str> - to enable the client authentication (by default it is enabled)

For example, to run Task runner with - torch_cnn_mnist model, 3 collaborators, 5 rounds and non-TLS scenario:

```sh
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py --num_rounds 5 --num_collaborators 3 --model_name torch_cnn_mnist --disable_tls
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py --num_rounds 5 --num_collaborators 3 --model_name torch_cnn_mnist --use_tls "true"
```

### Output Structure
Expand Down
67 changes: 17 additions & 50 deletions tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
# Define a named tuple to store the objects for model owner, aggregator, and collaborators
federation_fixture = collections.namedtuple(
"federation_fixture",
"model_owner, aggregator, collaborators, model_name, disable_client_auth, disable_tls, workspace_path, results_dir, num_rounds",
"model_owner, aggregator, collaborators, model_name, require_client_auth, use_tls, workspace_path, results_dir, num_rounds",
)


def pytest_addoption(parser):
"""
Add custom command line options to the pytest parser.
Expand All @@ -29,44 +28,12 @@ def pytest_addoption(parser):
"""
parser.addini("results_dir", "Directory to store test results", default="results")
parser.addini("log_level", "Logging level", default="DEBUG")
parser.addoption(
"--results_dir", action="store", type=str, default="results", help="Results directory"
)
parser.addoption(
"--num_collaborators",
action="store",
type=int,
default=constants.NUM_COLLABORATORS,
help="Number of collaborators",
)
parser.addoption(
"--num_rounds",
action="store",
type=int,
default=constants.NUM_ROUNDS,
help="Number of rounds to train",
)
parser.addoption(
"--model_name",
action="store",
type=str,
help="Model name",
)
parser.addoption(
"--disable_client_auth",
action="store_true",
help="Disable client authentication",
)
parser.addoption(
"--disable_tls",
action="store_true",
help="Disable TLS for communication",
)
parser.addoption(
"--log_memory_usage",
action="store_true",
help="Enable memory log in collaborators and aggregator",
)
parser.addoption("--num_collaborators")
parser.addoption("--num_rounds")
parser.addoption("--model_name")
parser.addoption("--require_client_auth")
parser.addoption("--use_tls")
parser.addoption("--log_memory_usage")


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -234,20 +201,20 @@ def fx_federation(request, pytestconfig):
args = parse_arguments()
# Use the model name from the test case name if not provided as a command line argument
model_name = args.model_name if args.model_name else request.node.name.split("test_")[1]
results_dir = args.results_dir or pytestconfig.getini("results_dir")
results_dir = pytestconfig.getini("results_dir")
num_collaborators = args.num_collaborators
num_rounds = args.num_rounds
disable_client_auth = args.disable_client_auth
disable_tls = args.disable_tls
require_client_auth = True if ( args.require_client_auth.lower() == "true" ) else False
use_tls = True if ( args.use_tls.lower() == "true" ) else False
payalcha marked this conversation as resolved.
Show resolved Hide resolved
log_memory_usage = args.log_memory_usage

log.info(
f"Running federation setup using Task Runner API on single machine with below configurations:\n"
f"\tNumber of collaborators: {num_collaborators}\n"
f"\tNumber of rounds: {num_rounds}\n"
f"\tModel name: {model_name}\n"
f"\tClient authentication: {not disable_client_auth}\n"
f"\tTLS: {not disable_tls}\n"
f"\tClient authentication: {require_client_auth}\n"
f"\tTLS: {use_tls}\n"
f"\tMemory Logs: {log_memory_usage}"
)

Expand All @@ -270,16 +237,16 @@ def fx_federation(request, pytestconfig):
model_owner.modify_plan(
new_rounds=num_rounds,
num_collaborators=num_collaborators,
disable_client_auth=disable_client_auth,
disable_tls=disable_tls,
require_client_auth=require_client_auth,
use_tls=use_tls,
)
except Exception as e:
log.error(f"Failed to modify the plan: {e}")
raise e

# For TLS enabled (default) scenario: when the workspace is certified, the collaborators are registered as well
# For TLS disabled scenario: collaborators need to be registered explicitly
if args.disable_tls:
if args.use_tls:
log.info("Disabling TLS for communication")
payalcha marked this conversation as resolved.
Show resolved Hide resolved
try:
model_owner.register_collaborators(num_collaborators)
Expand Down Expand Up @@ -321,8 +288,8 @@ def fx_federation(request, pytestconfig):
aggregator=aggregator,
collaborators=collaborators,
model_name=model_name,
disable_client_auth=disable_client_auth,
disable_tls=disable_tls,
require_client_auth=require_client_auth,
use_tls=use_tls,
workspace_path=workspace_path,
results_dir=results_dir,
num_rounds=num_rounds,
Expand Down
10 changes: 5 additions & 5 deletions tests/end_to_end/models/participants.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ def certify_collaborator(self, collaborator_name):
raise e
return True

def modify_plan(self, new_rounds=None, num_collaborators=None, disable_client_auth=False, disable_tls=False):
def modify_plan(self, new_rounds=None, num_collaborators=None, require_client_auth=False, use_tls=False):
payalcha marked this conversation as resolved.
Show resolved Hide resolved
"""
Modify the plan to train the model
Args:
new_rounds (int): Number of rounds to train
num_collaborators (int): Number of collaborators
disable_client_auth (bool): Disable client authentication
disable_tls (bool): Disable TLS communication
require_client_auth (bool): Disable client authentication
use_tls (bool): Disable TLS communication
payalcha marked this conversation as resolved.
Show resolved Hide resolved
Returns:
bool: True if successful, else False
"""
Expand All @@ -139,8 +139,8 @@ def modify_plan(self, new_rounds=None, num_collaborators=None, disable_client_au
data["collaborator"]["settings"]["log_memory_usage"] = self.log_memory_usage

data["data_loader"]["settings"]["collaborator_count"] = int(self.num_collaborators)
data["network"]["settings"]["disable_client_auth"] = disable_client_auth
data["network"]["settings"]["tls"] = not disable_tls
data["network"]["settings"]["require_client_auth"] = require_client_auth
data["network"]["settings"]["use_tls"] = use_tls

with open(self.plan_path, "w+") as write_file:
yaml.dump(data, write_file)
Expand Down
6 changes: 3 additions & 3 deletions tests/end_to_end/test_suites/task_runner_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_torch_cnn_mnist(fx_federation):
log.info("Testing torch_cnn_mnist model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
if fx_federation.use_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
Expand All @@ -32,7 +32,7 @@ def test_keras_cnn_mnist(fx_federation):
log.info("Testing keras_cnn_mnist model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
if fx_federation.use_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
Expand All @@ -50,7 +50,7 @@ def test_torch_cnn_histology(fx_federation):
log.info("Testing torch_cnn_histology model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
if fx_federation.use_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
Expand Down
11 changes: 6 additions & 5 deletions tests/end_to_end/utils/conftest_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,23 @@ def parse_arguments():
- num_collaborators (int, default=2): Number of collaborators
- num_rounds (int, default=5): Number of rounds to train
- model_name (str, default="torch_cnn_mnist"): Model name
- disable_client_auth (bool): Disable client authentication
- disable_tls (bool): Disable TLS for communication
- require_client_auth (bool): Disable client authentication
- use_tls (bool): Disable TLS for communication
payalcha marked this conversation as resolved.
Show resolved Hide resolved
- log_memory_usage (bool): Enable Memory leak logs

Raises:
SystemExit: If the required arguments are not provided or if any argument parsing error occurs.
"""
try:
parser = argparse.ArgumentParser(description="Provide the required arguments to run the tests")
parser.add_argument("--results_dir", type=str, required=False, default="results", help="Directory to store the results")
# parser.add_argument("--results_dir", type=str, required=False, default="results", help="Directory to store the results")
payalcha marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--num_collaborators", type=int, default=2, help="Number of collaborators")
parser.add_argument("--num_rounds", type=int, default=5, help="Number of rounds to train")
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--disable_client_auth", action="store_true", help="Disable client authentication")
parser.add_argument("--disable_tls", action="store_true", help="Disable TLS for communication")
parser.add_argument("--require_client_auth", type=str, default="True", help="Enable client authentication")
parser.add_argument("--use_tls", type=str, default="True", help="Enable TLS for communication")
parser.add_argument("--log_memory_usage", action="store_true", help="Enable Memory leak logs")
# parser.add_argument("--junitxml" , type=str, default="report.xml", help="Path to store the JUnit XML report")
payalcha marked this conversation as resolved.
Show resolved Hide resolved
args = parser.parse_known_args()[0]
return args

Expand Down