From f8cd2a7587ea453a44968de86dc4dadca2358fe6 Mon Sep 17 00:00:00 2001 From: benjaminastrand Date: Thu, 12 Sep 2024 10:31:11 +0200 Subject: [PATCH] Clean up and run 20 trials a 50 rounds --- .../notebooks/Hyperparameter_Tuning.ipynb | 699 +++++------------- 1 file changed, 170 insertions(+), 529 deletions(-) diff --git a/examples/notebooks/Hyperparameter_Tuning.ipynb b/examples/notebooks/Hyperparameter_Tuning.ipynb index 660d42b4d..41d8f9dbd 100644 --- a/examples/notebooks/Hyperparameter_Tuning.ipynb +++ b/examples/notebooks/Hyperparameter_Tuning.ipynb @@ -6,7 +6,7 @@ "source": [ "## Hyperparameter tuning of the server-side optimizer with Optuna\n", "\n", - "This notebook shows specifically how to tune the *learning rate* of *FedAdam* using the Optuna package. Tuning of other hyperparameter and/or other server-side optimizers can be done analogously. The notebook *Aggregators.ipynb* shows how to use different aggregators with the FEDn Python API.\n", + "This notebook shows how to tune hyperparameters of the server-side optimizer, specifically the *learning rate* of *FedAdam*, using the Optuna package. Optuna supports Bayesian optimization for the selection of hyperparameter values. Tuning of other hyperparameter and/or other server-side optimizers can be done analogously. The notebook *Aggregators.ipynb* shows how to use different aggregators with the FEDn Python API.\n", "\n", "For a complete list of implemented interfaces, please refer to the [FEDn APIs](https://fedn.readthedocs.io/en/latest/fedn.network.api.html#module-fedn.network.api.client). \n", "\n", @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -40,11 +40,11 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "CONTROLLER_HOST = 'fedn.scaleoutsystems.com/' # TODO byt ut till lokal\n", + "CONTROLLER_HOST = 'fedn.scaleoutsystems.com/fedn.scaleoutsystems.com/'\n", "ACCESS_TOKEN = ''\n", "client = APIClient(CONTROLLER_HOST,token=ACCESS_TOKEN, secure=True,verify=True)" ] @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -71,79 +71,87 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "### Using Optuna to tune the server-side learning rate of FedAdam\n", + "The Optuna framework expects the user to define an objective function, which is used to evaluate the model given a certain set of hyperparameter values. This notebook is based on an existing example on the [FEDn Github](https://github.com/scaleoutsystems/fedn/tree/master/examples/mnist-pytorch), where we use a simple PyTorch model on the MNIST handwritten digit dataset. To evaluate the performance given different hyperparameter values, we will view the accuracy on the test set as the validation accuracy and we want to find the learning rate that maximizes this metric.\n", + "\n", "### Defining the objective function\n", "\n", - "Optuna expects an objective function - the function that evaluates a certain set of hyperparameter values. In this example, we will use the test accuracy as a validation score and we want to maximize it.\n", + "For each choice of hyperparameter values, we start a new session using FEDn and train the global model with the current hyperparameter values. When the session has finished, we evaluate the performance attained in the session. This is where the objective function comes into play! The objective function should follows these steps:\n", + "\n", + "1. Set a range for each hyperparameter to tune using the `trial` object in Optuna.\n", + "2. **Train the model**, using the hyperparameters suggested by Optuna.\n", + "3. Calculate and **return an evaluation metric**.\n", "\n", - "For each set of hyperparameter values, each `trial`, we will start a new session using the FEDn Python API. In each session/trial, we will select the model with the highest test accuracy and use that in the Optuna objective function to evaluate the trial.\n", + "But before we define the objective function, we will create a function that defines how the evaluation metric shall be calculated (step 3) after each finished session. Below are two suggested methods for evaluating the performance attained in a session:\n", "\n", - "The `objective()` function gives us some flexibility in how we choose to evaluate the choice of hyperparameters in each trial/session. Below are two examples on how to calculate the attained validation accuracy in a session.\n", + "* **Highest score** - select the highest achieved test accuracy out of all rounds in the session.\n", + "* **Average final few rounds** - compute the average test accuracy over the final few (ex. 5) rounds to account for the stochastic nature of the test accuracy score.\n", "\n", - "* ``" + "…and how to implement them using FEDn, where the `eval_method` parameter determines which of the two methods to use:\n" ] }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ - "# Helper function to get the highest test accuracy within a session\n", - "def get_highest_test_accuracy_in_session(client, n_rounds):\n", - " best_accuracy = 0\n", - " validations_in_session = client.get_validations()['result'][:n_rounds]\n", - " for validation in validations_in_session:\n", - " val_accuracy = json.loads(validation['data'])['test_accuracy']\n", - " if val_accuracy > best_accuracy:\n", - " best_accuracy = val_accuracy\n", - "\n", - " return best_accuracy\n", - "\n", - "# Helper function to get the average test accuracy over the last 10 rounds in a session\n", - "def get_test_accuracy_in_session_smooth(client, n_rounds):\n", + "def get_test_accuracy(client, n_rounds_in_session, eval_method='highest'):\n", " \n", - " n_rounds_to_avg = 5\n", - " if n_rounds_to_avg > n_rounds:\n", - " n_rounds_to_avg = n_rounds\n", + " # Set number of rounds to average for 'smooth' method\n", + " if eval_method == 'smooth':\n", + " n_rounds_to_eval = min(5, n_rounds_in_session)\n", + " else:\n", + " n_rounds_to_eval = n_rounds_in_session\n", " \n", - " # New\n", - " models = client.get_model_trail()[-n_rounds_to_avg:] # model with index -1 lacks validations -> seed model??\n", - " # print(f'models: {len(models)}\\n {models}')\n", - " model_test_acc = []\n", + " # Get models in session based on eval_method\n", + " models_in_session = client.get_model_trail()[-n_rounds_to_eval:]\n", "\n", - " # Loop over the last 'n_rounds_to_avg' rounds\n", - " for model_index, model in enumerate(models):\n", - " \n", + " session_test_accuracy_scores = []\n", + " for model in models_in_session:\n", " model_id = model[\"model\"]\n", - " validations = client.get_validations(model_id=model_id)\n", - " # print(f'Validation nr. {model_index}: {validations}')\n", - " a = []\n", "\n", - " # Loop over all contributing clients\n", - " for validation in validations['result']: \n", + " # Wait to receive validation data\n", + " wait_time = 0\n", + " while True:\n", + " time.sleep(1)\n", + " wait_time += 1\n", + " validations = client.get_validations(model_id=model_id)\n", + " if validations['count'] != 0 or wait_time == 60:\n", + " break\n", + "\n", + " # Average test accuracy over all contributing clients\n", + " model_test_accuracy_scores = []\n", + " for validation in validations['result']:\n", " metrics = json.loads(validation['data'])\n", - " a.append(metrics['test_accuracy'])\n", + " model_test_accuracy_scores.append(metrics['test_accuracy'])\n", " \n", - " model_test_acc.append(a)\n", - " print(f'Model id: {model_id}, Validations: {validations}')\n", - "\n", - " mean_val_accuracies = [np.mean(x) for x in model_test_acc]\n", - " print(f'Mean accuracy: {mean_val_accuracies}')\n", + " session_test_accuracy_scores.append(model_test_accuracy_scores)\n", "\n", - " # Old\n", - " # validations_to_avg = client.get_validations()['result'][:n_rounds_to_avg]\n", - " # val_accuracies = [json.loads(validation['data'])['test_accuracy'] for validation in validations_to_avg]\n", + " client_avg_test_accuracy_scores = [np.mean(x) for x in session_test_accuracy_scores]\n", "\n", - " mean_val_accuracy = np.mean(mean_val_accuracies)\n", - " print(f'Validation accuracy scores:\\n{mean_val_accuracies}')\n", - " print(f'Average validation accuracy: {mean_val_accuracy}')\n", + " if eval_method == 'highest':\n", + " # Return the highest test accuracy\n", + " return np.amax(client_avg_test_accuracy_scores)\n", + " elif eval_method == 'smooth':\n", + " # Return the calculated mean accuracy\n", + " return np.mean(client_avg_test_accuracy_scores)\n", + " else:\n", + " raise ValueError(\"Invalid eval_method. Use 'highest' or 'smooth'.\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have created a function to use in step 3, we will define the objective function. The code below shows how we can complete the three steps of the objective function with FEDn. The range in which Optuna will look for hyperparameter values is defined in **step 1**. Note that we are only tuning the learning rate of FedAdam in this example to keep things simple. **Step 2** entails starting a session and waiting for it to finish before evaluating the resulting model. In **step 3**, we simply call the function that we defined above and return the result.\n", "\n", - " return mean_val_accuracy" + "**Note:** We start from the seed model in each session to ensure that each trial has the same starting point." ] }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -152,11 +160,12 @@ "# Objective function which will be sent to Optuna to evaluate the selection of hyperparameter values\n", "def objective(trial):\n", " # Number of rounds per session\n", - " n_rounds = 5\n", + " n_rounds = 50\n", "\n", - " # Suggest hyperparameter priors\n", + " # 1. Suggest hyperparameter priors\n", " learning_rate = trial.suggest_float(\"learning_rate\", 1e-3, 1e-1, log=True)\n", "\n", + " # 2. Train the model\n", " # Set session configurations (from seed model)\n", " session_config = {\n", " \"helper\": \"numpyhelper\",\n", @@ -175,80 +184,53 @@ " \n", " # Wait for the session to finish\n", " while not client.session_is_finished(session_id):\n", - " time.sleep(2)\n", + " time.sleep(1)\n", " \n", - " # Return validation accuracy for session\n", - " return get_test_accuracy_in_session_smooth(client=client, n_rounds=n_rounds)" + " # 3. Return validation accuracy for session\n", + " return get_test_accuracy(client=client, n_rounds_in_session=n_rounds, eval_method=\"smooth\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Creating and running an Optuna study\n", + "### Creating, running and analyzing an Optuna study\n", "\n", - "Here we create an Optuna study. Since we are using the test accuracy for evaluation, we want to maximize the objective function in this example. We pass the objective function defined earlier when calling `study.optimize()` and select the number of trials we want to perform.\n", + "It’s time to create and run our study to let Optuna find optimal server-side learning rate for FedAdam. At this stage, all that is left to do is to tell Optuna in which direction to optimize the objective function and how many hyperparameter values we want to try. We create an Optuna `study` object and since we are using the test accuracy for evaluation, we want to `maximize` the objective function in this example. We run the `optimize()` method, passing the `objective` function we defined earlier as a parameter and specify the number of hyperparameter values we want to try via the `n_trials` parameter. \n", "\n", - "**Note:** Each trial starts a session, so the number of sessions is `n_trials`." + "**Note:** Each trial starts a session, so the number of sessions will be `n_trials`." ] }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[I 2024-09-09 11:40:23,930] A new study created in memory with name: no-name-9f40df3c-4c9f-4283-a7df-2608de273a5f\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model id: f8df5900-233b-4ee8-aa27-cbb197f517b4, Validations: {'count': 2, 'result': [{'correlation_id': '765ae8e1-bcfc-41b9-90e5-c47d122139ee', 'data': '{\"training_loss\": 2.9918696880340576, \"training_accuracy\": 0.2761666774749756, \"test_loss\": 3.145538568496704, \"test_accuracy\": 0.25600001215934753}', 'id': '66dec299819bae1528aeeefd', 'meta': '', 'model_id': 'f8df5900-233b-4ee8-aa27-cbb197f517b4', 'receiver': {'clientId': '', 'name': 'hyperparametertuning-hrt-fedn', 'role': 'COMBINER'}, 'sender': {'clientId': '', 'name': 'client735', 'role': 'WORKER'}, 'session_id': 'cc9be434-4dda-4dc5-b40e-762ff38cefac', 'timestamp': '2024-09-09T09:40:41.057123Z'}, {'correlation_id': '4bbb7175-52d1-4ad5-82cc-b97f8b728d19', 'data': '{\"training_loss\": 3.0555734634399414, \"training_accuracy\": 0.26350000500679016, \"test_loss\": 3.146380662918091, \"test_accuracy\": 0.25200000405311584}', 'id': '66dec299819bae1528aeeefb', 'meta': '', 'model_id': 'f8df5900-233b-4ee8-aa27-cbb197f517b4', 'receiver': {'clientId': '', 'name': 'hyperparametertuning-hrt-fedn', 'role': 'COMBINER'}, 'sender': {'clientId': '', 'name': 'client268', 'role': 'WORKER'}, 'session_id': 'cc9be434-4dda-4dc5-b40e-762ff38cefac', 'timestamp': '2024-09-09T09:40:41.056228Z'}]}\n", - "Model id: ac4fb06e-0186-4220-b0ea-dda55bcba92a, Validations: {'count': 2, 'result': [{'correlation_id': 'ddcf1b24-8b55-4cf5-9133-bc35637c2bf3', 'data': '{\"training_loss\": 5.030279636383057, \"training_accuracy\": 0.2666666805744171, \"test_loss\": 4.994809150695801, \"test_accuracy\": 0.27399998903274536}', 'id': '66dec2a7819bae1528aeef0c', 'meta': '', 'model_id': 'ac4fb06e-0186-4220-b0ea-dda55bcba92a', 'receiver': {'clientId': '', 'name': 'hyperparametertuning-hrt-fedn', 'role': 'COMBINER'}, 'sender': {'clientId': '', 'name': 'client735', 'role': 'WORKER'}, 'session_id': 'cc9be434-4dda-4dc5-b40e-762ff38cefac', 'timestamp': '2024-09-09T09:40:55.140524Z'}, {'correlation_id': 'c22bb9c4-4557-4b6e-8aac-28e907393ea4', 'data': '{\"training_loss\": 4.781228065490723, \"training_accuracy\": 0.2773333191871643, \"test_loss\": 4.559061527252197, \"test_accuracy\": 0.3059999942779541}', 'id': '66dec2a7819bae1528aeef09', 'meta': '', 'model_id': 'ac4fb06e-0186-4220-b0ea-dda55bcba92a', 'receiver': {'clientId': '', 'name': 'hyperparametertuning-hrt-fedn', 'role': 'COMBINER'}, 'sender': {'clientId': '', 'name': 'client268', 'role': 'WORKER'}, 'session_id': 'cc9be434-4dda-4dc5-b40e-762ff38cefac', 'timestamp': '2024-09-09T09:40:55.025714Z'}]}\n", - "Model id: 65343657-8bd2-4e48-9bfc-9dc1c90f766b, Validations: {'count': 2, 'result': [{'correlation_id': '9414bb80-d5c1-4c1b-a67f-c3500aa02961', 'data': '{\"training_loss\": 2.4126389026641846, \"training_accuracy\": 0.4465000033378601, \"test_loss\": 2.2521755695343018, \"test_accuracy\": 0.44699999690055847}', 'id': '66dec2b3819bae1528aeef19', 'meta': '', 'model_id': '65343657-8bd2-4e48-9bfc-9dc1c90f766b', 'receiver': {'clientId': '', 'name': 'hyperparametertuning-hrt-fedn', 'role': 'COMBINER'}, 'sender': {'clientId': '', 'name': 'client268', 'role': 'WORKER'}, 'session_id': 'cc9be434-4dda-4dc5-b40e-762ff38cefac', 'timestamp': '2024-09-09T09:41:07.439818Z'}, {'correlation_id': '1c6fa889-309e-4bc9-ae20-7f45bc7342fc', 'data': '{\"training_loss\": 2.369710922241211, \"training_accuracy\": 0.4494999945163727, \"test_loss\": 2.3147921562194824, \"test_accuracy\": 0.4399999976158142}', 'id': '66dec2b3819bae1528aeef17', 'meta': '', 'model_id': '65343657-8bd2-4e48-9bfc-9dc1c90f766b', 'receiver': {'clientId': '', 'name': 'hyperparametertuning-hrt-fedn', 'role': 'COMBINER'}, 'sender': {'clientId': '', 'name': 'client735', 'role': 'WORKER'}, 'session_id': 'cc9be434-4dda-4dc5-b40e-762ff38cefac', 'timestamp': '2024-09-09T09:41:07.436095Z'}]}\n", - "Model id: 56c93771-f8e7-4987-9f1f-fc5d050a415e, Validations: {'count': 2, 'result': [{'correlation_id': '7aded382-761f-4b55-8984-c24863d09d20', 'data': '{\"training_loss\": 0.8602045774459839, \"training_accuracy\": 0.7173333168029785, \"test_loss\": 0.9392972588539124, \"test_accuracy\": 0.6869999766349792}', 'id': '66dec2bf819bae1528aeef27', 'meta': '', 'model_id': '56c93771-f8e7-4987-9f1f-fc5d050a415e', 'receiver': {'clientId': '', 'name': 'hyperparametertuning-hrt-fedn', 'role': 'COMBINER'}, 'sender': {'clientId': '', 'name': 'client268', 'role': 'WORKER'}, 'session_id': 'cc9be434-4dda-4dc5-b40e-762ff38cefac', 'timestamp': '2024-09-09T09:41:19.240899Z'}, {'correlation_id': '2c35452a-393b-40c9-8f6f-e800419e16d2', 'data': '{\"training_loss\": 0.8731573224067688, \"training_accuracy\": 0.7064999938011169, \"test_loss\": 1.0654996633529663, \"test_accuracy\": 0.6869999766349792}', 'id': '66dec2bf819bae1528aeef25', 'meta': '', 'model_id': '56c93771-f8e7-4987-9f1f-fc5d050a415e', 'receiver': {'clientId': '', 'name': 'hyperparametertuning-hrt-fedn', 'role': 'COMBINER'}, 'sender': {'clientId': '', 'name': 'client735', 'role': 'WORKER'}, 'session_id': 'cc9be434-4dda-4dc5-b40e-762ff38cefac', 'timestamp': '2024-09-09T09:41:19.225037Z'}]}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/numpy/core/fromnumeric.py:3464: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/numpy/core/_methods.py:192: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n", - "[W 2024-09-09 11:41:29,099] Trial 0 failed with parameters: {'learning_rate': 0.09377062950652325} because of the following error: The value nan is not acceptable.\n", - "[W 2024-09-09 11:41:29,100] Trial 0 failed with value nan.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model id: e4a747bb-f5a2-4845-898d-ca6e9d005bb0, Validations: {'count': 0, 'result': []}\n", - "Mean accuracy: [0.2540000081062317, 0.28999999165534973, 0.44349999725818634, 0.6869999766349792, nan]\n", - "Validation accuracy scores:\n", - "[0.2540000081062317, 0.28999999165534973, 0.44349999725818634, 0.6869999766349792, nan]\n", - "Average validation accuracy: nan\n" - ] - }, - { - "ename": "ValueError", - "evalue": "No trials are completed yet.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[80], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Optimize hyperparameters\u001b[39;00m\n\u001b[1;32m 5\u001b[0m study\u001b[38;5;241m.\u001b[39moptimize(objective, n_trials\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBest hyperparameters:\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[43mstudy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbest_params\u001b[49m)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBest value:\u001b[39m\u001b[38;5;124m\"\u001b[39m, study\u001b[38;5;241m.\u001b[39mbest_value)\n", - "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/optuna/study/study.py:119\u001b[0m, in \u001b[0;36mStudy.best_params\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbest_params\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]:\n\u001b[1;32m 109\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return parameters of the best trial in the study.\u001b[39;00m\n\u001b[1;32m 110\u001b[0m \n\u001b[1;32m 111\u001b[0m \u001b[38;5;124;03m .. note::\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 116\u001b[0m \n\u001b[1;32m 117\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbest_trial\u001b[49m\u001b[38;5;241m.\u001b[39mparams\n", - "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/optuna/study/study.py:162\u001b[0m, in \u001b[0;36mStudy.best_trial\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_is_multi_objective():\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 158\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mA single best trial cannot be retrieved from a multi-objective study. Consider \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing Study.best_trials to retrieve a list containing the best trials.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 160\u001b[0m )\n\u001b[0;32m--> 162\u001b[0m best_trial \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_storage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_best_trial\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_study_id\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;66;03m# If the trial with the best value is infeasible, select the best trial from all feasible\u001b[39;00m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;66;03m# trials. Note that the behavior is undefined when constrained optimization without the\u001b[39;00m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# violation value in the best-valued trial.\u001b[39;00m\n\u001b[1;32m 167\u001b[0m constraints \u001b[38;5;241m=\u001b[39m best_trial\u001b[38;5;241m.\u001b[39msystem_attrs\u001b[38;5;241m.\u001b[39mget(_CONSTRAINTS_KEY)\n", - "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/optuna/storages/_in_memory.py:232\u001b[0m, in \u001b[0;36mInMemoryStorage.get_best_trial\u001b[0;34m(self, study_id)\u001b[0m\n\u001b[1;32m 229\u001b[0m best_trial_id \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_studies[study_id]\u001b[38;5;241m.\u001b[39mbest_trial_id\n\u001b[1;32m 231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m best_trial_id \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 232\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo trials are completed yet.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_studies[study_id]\u001b[38;5;241m.\u001b[39mdirections) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 235\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBest trial can be obtained only for single-objective optimization.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 236\u001b[0m )\n", - "\u001b[0;31mValueError\u001b[0m: No trials are completed yet." + "[I 2024-09-11 17:55:58,359] A new study created in memory with name: no-name-3523664a-7d8a-415c-953f-a97630ce9ef8\n", + "[I 2024-09-11 17:59:21,745] Trial 0 finished with value: 0.9595600068569183 and parameters: {'learning_rate': 0.012301378174433742}. Best is trial 0 with value: 0.9595600068569183.\n", + "[I 2024-09-11 18:02:47,210] Trial 1 finished with value: 0.9620999932289124 and parameters: {'learning_rate': 0.023368957405967835}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:06:14,743] Trial 2 finished with value: 0.911080002784729 and parameters: {'learning_rate': 0.0013378005538716429}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:09:36,215] Trial 3 finished with value: 0.9557799935340882 and parameters: {'learning_rate': 0.007007650689123562}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:12:54,568] Trial 4 finished with value: 0.9598599970340729 and parameters: {'learning_rate': 0.042762352644500096}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:16:21,127] Trial 5 finished with value: 0.9210599958896637 and parameters: {'learning_rate': 0.0015413604505815153}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:19:59,964] Trial 6 finished with value: 0.9619599997997283 and parameters: {'learning_rate': 0.029375669161110958}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:23:23,559] Trial 7 finished with value: 0.9521999955177307 and parameters: {'learning_rate': 0.004598986580964227}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:26:47,148] Trial 8 finished with value: 0.9614199936389923 and parameters: {'learning_rate': 0.04232333700633576}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:30:06,728] Trial 9 finished with value: 0.8862399995326996 and parameters: {'learning_rate': 0.0010082622326759046}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:33:28,295] Trial 10 finished with value: 0.9530400037765503 and parameters: {'learning_rate': 0.0944781292696637}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:36:55,057] Trial 11 finished with value: 0.9606599926948547 and parameters: {'learning_rate': 0.019765558641651543}. Best is trial 1 with value: 0.9620999932289124.\n", + "[I 2024-09-11 18:40:21,750] Trial 12 finished with value: 0.9623999953269958 and parameters: {'learning_rate': 0.026254173914001054}. Best is trial 12 with value: 0.9623999953269958.\n", + "[I 2024-09-11 18:43:46,536] Trial 13 finished with value: 0.9538000106811524 and parameters: {'learning_rate': 0.08192645761678255}. Best is trial 12 with value: 0.9623999953269958.\n", + "[I 2024-09-11 18:47:10,396] Trial 14 finished with value: 0.9603799998760223 and parameters: {'learning_rate': 0.015592889217911915}. Best is trial 12 with value: 0.9623999953269958.\n", + "[I 2024-09-11 18:50:32,160] Trial 15 finished with value: 0.9487399995326996 and parameters: {'learning_rate': 0.0038523074238855904}. Best is trial 12 with value: 0.9623999953269958.\n", + "[I 2024-09-11 18:53:58,143] Trial 16 finished with value: 0.9629000008106232 and parameters: {'learning_rate': 0.022691713565568716}. Best is trial 16 with value: 0.9629000008106232.\n", + "[I 2024-09-11 18:57:23,139] Trial 17 finished with value: 0.9598199963569641 and parameters: {'learning_rate': 0.009542234485749465}. Best is trial 16 with value: 0.9629000008106232.\n", + "[I 2024-09-11 19:00:47,088] Trial 18 finished with value: 0.960340005159378 and parameters: {'learning_rate': 0.05229102934003298}. Best is trial 16 with value: 0.9629000008106232.\n", + "[I 2024-09-11 19:04:10,285] Trial 19 finished with value: 0.9630400002002716 and parameters: {'learning_rate': 0.02945876449044553}. Best is trial 19 with value: 0.9630400002002716.\n" ] } ], @@ -257,48 +239,49 @@ "study = optuna.create_study(direction=\"maximize\")\n", "\n", "# Optimize hyperparameters\n", - "study.optimize(objective, n_trials=1)\n", - "print(\"Best hyperparameters:\", study.best_params)\n", - "print(\"Best value:\", study.best_value)" + "study.optimize(objective, n_trials=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Visualize Optuna's optimization\n", - "\n" + "Now we can easily access the results through the `study` object, for example the best learning rate:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 43, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "0.02945876449044553" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import optuna.visualization as vis\n", - "\n", - "vis.plot_slice(study)" + "opt_learning_rate = study.best_params['learning_rate']\n", + "opt_learning_rate" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "…and visualize the optimization process:" ] }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 44, "metadata": {}, "outputs": [ - { - "ename": "ValueError", - "evalue": "Mime type rendering requires nbformat>=4.2.0 but it is not installed", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/IPython/core/formatters.py:920\u001b[0m, in \u001b[0;36mIPythonDisplayFormatter.__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 918\u001b[0m method \u001b[38;5;241m=\u001b[39m get_real_method(obj, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint_method)\n\u001b[1;32m 919\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 920\u001b[0m \u001b[43mmethod\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 921\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n", - "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/plotly/basedatatypes.py:832\u001b[0m, in \u001b[0;36mBaseFigure._ipython_display_\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 829\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mplotly\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mio\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpio\u001b[39;00m\n\u001b[1;32m 831\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pio\u001b[38;5;241m.\u001b[39mrenderers\u001b[38;5;241m.\u001b[39mrender_on_display \u001b[38;5;129;01mand\u001b[39;00m pio\u001b[38;5;241m.\u001b[39mrenderers\u001b[38;5;241m.\u001b[39mdefault:\n\u001b[0;32m--> 832\u001b[0m \u001b[43mpio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 833\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 834\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mrepr\u001b[39m(\u001b[38;5;28mself\u001b[39m))\n", - "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/plotly/io/_renderers.py:394\u001b[0m, in \u001b[0;36mshow\u001b[0;34m(fig, renderer, validate, **kwargs)\u001b[0m\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 390\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMime type rendering requires ipython but it is not installed\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 391\u001b[0m )\n\u001b[1;32m 393\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m nbformat \u001b[38;5;129;01mor\u001b[39;00m Version(nbformat\u001b[38;5;241m.\u001b[39m__version__) \u001b[38;5;241m<\u001b[39m Version(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m4.2.0\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 394\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 395\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMime type rendering requires nbformat>=4.2.0 but it is not installed\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 396\u001b[0m )\n\u001b[1;32m 398\u001b[0m ipython_display\u001b[38;5;241m.\u001b[39mdisplay(bundle, raw\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 400\u001b[0m \u001b[38;5;66;03m# external renderers\u001b[39;00m\n", - "\u001b[0;31mValueError\u001b[0m: Mime type rendering requires nbformat>=4.2.0 but it is not installed" - ] - }, { "data": { "application/vnd.plotly.v1+json": { @@ -328,87 +311,7 @@ 16, 17, 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39, - 40, - 41, - 42, - 43, - 44, - 45, - 46, - 47, - 48, - 49, - 50, - 51, - 52, - 53, - 54, - 55, - 56, - 57, - 58, - 59, - 60, - 61, - 62, - 63, - 64, - 65, - 66, - 67, - 68, - 69, - 70, - 71, - 72, - 73, - 74, - 75, - 76, - 77, - 78, - 79, - 80, - 81, - 82, - 83, - 84, - 85, - 86, - 87, - 88, - 89, - 90, - 91, - 92, - 93, - 94, - 95, - 96, - 97, - 98, - 99 + 19 ], "colorbar": { "title": { @@ -465,208 +368,48 @@ "showlegend": false, "type": "scatter", "x": [ - 0.01060256463073004, - 0.06363319752806389, - 0.044054634905094, - 0.005457455596524018, - 0.0006534358443990889, - 0.005409564385563341, - 0.000296117238409861, - 0.0001548659658268065, - 0.0059773222630647126, - 0.009610312272318417, - 0.07936629680356555, - 0.0999017085272383, - 0.03030693483792421, - 0.09839071999176227, - 0.0018120798841072265, - 0.02630581835365611, - 0.0015862265658428628, - 0.018963170798261392, - 0.05591498203997656, - 0.0507252826898167, - 0.015598714196031616, - 0.04981982237560933, - 0.04472041885602553, - 0.04569663540155219, - 0.02409255925048435, - 0.012153379377141215, - 0.06204357797505839, - 0.0318047942959313, - 0.0026889985534921607, - 0.008406962615909752, - 0.07570103253376405, - 0.059086197752773524, - 0.03714901133834086, - 0.0646915046763561, - 0.018225411751837776, - 0.060873750215417036, - 0.0005937941753721039, - 0.03794230073789979, - 0.021776847720301837, - 0.012825763536089739, - 0.006945279111941595, - 0.003924318415667345, - 0.0009962905683287265, - 0.007485623216043211, - 0.08121672594329421, - 0.03218747043584, - 0.0043209290517303035, - 0.09007673305976491, - 0.00011242006456928975, - 0.052926437331853586, - 0.015080515584708165, - 0.06571961786071892, - 0.06853902191868609, - 0.043188154256233874, - 0.03137506646247659, - 0.09815877477219788, - 0.024965815660978004, - 0.0002369246671375477, - 0.05402719364497381, - 0.04024469007474665, - 0.0021520115377091278, - 0.06777601481845777, - 0.054552697374482345, - 0.07824127973056091, - 0.062190956331727144, - 0.04574533489252647, - 0.026700039728034024, - 0.035572253575668344, - 0.020947893979189017, - 0.09807853814362605, - 0.074008662788353, - 0.006494329870318023, - 0.013324381014246824, - 0.047348523959182987, - 0.05213136025170205, - 0.0468237789021135, - 0.009671712975184507, - 0.028449834843744366, - 0.03736743378716483, - 0.0011102163493235233, - 0.059419725114923834, - 0.029590575438061956, - 0.017734648001224127, - 0.08418584424126585, - 0.08334759428158998, - 0.06386707358622472, - 0.08176816323358746, - 0.04691662345629954, - 0.0005092474657505243, - 0.06835465614566856, - 0.040926964321475, - 0.05663192019646022, - 0.03407804204205137, - 0.083789776074211, - 0.02789909609279049, - 0.04909602447498623, - 0.07135979861985685, - 0.02340684816934831, - 0.04183807409164346, - 0.09709298119891982 + 0.012301378174433742, + 0.023368957405967835, + 0.0013378005538716429, + 0.007007650689123562, + 0.042762352644500096, + 0.0015413604505815153, + 0.029375669161110958, + 0.004598986580964227, + 0.04232333700633576, + 0.0010082622326759046, + 0.0944781292696637, + 0.019765558641651543, + 0.026254173914001054, + 0.08192645761678255, + 0.015592889217911915, + 0.0038523074238855904, + 0.022691713565568716, + 0.009542234485749465, + 0.05229102934003298, + 0.02945876449044553 ], "y": [ - 0.18700000643730164, - 0.4269999861717224, - 0.2070000022649765, - 0.3019999861717224, - 0.23999999463558197, - 0.1420000046491623, - 0.2409999966621399, - 0.13600000739097595, - 0.13199999928474426, - 0.2669999897480011, - 0.39899998903274536, - 0.2280000001192093, - 0.09700000286102295, - 0.3709999918937683, - 0.11599999666213989, - 0.164000004529953, - 0.3959999978542328, - 0.16599999368190765, - 0.4099999964237213, - 0.23499999940395355, - 0.27799999713897705, - 0.453000009059906, - 0.2160000056028366, - 0.27900001406669617, - 0.30799999833106995, - 0.3959999978542328, - 0.4350000023841858, - 0.26899999380111694, - 0.36500000953674316, - 0.17599999904632568, - 0.36399999260902405, - 0.1469999998807907, - 0.20600000023841858, - 0.3330000042915344, - 0.23499999940395355, - 0.4059999883174896, - 0.23199999332427979, - 0.14399999380111694, - 0.3569999933242798, - 0.4009999930858612, - 0.4180000126361847, - 0.3070000112056732, - 0.19499999284744263, - 0.14900000393390656, - 0.34299999475479126, - 0.18799999356269836, - 0.3499999940395355, - 0.2070000022649765, - 0.15299999713897705, - 0.13199999928474426, - 0.19099999964237213, - 0.41499999165534973, - 0.22300000488758087, - 0.17900000512599945, - 0.2759999930858612, - 0.36899998784065247, - 0.14499999582767487, - 0.335999995470047, - 0.13300000131130219, - 0.30399999022483826, - 0.3540000021457672, - 0.17000000178813934, - 0.2549999952316284, - 0.24799999594688416, - 0.15299999713897705, - 0.21699999272823334, - 0.31700000166893005, - 0.35499998927116394, - 0.38100001215934753, - 0.39399999380111694, - 0.13099999725818634, - 0.17800000309944153, - 0.27900001406669617, - 0.4050000011920929, - 0.28600001335144043, - 0.22100000083446503, - 0.23499999940395355, - 0.4050000011920929, - 0.35499998927116394, - 0.33799999952316284, - 0.15299999713897705, - 0.20100000500679016, - 0.35499998927116394, - 0.40700000524520874, - 0.14499999582767487, - 0.10499999672174454, - 0.16899999976158142, - 0.1550000011920929, - 0.3019999861717224, - 0.14100000262260437, - 0.1940000057220459, - 0.3370000123977661, - 0.22499999403953552, - 0.3490000069141388, - 0.13199999928474426, - 0.3779999911785126, - 0.3160000145435333, - 0.13699999451637268, - 0.38999998569488525, - 0.335999995470047 + 0.9595600068569183, + 0.9620999932289124, + 0.911080002784729, + 0.9557799935340882, + 0.9598599970340729, + 0.9210599958896637, + 0.9619599997997283, + 0.9521999955177307, + 0.9614199936389923, + 0.8862399995326996, + 0.9530400037765503, + 0.9606599926948547, + 0.9623999953269958, + 0.9538000106811524, + 0.9603799998760223, + 0.9487399995326996, + 0.9629000008106232, + 0.9598199963569641, + 0.960340005159378, + 0.9630400002002716 ] } ], @@ -1502,125 +1245,23 @@ } } } - }, - "text/html": [ - "
\n", - "
" - ], - "text/plain": [ - "Figure({\n", - " 'data': [{'marker': {'color': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", - " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,\n", - " 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,\n", - " 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,\n", - " 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,\n", - " 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73,\n", - " 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85,\n", - " 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,\n", - " 98, 99],\n", - " 'colorbar': {'title': {'text': 'Trial'}, 'x': 1.0, 'xpad': 40},\n", - " 'colorscale': [[0.0, 'rgb(247,251,255)'], [0.125,\n", - " 'rgb(222,235,247)'], [0.25,\n", - " 'rgb(198,219,239)'], [0.375,\n", - " 'rgb(158,202,225)'], [0.5,\n", - " 'rgb(107,174,214)'], [0.625,\n", - " 'rgb(66,146,198)'], [0.75,\n", - " 'rgb(33,113,181)'], [0.875,\n", - " 'rgb(8,81,156)'], [1.0, 'rgb(8,48,107)']],\n", - " 'line': {'color': 'Grey', 'width': 0.5}},\n", - " 'mode': 'markers',\n", - " 'name': 'Feasible Trial',\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'x': [0.01060256463073004, 0.06363319752806389, 0.044054634905094,\n", - " 0.005457455596524018, 0.0006534358443990889,\n", - " 0.005409564385563341, 0.000296117238409861,\n", - " 0.0001548659658268065, 0.0059773222630647126,\n", - " 0.009610312272318417, 0.07936629680356555, 0.0999017085272383,\n", - " 0.03030693483792421, 0.09839071999176227,\n", - " 0.0018120798841072265, 0.02630581835365611,\n", - " 0.0015862265658428628, 0.018963170798261392,\n", - " 0.05591498203997656, 0.0507252826898167, 0.015598714196031616,\n", - " 0.04981982237560933, 0.04472041885602553, 0.04569663540155219,\n", - " 0.02409255925048435, 0.012153379377141215, 0.06204357797505839,\n", - " 0.0318047942959313, 0.0026889985534921607,\n", - " 0.008406962615909752, 0.07570103253376405,\n", - " 0.059086197752773524, 0.03714901133834086, 0.0646915046763561,\n", - " 0.018225411751837776, 0.060873750215417036,\n", - " 0.0005937941753721039, 0.03794230073789979,\n", - " 0.021776847720301837, 0.012825763536089739,\n", - " 0.006945279111941595, 0.003924318415667345,\n", - " 0.0009962905683287265, 0.007485623216043211,\n", - " 0.08121672594329421, 0.03218747043584, 0.0043209290517303035,\n", - " 0.09007673305976491, 0.00011242006456928975,\n", - " 0.052926437331853586, 0.015080515584708165,\n", - " 0.06571961786071892, 0.06853902191868609, 0.043188154256233874,\n", - " 0.03137506646247659, 0.09815877477219788, 0.024965815660978004,\n", - " 0.0002369246671375477, 0.05402719364497381,\n", - " 0.04024469007474665, 0.0021520115377091278,\n", - " 0.06777601481845777, 0.054552697374482345, 0.07824127973056091,\n", - " 0.062190956331727144, 0.04574533489252647,\n", - " 0.026700039728034024, 0.035572253575668344,\n", - " 0.020947893979189017, 0.09807853814362605, 0.074008662788353,\n", - " 0.006494329870318023, 0.013324381014246824,\n", - " 0.047348523959182987, 0.05213136025170205, 0.0468237789021135,\n", - " 0.009671712975184507, 0.028449834843744366,\n", - " 0.03736743378716483, 0.0011102163493235233,\n", - " 0.059419725114923834, 0.029590575438061956,\n", - " 0.017734648001224127, 0.08418584424126585, 0.08334759428158998,\n", - " 0.06386707358622472, 0.08176816323358746, 0.04691662345629954,\n", - " 0.0005092474657505243, 0.06835465614566856, 0.040926964321475,\n", - " 0.05663192019646022, 0.03407804204205137, 0.083789776074211,\n", - " 0.02789909609279049, 0.04909602447498623, 0.07135979861985685,\n", - " 0.02340684816934831, 0.04183807409164346, 0.09709298119891982],\n", - " 'y': [0.18700000643730164, 0.4269999861717224, 0.2070000022649765,\n", - " 0.3019999861717224, 0.23999999463558197, 0.1420000046491623,\n", - " 0.2409999966621399, 0.13600000739097595, 0.13199999928474426,\n", - " 0.2669999897480011, 0.39899998903274536, 0.2280000001192093,\n", - " 0.09700000286102295, 0.3709999918937683, 0.11599999666213989,\n", - " 0.164000004529953, 0.3959999978542328, 0.16599999368190765,\n", - " 0.4099999964237213, 0.23499999940395355, 0.27799999713897705,\n", - " 0.453000009059906, 0.2160000056028366, 0.27900001406669617,\n", - " 0.30799999833106995, 0.3959999978542328, 0.4350000023841858,\n", - " 0.26899999380111694, 0.36500000953674316, 0.17599999904632568,\n", - " 0.36399999260902405, 0.1469999998807907, 0.20600000023841858,\n", - " 0.3330000042915344, 0.23499999940395355, 0.4059999883174896,\n", - " 0.23199999332427979, 0.14399999380111694, 0.3569999933242798,\n", - " 0.4009999930858612, 0.4180000126361847, 0.3070000112056732,\n", - " 0.19499999284744263, 0.14900000393390656, 0.34299999475479126,\n", - " 0.18799999356269836, 0.3499999940395355, 0.2070000022649765,\n", - " 0.15299999713897705, 0.13199999928474426, 0.19099999964237213,\n", - " 0.41499999165534973, 0.22300000488758087, 0.17900000512599945,\n", - " 0.2759999930858612, 0.36899998784065247, 0.14499999582767487,\n", - " 0.335999995470047, 0.13300000131130219, 0.30399999022483826,\n", - " 0.3540000021457672, 0.17000000178813934, 0.2549999952316284,\n", - " 0.24799999594688416, 0.15299999713897705, 0.21699999272823334,\n", - " 0.31700000166893005, 0.35499998927116394, 0.38100001215934753,\n", - " 0.39399999380111694, 0.13099999725818634, 0.17800000309944153,\n", - " 0.27900001406669617, 0.4050000011920929, 0.28600001335144043,\n", - " 0.22100000083446503, 0.23499999940395355, 0.4050000011920929,\n", - " 0.35499998927116394, 0.33799999952316284, 0.15299999713897705,\n", - " 0.20100000500679016, 0.35499998927116394, 0.40700000524520874,\n", - " 0.14499999582767487, 0.10499999672174454, 0.16899999976158142,\n", - " 0.1550000011920929, 0.3019999861717224, 0.14100000262260437,\n", - " 0.1940000057220459, 0.3370000123977661, 0.22499999403953552,\n", - " 0.3490000069141388, 0.13199999928474426, 0.3779999911785126,\n", - " 0.3160000145435333, 0.13699999451637268, 0.38999998569488525,\n", - " 0.335999995470047]}],\n", - " 'layout': {'template': '...',\n", - " 'title': {'text': 'Slice Plot'},\n", - " 'xaxis': {'title': {'text': 'learning_rate'}, 'type': 'log'},\n", - " 'yaxis': {'title': {'text': 'Objective Value'}}}\n", - "})" - ] + } }, - "execution_count": 53, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "vis.plot_optimization_history(study)" + "optuna.visualization.plot_slice(study)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Conclusion\n", + "\n", + "In this post, we showed how to integrate Optuna with FEDn for hyperparameter tuning, using the example of tuning the learning rate of FedAdam. By defining an objective function and leveraging Optuna's efficient optimization, we automated the search for the best server-side learning rate to maximize test accuracy. With FEDn’s flexible API, we were able to evaluate performance in a flexible manner, whether by selecting the highest accuracy or averaging the final rounds." ] } ], @@ -1640,7 +1281,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.0" + "version": "3.12.6" } }, "nbformat": 4,