Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
simplymathematics committed Nov 30, 2024
1 parent 97a851b commit a832d2f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
4 changes: 3 additions & 1 deletion deckard/base/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(
)
if self.attack is not None:
logger.info(f"Attack: {self.attack}")
adv_metrics = [f"adv_{metric}" for metric in metrics if not metric.startswith("adv_")]
adv_metrics = [
f"adv_{metric}" for metric in metrics if not metric.startswith("adv_")
]
if "adv_train_time" in adv_metrics:
adv_metrics.remove("adv_train_time")
adv_metrics.append("adv_fit_time")
Expand Down
3 changes: 0 additions & 3 deletions examples/gzip/classifier_refactor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Literal
import logging
import numpy as np
import pandas as pd
import brotli
import pickle
import gzip
Expand All @@ -15,7 +14,6 @@
from sklearn.base import BaseEstimator, TransformerMixin



logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -433,7 +431,6 @@ def fit_transform(self, X, y=None):
return self.transform(X)



if __name__ == "__main__":

_config = """
Expand Down
14 changes: 7 additions & 7 deletions examples/gzip/plots.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import numpy as np
from pathlib import Path
from matplotlib.ticker import FixedLocator, NullFormatter

# Set seaborn theme to paper using times new roman font
sns.set_theme(context="paper", style="whitegrid", font="Times New Roman", font_scale=2)
Expand Down Expand Up @@ -39,7 +37,7 @@
data = data.fillna(0)
print(f"Shape of data: {data.shape}")
print(f"Columns: {data.columns}")
data['accuracy'] = data['accuracy'] * 100
data["accuracy"] = data["accuracy"] * 100
tmp_groups = []
for col in group_these:
if col not in data.columns:
Expand All @@ -66,7 +64,9 @@
# add the mean and standard deviation to the group
group[col + "_mean"] = mean
group[col + "_std"] = std
assert f"{col}_mean" in group.columns, f"{col}_mean not in group columns"
assert (
f"{col}_mean" in group.columns
), f"{col}_mean not in group columns"
assert f"{col}_std" in group.columns, f"{col}_std not in group columns"
group = group.drop(col, axis=1)
# group = group.head(1)
Expand Down Expand Up @@ -233,7 +233,7 @@

refit_df = pd.read_csv("output/combined/plots/refit_merged.csv", index_col=0)
refit_df["Algorithm"] = refit_df["algorithm"]
refit_df['accuracy'] = refit_df['accuracy'] * 100
refit_df["accuracy"] = refit_df["accuracy"] * 100
refit_df.dropna(inplace=True, subset=["accuracy"])
acc_graph = sns.relplot(
data=refit_df,
Expand All @@ -249,7 +249,7 @@
hue_order=["Vanilla", "Assumed", "Enforced", "Average"],
style_order=["GZIP", "BZ2", "Brotli", "Hamming", "Ratio", "Levenshtein"],
)

for ax in acc_graph.axes.flat:
# Increase the line thickness
for line in ax.lines:
Expand All @@ -259,7 +259,7 @@
label.set_rotation(45)
# set ylim to [0,1]
ax.set_ylim(0, 100)

acc_graph.set_axis_labels("No. of Training Samples", " Accuracy (%)")
acc_graph.set_titles("{row_name} - {col_name}")
# acc_graph.tight_layout()
Expand Down

0 comments on commit a832d2f

Please sign in to comment.