Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Afr big merge #176

Merged
merged 93 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
67eed77
fix bug in scoring failure
simplymathematics Dec 2, 2023
ff9cbbd
update power example to support bit-depth search
simplymathematics Dec 2, 2023
e913c96
update result directories
simplymathematics Dec 2, 2023
c8e5a06
revert changes to example/power
simplymathematics Dec 2, 2023
8928e7e
add bit depth example
simplymathematics Dec 2, 2023
97c178e
revert directory changes for power example
simplymathematics Dec 2, 2023
e302e38
fixed directory for mnist in power example
simplymathematics Dec 2, 2023
076dce7
fixed directory for mnist in power example
simplymathematics Dec 2, 2023
b4a5655
use the latest torch torchvision torchaudio
salehsedghpour Dec 2, 2023
96b071c
Merge branch 'main' of github.com:simplymathematics/deckard into bit-…
simplymathematics Dec 3, 2023
7c85be1
update workflow to push on PR
simplymathematics Dec 3, 2023
b09f54c
Merge branch 'bit-depth-power-example' of github.com:simplymathematic…
simplymathematics Dec 3, 2023
c2af1a4
removed double stage folders in log folder
simplymathematics Dec 3, 2023
1712c3b
+ more epochs for cifar100
simplymathematics Dec 3, 2023
d428e48
changed intervals from uniform to log uniform, made learning rate ran…
simplymathematics Dec 3, 2023
74df959
strip whitespace, covert o numeric in compile script
simplymathematics Dec 4, 2023
d8fbf1a
update git ignores
simplymathematics Dec 4, 2023
183ba2a
suport nb_epoch as defence choice
simplymathematics Dec 5, 2023
aeb2b72
remove adv_success from requirements
simplymathematics Dec 5, 2023
6b94128
add "NaN" to nones
simplymathematics Dec 5, 2023
416ed65
update afr script
simplymathematics Dec 5, 2023
a289b18
update mnist .dvc cache
simplymathematics Dec 5, 2023
48a1f8a
updat cifar10 plots
simplymathematics Dec 5, 2023
aac2227
uncomment paretoset in plotting
simplymathematics Dec 5, 2023
2b3752c
fix default defence bug and realtive pathing in compile script
simplymathematics Dec 5, 2023
691fae2
moved plots to subfolder
simplymathematics Dec 5, 2023
687f454
better configuration support
simplymathematics Dec 5, 2023
621d2c1
fix compile script bug
simplymathematics Dec 6, 2023
93befed
update compile and plots yaml for power example
simplymathematics Dec 6, 2023
23fb2e1
fix compile bug
simplymathematics Dec 6, 2023
1c47dfd
update plots
simplymathematics Dec 6, 2023
92bcafd
include plot files in dvc
simplymathematics Dec 6, 2023
aacfb00
update afr to read from conf file
simplymathematics Dec 6, 2023
7a3d1ec
linting
simplymathematics Dec 6, 2023
08a812a
linting
simplymathematics Dec 6, 2023
87660a4
Merge branch 'main' of github.com:simplymathematics/deckard into fix-…
simplymathematics Dec 6, 2023
269cf98
update pytorch example
simplymathematics Dec 6, 2023
957c65d
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 6, 2023
108ff15
update pytorch afr.yaml (not working)
simplymathematics Dec 7, 2023
6868dd2
split cleaning from plotting, but only working for examples/pytorch/m…
simplymathematics Dec 7, 2023
cef183c
working cleaning script
simplymathematics Dec 7, 2023
26c617e
fix pytorch examples with new clean script
simplymathematics Dec 7, 2023
00e1a0a
remove debug check from parse_results
simplymathematics Dec 7, 2023
3a051fa
make deckard a depedendency of the parsing script
simplymathematics Dec 7, 2023
0f15e04
made models.sh easier to read
simplymathematics Dec 7, 2023
2b73578
update afr for pytorch example
simplymathematics Dec 7, 2023
1987660
update power example
simplymathematics Dec 8, 2023
0ac3f5d
update dvc.lock for pytorch example
simplymathematics Dec 8, 2023
da05855
update pytorch/cifar100
simplymathematics Dec 8, 2023
e17706a
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 8, 2023
3c0deb5
update power/plots (not working)
simplymathematics Dec 8, 2023
b2b9157
add docstrings to plots.py
simplymathematics Dec 8, 2023
556898b
update power example with merge script
simplymathematics Dec 9, 2023
85dea7e
add power data
Dec 10, 2023
717b3a9
update configs
simplymathematics Dec 10, 2023
d4792c7
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 10, 2023
3aee986
add combined plots
simplymathematics Dec 11, 2023
269ad24
update afr models
simplymathematics Dec 11, 2023
f20f6d7
added support for dummy variables in afr
simplymathematics Dec 12, 2023
b147fc1
++combined_plots.py and fix afr bug
simplymathematics Dec 12, 2023
2da27ed
add cifar100 l4 power data with commenting everything else
Dec 12, 2023
55e4564
add varepsilon to attack params
simplymathematics Dec 12, 2023
d71b198
add dummy variables
simplymathematics Dec 12, 2023
50ff188
fix rounding bug
simplymathematics Dec 13, 2023
80615f7
update to newest plots
simplymathematics Dec 13, 2023
6376e0c
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 13, 2023
4f3176a
newest plots for power example
simplymathematics Dec 13, 2023
071db7f
linting
simplymathematics Dec 13, 2023
139864d
removed old afr file
simplymathematics Dec 13, 2023
fc4ec91
linting
simplymathematics Dec 13, 2023
8ea0056
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Jan 15, 2024
49ef6c1
update conf
simplymathematics Jan 15, 2024
62aefae
\Merge branch 'fix-compile-script' of github.com:simplymathematics/de…
simplymathematics Jan 15, 2024
4a34fd1
fixed kepler script bug
simplymathematics Jan 15, 2024
fdd2e8a
linting
simplymathematics Jan 15, 2024
0cc7b42
linting
simplymathematics Jan 15, 2024
40047b0
linting
simplymathematics Jan 15, 2024
4826f0a
linting
simplymathematics Jan 15, 2024
3d93817
linting
simplymathematics Jan 15, 2024
8c78260
linting
simplymathematics Jan 15, 2024
105d051
linting
simplymathematics Jan 15, 2024
14e32a8
linting
simplymathematics Jan 15, 2024
f25d72b
linting
simplymathematics Jan 15, 2024
b12afe9
add dvc.yaml
salehsedghpour Feb 13, 2024
b27f8dc
add query_kepler
salehsedghpour Feb 13, 2024
f8eaa75
Merge branch 'fix-compile-script' of https://github.com/simplymathema…
simplymathematics Mar 20, 2024
beffc10
update layers
simplymathematics Mar 25, 2024
5c89e10
update confs
simplymathematics Mar 25, 2024
8380fdc
merge
simplymathematics Mar 25, 2024
a5e366f
fix post-merge scripts
simplymathematics Mar 25, 2024
50452e8
Merge branch 'fix-compile-script' of https://github.com/simplymathema…
simplymathematics Mar 25, 2024
694b9d4
linting
simplymathematics Mar 25, 2024
6f30a1d
linting
simplymathematics Mar 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 34 additions & 41 deletions deckard/layers/afr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pandas as pd
import numpy as np
from pathlib import Path

