Skip to content

Commit

Permalink
use_tls and require_client_auth
Browse files Browse the repository at this point in the history
Signed-off-by: Chaurasiya, Payal <[email protected]>
  • Loading branch information
payalcha committed Nov 22, 2024
1 parent e4a992e commit 33943fd
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/task_runner_e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
run: |
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py \
-m ${{ env.MODEL_NAME }} --model_name ${{ env.MODEL_NAME }} \
--num_rounds $NUM_ROUNDS --num_collaborators $NUM_COLLABORATORS --disable_tls
--num_rounds $NUM_ROUNDS --num_collaborators $NUM_COLLABORATORS --use_tls "false"
echo "Task runner end to end test run completed"
- name: Print test summary
Expand Down
6 changes: 3 additions & 3 deletions tests/end_to_end/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ Below parameters are available for modification:
1. --num_collaborators <int> - to modify the number of collaborators
2. --num_rounds <int> - to modify the number of rounds to train
3. --model_name <str> - to use a specific model
4. --disable_tls - to disable TLS communication (by default it is enabled)
5. --disable_client_auth - to disable the client authentication (by default it is enabled)
4. --use_tls <str> - to enable TLS communication (by default it is enabled)
5. --require_client_auth <str> - to enable the client authentication (by default it is enabled)

For example, to run Task runner with - torch_cnn_mnist model, 3 collaborators, 5 rounds and non-TLS scenario:

```sh
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py --num_rounds 5 --num_collaborators 3 --model_name torch_cnn_mnist --disable_tls
python -m pytest -s tests/end_to_end/test_suites/task_runner_tests.py --num_rounds 5 --num_collaborators 3 --model_name torch_cnn_mnist --use_tls "true"
```

### Output Structure
Expand Down
67 changes: 17 additions & 50 deletions tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
# Define a named tuple to store the objects for model owner, aggregator, and collaborators
federation_fixture = collections.namedtuple(
"federation_fixture",
"model_owner, aggregator, collaborators, model_name, disable_client_auth, disable_tls, workspace_path, results_dir, num_rounds",
"model_owner, aggregator, collaborators, model_name, require_client_auth, use_tls, workspace_path, results_dir, num_rounds",
)


def pytest_addoption(parser):
"""
Add custom command line options to the pytest parser.
Expand All @@ -29,44 +28,12 @@ def pytest_addoption(parser):
"""
parser.addini("results_dir", "Directory to store test results", default="results")
parser.addini("log_level", "Logging level", default="DEBUG")
parser.addoption(
"--results_dir", action="store", type=str, default="results", help="Results directory"
)
parser.addoption(
"--num_collaborators",
action="store",
type=int,
default=constants.NUM_COLLABORATORS,
help="Number of collaborators",
)
parser.addoption(
"--num_rounds",
action="store",
type=int,
default=constants.NUM_ROUNDS,
help="Number of rounds to train",
)
parser.addoption(
"--model_name",
action="store",
type=str,
help="Model name",
)
parser.addoption(
"--disable_client_auth",
action="store_true",
help="Disable client authentication",
)
parser.addoption(
"--disable_tls",
action="store_true",
help="Disable TLS for communication",
)
parser.addoption(
"--log_memory_usage",
action="store_true",
help="Enable memory log in collaborators and aggregator",
)
parser.addoption("--num_collaborators")
parser.addoption("--num_rounds")
parser.addoption("--model_name")
parser.addoption("--require_client_auth")
parser.addoption("--use_tls")
parser.addoption("--log_memory_usage")


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -234,20 +201,20 @@ def fx_federation(request, pytestconfig):
args = parse_arguments()
# Use the model name from the test case name if not provided as a command line argument
model_name = args.model_name if args.model_name else request.node.name.split("test_")[1]
results_dir = args.results_dir or pytestconfig.getini("results_dir")
results_dir = pytestconfig.getini("results_dir")
num_collaborators = args.num_collaborators
num_rounds = args.num_rounds
disable_client_auth = args.disable_client_auth
disable_tls = args.disable_tls
require_client_auth = True if ( args.require_client_auth.lower() == "true" ) else False
use_tls = True if ( args.use_tls.lower() == "true" ) else False
log_memory_usage = args.log_memory_usage

log.info(
f"Running federation setup using Task Runner API on single machine with below configurations:\n"
f"\tNumber of collaborators: {num_collaborators}\n"
f"\tNumber of rounds: {num_rounds}\n"
f"\tModel name: {model_name}\n"
f"\tClient authentication: {not disable_client_auth}\n"
f"\tTLS: {not disable_tls}\n"
f"\tClient authentication: {require_client_auth}\n"
f"\tTLS: {use_tls}\n"
f"\tMemory Logs: {log_memory_usage}"
)

