Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
simplymathematics committed Nov 28, 2023
1 parent 66aaa73 commit 9aa7960
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 44 deletions.
1 change: 0 additions & 1 deletion deckard/layers/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def get_dvc_stage_params(
directory=".",
name=None,
):

logger.info(
f"Getting params for stage {stage} from {params_file} and {pipeline_file} in {directory}.",
)
Expand Down
110 changes: 67 additions & 43 deletions examples/pytorch/mnist/average_across_random_states.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
" \"train_time\": \"min\",\n",
" \"predict_time\": \"min\",\n",
" \"accuracy\": \"max\",\n",
" \"data.sample.random_state\" : \"diff\",\n",
" \"data.sample.random_state\": \"diff\",\n",
" \"model.trainer.nb_epoch\": \"diff\",\n",
" \"atk_gen\": \"diff\",\n",
" \"atk_param\" : \"diff\",\n",
" \"atk_param\": \"diff\",\n",
" \"adv_accuracy\": \"max\",\n",
" \"adv_fit_time\": \"min\",\n",
"}\n"
"}"
]
},
{
Expand All @@ -59,7 +59,7 @@
"source": [
"layers = df.model_layers.unique()\n",
"layers.sort()\n",
"epochs = df['model.trainer.nb_epoch'].unique()\n",
"epochs = df[\"model.trainer.nb_epoch\"].unique()\n",
"epochs.sort()\n",
"attacks = df.atk_gen.unique()\n",
"attacks.sort()\n",
Expand All @@ -75,7 +75,7 @@
" f\"Attacks: {attacks}\\n\"\n",
" f\"Number of defenses: {len(defenses)}\\n\"\n",
" f\"Defenses: {defenses}\\n\"\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -525,27 +525,30 @@
"sense_dict = {\n",
" \"model_layers\": \"diff\",\n",
" \"accuracy\": \"max\",\n",
" \"data.sample.random_state\" : \"diff\",\n",
" \"data.sample.random_state\": \"diff\",\n",
" \"model.trainer.nb_epoch\": \"diff\",\n",
" \"model_layers\" : \"diff\",\n",
" \"model_layers\": \"diff\",\n",
" \"atk_gen\": \"diff\",\n",
" \"def_gen\": \"diff\",\n",
" \"adv_fit_time\": \"min\",\n",
" \"adv_accuracy\": \"min\",\n",
" \"predict_time\": \"min\",\n",
" \"train_time\" : \"min\",\n",
" \"attack.attack_size\" : \"diff\",\n",
" \"train_time\": \"min\",\n",
" \"attack.attack_size\": \"diff\",\n",
"}\n",
"\n",
"# Average across random states\n",
"scorer = \"accuracy\"\n",
"\n",
"\n",
"def average_across_random_states(df, scorer, sense_dict):\n",
" sense_dict.pop(\"data.sample.random_state\", None)\n",
" group_list = [k for k,v in sense_dict.items() if v == 'diff']\n",
" group_list = [k for k, v in sense_dict.items() if v == \"diff\"]\n",
" group_list_wo_random_state = group_list.copy()\n",
" print(f\"Grouping by {group_list_wo_random_state} for {scorer}\")\n",
" df[f'mean_{scorer}'] = df.groupby(group_list_wo_random_state)[scorer].transform('mean')\n",
" df[f\"mean_{scorer}\"] = df.groupby(group_list_wo_random_state)[scorer].transform(\n",
" \"mean\"\n",
" )\n",
" return df\n",
"\n",
"\n",
Expand All @@ -556,9 +559,10 @@
" df = df.drop(col, axis=1)\n",
" return df\n",
"\n",
"\n",
"def find_pareto_set_for_graph(df, sense_dict):\n",
" scorers = [k for k,v in sense_dict.items() if v in [\"max\", \"min\"]]\n",
" group_list = [k for k,v in sense_dict.items() if v == 'diff']\n",
" scorers = [k for k, v in sense_dict.items() if v in [\"max\", \"min\"]]\n",
" group_list = [k for k, v in sense_dict.items() if v == \"diff\"]\n",
" group_list_wo_attack = group_list.copy()\n",
" for group in group_list:\n",
" if group in [\"atk_gen\", \"atk_value\", \"atk_param\"]:\n",
Expand All @@ -570,7 +574,9 @@
" else:\n",
" continue\n",
" for scorer in scorers:\n",
" scores = df[scorer].fillna(df.groupby(group_list_wo_attack)[scorer].transform('mean'))\n",
" scores = df[scorer].fillna(\n",
" df.groupby(group_list_wo_attack)[scorer].transform(\"mean\")\n",
" )\n",
" df[scorer] = scores.fillna(scores.mean())\n",
" df = average_across_random_states(df, scorer, sense_dict)\n",
" value = sense_dict.get(scorer)\n",
Expand All @@ -581,8 +587,10 @@
" # df = df[bools]\n",
" return df\n",
"\n",
"\n",
"df = find_pareto_set_for_graph(df, sense_dict)\n",
"\n",
"\n",
"def drop_col_if_no_variance(df):\n",
" drop_these = []\n",
" for col in df.columns:\n",
Expand All @@ -591,6 +599,7 @@
" tmp = df.drop(drop_these, axis=1)\n",
" return tmp\n",
"\n",
"\n",
"df = drop_poorly_merged_columns(df)\n",
"\n",
"df"
Expand Down Expand Up @@ -623,7 +632,7 @@
}
],
"source": [
"sns.lineplot(data=df, y=\"adv_log_loss\", x=\"model.trainer.nb_epoch\", hue=\"model_layers\")\n"
"sns.lineplot(data=df, y=\"adv_log_loss\", x=\"model.trainer.nb_epoch\", hue=\"model_layers\")"
]
},
{
Expand Down Expand Up @@ -663,68 +672,83 @@
}
],
"source": [
"from lifelines import CoxPHFitter, KaplanMeierFitter, NelsonAalenFitter, AalenAdditiveFitter, WeibullAFTFitter, LogNormalAFTFitter, LogLogisticAFTFitter, PiecewiseExponentialRegressionFitter\n",
"from lifelines import (\n",
" CoxPHFitter,\n",
" KaplanMeierFitter,\n",
" NelsonAalenFitter,\n",
" AalenAdditiveFitter,\n",
" WeibullAFTFitter,\n",
" LogNormalAFTFitter,\n",
" LogLogisticAFTFitter,\n",
" PiecewiseExponentialRegressionFitter,\n",
")\n",
"\n",
"\n",
"model_dict = {\n",
" \"cox\" : CoxPHFitter,\n",
" \"cox\": CoxPHFitter,\n",
" # \"kaplan_meier\" : KaplanMeierFitter,\n",
" # \"nelson_aalen\" : NelsonAalenFitter,\n",
" # \"aalen_additive\" : AalenAdditiveFitter,\n",
" \"weibull\" : WeibullAFTFitter,\n",
" \"log_normal\" : LogNormalAFTFitter,\n",
" \"log_logistic\" : LogLogisticAFTFitter,\n",
" \"weibull\": WeibullAFTFitter,\n",
" \"log_normal\": LogNormalAFTFitter,\n",
" \"log_logistic\": LogLogisticAFTFitter,\n",
" # \"piecewise_exponential\" : PiecewiseExponentialRegressionFitter,\n",
"}\n",
"\n",
"\n",
"def fit_aft_model(df, sense_dict, model_name):\n",
" \n",
" stratify = ['atk_gen', 'def_gen',]\n",
" subset_df = df.copy()\n",
" stratify = [\n",
" \"atk_gen\",\n",
" \"def_gen\",\n",
" ]\n",
" subset_df = df.copy()\n",
" subset_df = subset_df.drop(stratify, axis=1)\n",
" model = model_dict[model_name]()\n",
" model.fit(df, duration_col ='mean_adv_fit_time', event_col='adv_failures')\n",
" model.fit(df, duration_col=\"mean_adv_fit_time\", event_col=\"adv_failures\")\n",
" model.print_summary()\n",
" plot = model.plot()\n",
" concordance = model.score(df, scoring_method='concordance_index')\n",
" concordance = model.score(df, scoring_method=\"concordance_index\")\n",
" print(f\"Concordance index: {concordance}\")\n",
" measured_median = np.median(df.mean_adv_fit_time / df['attack.attack_size'] * ((1 - df.adv_failures)/100))\n",
" measured_median = np.median(\n",
" df.mean_adv_fit_time / df[\"attack.attack_size\"] * ((1 - df.adv_failures) / 100)\n",
" )\n",
" print(\"Measured median attack time:\", measured_median)\n",
" modelled_median = np.median(model.predict_median(df, ancillary=df))\n",
" print(\"Predicted median attack time:\", modelled_median)\n",
" score = model.score(df, scoring_method='log_likelihood')\n",
" score = model.score(df, scoring_method=\"log_likelihood\")\n",
" score_dict = {\n",
" \"model\" : model_name,\n",
" \"concordance\" : concordance,\n",
" \"measured_median\" : measured_median,\n",
" \"modelled_median\" : modelled_median,\n",
" \"log_likelihood\" : score,\n",
" \"model\": model_name,\n",
" \"concordance\": concordance,\n",
" \"measured_median\": measured_median,\n",
" \"modelled_median\": modelled_median,\n",
" \"log_likelihood\": score,\n",
" }\n",
" return model, plot, score\n",
"\n",
"\n",
"models = {}\n",
"scores = {}\n",
"plots = {}\n",
"stratify = ['atk_gen', 'def_gen']\n",
"stratify = [\"atk_gen\", \"def_gen\"]\n",
"subset_cols = [k for k in sense_dict if k not in stratify]\n",
"aft_df = df[subset_cols].copy()\n",
"aft_df['adv_failures'] = (1 - df['mean_adv_accuracy']) * df['attack.attack_size']\n",
"del aft_df['mean_adv_accuracy']\n",
"aft_df[\"adv_failures\"] = (1 - df[\"mean_adv_accuracy\"]) * df[\"attack.attack_size\"]\n",
"del aft_df[\"mean_adv_accuracy\"]\n",
"new_sense_dict = sense_dict.copy()\n",
"new_sense_dict.update({\"adv_failures\" : sense_dict['mean_adv_accuracy']})\n",
"new_sense_dict.update({\"adv_failures\": sense_dict[\"mean_adv_accuracy\"]})\n",
"new_sense_dict.pop(\"mean_adv_accuracy\", None)\n",
"new_sense_dict\n",
"\n",
"for model_name in model_dict:\n",
" print(f\"Fitting {model_name} model\")\n",
" model, plot, score = fit_aft_model(aft_df, new_sense_dict, model_name)\n",
" models.update({model_name : model})\n",
" scores.update({model_name : score})\n",
" plots.update({model_name : plot})\n",
" plt.xscale('linear')\n",
" models.update({model_name: model})\n",
" scores.update({model_name: score})\n",
" plots.update({model_name: plot})\n",
" plt.xscale(\"linear\")\n",
" plt.show()\n",
" plt.gcf().clear()\n",
" \n",
"\n",
"# scores = pd.DataFrame.from_dict(scores, orient='index', columns=['score'])\n",
"\n",
"# covariates = [k for k,v in sense_dict.items() if v == 'diff']\n",
Expand All @@ -747,9 +771,9 @@
"metadata": {},
"outputs": [],
"source": [
"model = models['weibull']\n",
"model = models[\"weibull\"]\n",
"expectations = model.predict_expectation(df, ancillary=df)\n",
"survival_function = model.predict_survival_function(df, ancillary=df)\n"
"survival_function = model.predict_survival_function(df, ancillary=df)"
]
},
{
Expand Down

0 comments on commit 9aa7960

Please sign in to comment.