import logging
import yaml
import argparse
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from lifelines import (
WeibullAFTFitter,
Expand All @@ -12,13 +14,13 @@
CoxPHFitter,
)
from .clean_data import drop_frames_without_results
import matplotlib
import logging
import yaml
import argparse
from .plots import set_matplotlib_vars


logger = logging.getLogger(__name__)

sns.set_theme(style="whitegrid", font_scale=1.8, font="times new roman")


def plot_aft(
df,
Expand All @@ -30,12 +32,11 @@ def plot_aft(
xlabel=None,
ylabel=None,
replacement_dict={},
filetype=".eps",
folder=".",
legend={},
**kwargs,
):
file = Path(folder, file).with_suffix(filetype)
file = Path(folder, file)
aft = fit_aft(df, event_col, duration_col, mtype, kwargs)
columns = list(df.columns)
columns.remove(event_col)
Expand All @@ -58,7 +59,6 @@ def plot_aft(
ax.get_figure().tight_layout()
ax.get_figure().savefig(file)
logger.info(f"Saved graph to {file}")
plt.show()
plt.gcf().clear()
return ax, aft

Expand Down Expand Up @@ -102,11 +102,10 @@ def plot_partial_effects(
replacement_dict={},
cmap="coolwarm",
folder=".",
filetype=".eps",
**kwargs,
):
plt.gcf().clear()
file = Path(folder, file).with_suffix(filetype)
file = Path(folder, file)
partial_effects = aft.plot_partial_effects_on_outcome(
covariate_array,
values_array,
Expand Down Expand Up @@ -135,7 +134,6 @@ def score_model(aft, train, test):
train_score = aft.score(train)
test_score = aft.score(test)
scores = {"train_score": train_score, "test_score": test_score}
plt.show()
return scores


Expand Down Expand Up @@ -175,7 +173,7 @@ def make_afr_table(
aft_data.to_csv(folder / "aft_comparison.csv", na_rep="--")
logger.info(f"Saved AFT comparison to {folder / 'aft_comparison.csv'}")
aft_data.to_latex(
buf=folder / f"{filename}.tex",
buf=Path(folder / "aft_comparison.tex").as_posix(),
float_format="%.3g",
na_rep="--",
label=label,
Expand Down Expand Up @@ -203,16 +201,13 @@ def clean_data_for_aft(
), f"Target {target} not in dataframe with columns {subset.columns}"
logger.info(f"Shape of dirty data: {subset.shape}")
cleaned = pd.DataFrame()
covariate_list.append(target)

if target not in covariate_list:
covariate_list.append(target)
logger.info(f"Covariates : {covariate_list}")
for kwarg in covariate_list:
assert kwarg in subset.columns, f"{kwarg} not in data.columns"
cleaned = pd.concat([cleaned, subset[kwarg]], axis=1)
cols = cleaned.columns
cleaned = pd.DataFrame(subset, columns=cols)
cleaned.index = subset.index
# remove rows with -1e10 or 1e10, which are placeholders for run-time errors depending on the direction of optimization
cols = list(cleaned.columns)
for col in cols:
cleaned = cleaned[cleaned[col] != -1e10]
cleaned = cleaned[cleaned[col] != 1e10]
Expand Down Expand Up @@ -249,8 +244,8 @@ def split_data_for_aft(
assert (
duration_col in cleaned
), f"Duration {duration_col} not in dataframe with columns {cleaned.columns}"
X_train = X_train.dropna(axis=0, how="any")
X_test = X_test.dropna(axis=0, how="any")
# X_train = X_train.dropna(axis=0, how="any")
# X_test = X_test.dropna(axis=0, how="any")
X_train = pd.DataFrame(X_train, columns=cleaned.columns)
X_test = pd.DataFrame(X_test, columns=cleaned.columns)
return X_train, X_test
Expand Down Expand Up @@ -333,18 +328,6 @@ def render_all_afr_plots(
print("*" * 80)


def set_matplotlib_vars(matplotlib_dict=None):
if matplotlib_dict is None:
matplotlib_dict = {
"font": {
"family": "Times New Roman",
"weight": "bold",
"size": 22,
},
}
matplotlib.rc(**matplotlib_dict)


def fillna(data, config):
fillna = config.pop("fillna", {})
for k, v in fillna.items():
Expand All @@ -354,9 +337,19 @@ def fillna(data, config):

if "__main__" == __name__:
afr_parser = argparse.ArgumentParser()
afr_parser.add_argument("--target", type=str, default="adv_failures")
afr_parser.add_argument("--duration_col", type=str, default="adv_fit_time")
afr_parser.add_argument("--dataset", type=str, default="mnist")
afr_parser.add_argument(
"--target",
type=str,
help="Failure count column",
required=True,
)
afr_parser.add_argument(
"--duration_col",
type=str,
help="Duration column",
required=True,
)
afr_parser.add_argument("--dataset", type=str, help="Dataset name", required=True)
afr_parser.add_argument("--data_file", type=str, default="data.csv")
afr_parser.add_argument("--config_file", type=str, default="afr.yaml")
afr_parser.add_argument("--plots_folder", type=str, default="plots")
Expand Down Expand Up @@ -394,12 +387,8 @@ def fillna(data, config):
covariates = config.get("covariates", [])
assert len(covariates) > 0, "No covariates specified in config file"

# Cannot fit AFT models with missing values
logger.info(f"Shape of data before data before dropping na: {data.shape}")
data = drop_frames_without_results(data, covariates)
logger.info(f"Shape of data before data before dropping na: {data.shape}")
# Converting accuracy to unnormalized count, if needed
if "adv_failures" in covariates and "adv_failures" in data.columns:
if "adv_failures" in covariates:
logger.info("Adding adv_failures to data")
assert "adv_accuracy" in data.columns, "adv_accuracy not in data"
assert "attack.attack_size" in data.columns, "attack.attack_size not in data"
Expand All @@ -415,6 +404,10 @@ def fillna(data, config):
:,
"data.sample.test_size",
]
# Cannot fit AFT models with missing values
logger.info(f"Shape of data before data before dropping na: {data.shape}")
data = drop_frames_without_results(data, covariates)
logger.info(f"Shape of data before data before dropping na: {data.shape}")
# Plotting AFT models
render_all_afr_plots(
config,
Expand Down
1 change: 0 additions & 1 deletion deckard/layers/clean_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ def main(args):

if "adv_accuracy" in results.columns:
results = calculate_failure_rate(results)

results = min_max_scaling(results, *min_max)
output_file = save_results(
results,
Expand Down
22 changes: 18 additions & 4 deletions deckard/layers/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
sns.set_theme(style="whitegrid", font_scale=1.8, font="times new roman")


def set_matplotlib_vars(matplotlib_dict=None):
if matplotlib_dict is None:
matplotlib_dict = {
"font": {
"family": "Times New Roman",
"weight": "bold",
"size": 22,
},
}
else:
assert isinstance(matplotlib_dict, dict), "matplotlib_dict must be a dictionary"
for k, v in matplotlib_dict.items():
plt.rc(k, **v)


def cat_plot(
data,
x,
Expand Down Expand Up @@ -266,6 +281,9 @@ def scatter_plot(
file = Path(file).with_suffix(filetype)
logger.info(f"Rendering graph {file}")
data = data.sort_values(by=[hue, x, y])
assert hue in data.columns, f"{hue} not in data columns"
assert x in data.columns, f"{x} not in data columns"
assert y in data.columns, f"{y} not in data columns"
graph = sns.scatterplot(
data=data,
x=x,
Expand Down Expand Up @@ -356,20 +374,16 @@ def main(args):
logger.info(f"Creating folder {FOLDER}")
FOLDER.mkdir(parents=True, exist_ok=True)

i = 0
cat_plot_list = big_dict.get("cat_plot", [])
for dict_ in cat_plot_list:
i += 1
cat_plot(data, **dict_, folder=FOLDER, filetype=IMAGE_FILETYPE)

line_plot_list = big_dict.get("line_plot", [])
for dict_ in line_plot_list:
i += 1
line_plot(data, **dict_, folder=FOLDER, filetype=IMAGE_FILETYPE)

scatter_plot_list = big_dict.get("scatter_plot", [])
for dict_ in scatter_plot_list:
i += 1
scatter_plot(data, **dict_, folder=FOLDER, filetype=IMAGE_FILETYPE)


Expand Down
Loading
Loading