Expand All @@ -270,16 +237,16 @@ def fx_federation(request, pytestconfig):
model_owner.modify_plan(
new_rounds=num_rounds,
num_collaborators=num_collaborators,
disable_client_auth=disable_client_auth,
disable_tls=disable_tls,
require_client_auth=require_client_auth,
use_tls=use_tls,
)
except Exception as e:
log.error(f"Failed to modify the plan: {e}")
raise e

# For TLS enabled (default) scenario: when the workspace is certified, the collaborators are registered as well
# For TLS disabled scenario: collaborators need to be registered explicitly
if args.disable_tls:
if args.use_tls:
log.info("Disabling TLS for communication")
try:
model_owner.register_collaborators(num_collaborators)
Expand Down Expand Up @@ -321,8 +288,8 @@ def fx_federation(request, pytestconfig):
aggregator=aggregator,
collaborators=collaborators,
model_name=model_name,
disable_client_auth=disable_client_auth,
disable_tls=disable_tls,
require_client_auth=require_client_auth,
use_tls=use_tls,
workspace_path=workspace_path,
results_dir=results_dir,
num_rounds=num_rounds,
Expand Down
10 changes: 5 additions & 5 deletions tests/end_to_end/models/participants.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ def certify_collaborator(self, collaborator_name):
raise e
return True

def modify_plan(self, new_rounds=None, num_collaborators=None, disable_client_auth=False, disable_tls=False):
def modify_plan(self, new_rounds=None, num_collaborators=None, require_client_auth=False, use_tls=False):
"""
Modify the plan to train the model
Args:
new_rounds (int): Number of rounds to train
num_collaborators (int): Number of collaborators
disable_client_auth (bool): Disable client authentication
disable_tls (bool): Disable TLS communication
require_client_auth (bool): Disable client authentication
use_tls (bool): Disable TLS communication
Returns:
bool: True if successful, else False
"""
Expand All @@ -139,8 +139,8 @@ def modify_plan(self, new_rounds=None, num_collaborators=None, disable_client_au
data["collaborator"]["settings"]["log_memory_usage"] = self.log_memory_usage

data["data_loader"]["settings"]["collaborator_count"] = int(self.num_collaborators)
data["network"]["settings"]["require_client_auth"] = not disable_client_auth
data["network"]["settings"]["use_tls"] = not disable_tls
data["network"]["settings"]["require_client_auth"] = not require_client_auth
data["network"]["settings"]["use_tls"] = not use_tls

with open(self.plan_path, "w+") as write_file:
yaml.dump(data, write_file)
Expand Down
6 changes: 3 additions & 3 deletions tests/end_to_end/test_suites/task_runner_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_torch_cnn_mnist(fx_federation):
log.info("Testing torch_cnn_mnist model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
if fx_federation.use_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
Expand All @@ -32,7 +32,7 @@ def test_keras_cnn_mnist(fx_federation):
log.info("Testing keras_cnn_mnist model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
if fx_federation.use_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
Expand All @@ -50,7 +50,7 @@ def test_torch_cnn_histology(fx_federation):
log.info("Testing torch_cnn_histology model")

# Setup PKI for trusted communication within the federation
if not fx_federation.disable_tls:
if fx_federation.use_tls:
assert fed_helper.setup_pki(fx_federation), "Failed to setup PKI for trusted communication"

# Start the federation
Expand Down
11 changes: 6 additions & 5 deletions tests/end_to_end/utils/conftest_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,23 @@ def parse_arguments():
- num_collaborators (int, default=2): Number of collaborators
- num_rounds (int, default=5): Number of rounds to train
- model_name (str, default="torch_cnn_mnist"): Model name
- disable_client_auth (bool): Disable client authentication
- disable_tls (bool): Disable TLS for communication
- require_client_auth (bool): Disable client authentication
- use_tls (bool): Disable TLS for communication
- log_memory_usage (bool): Enable Memory leak logs
Raises:
SystemExit: If the required arguments are not provided or if any argument parsing error occurs.
"""
try:
parser = argparse.ArgumentParser(description="Provide the required arguments to run the tests")
parser.add_argument("--results_dir", type=str, required=False, default="results", help="Directory to store the results")
# parser.add_argument("--results_dir", type=str, required=False, default="results", help="Directory to store the results")
parser.add_argument("--num_collaborators", type=int, default=2, help="Number of collaborators")
parser.add_argument("--num_rounds", type=int, default=5, help="Number of rounds to train")
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--disable_client_auth", action="store_true", help="Disable client authentication")
parser.add_argument("--disable_tls", action="store_true", help="Disable TLS for communication")
parser.add_argument("--require_client_auth", type=str, default="True", help="Enable client authentication")
parser.add_argument("--use_tls", type=str, default="True", help="Enable TLS for communication")
parser.add_argument("--log_memory_usage", action="store_true", help="Enable Memory leak logs")
# parser.add_argument("--junitxml" , type=str, default="report.xml", help="Path to store the JUnit XML report")
args = parser.parse_known_args()[0]
return args

Expand Down

0 comments on commit 33943fd

Please sign in to comment.