Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cifar100 accuracy #181

Merged
merged 208 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
208 commits
Select commit Hold shift + click to select a range
67eed77
fix bug in scoring failure
simplymathematics Dec 2, 2023
ff9cbbd
update power example to support bit-depth search
simplymathematics Dec 2, 2023
e913c96
update result directories
simplymathematics Dec 2, 2023
c8e5a06
revert changes to example/power
simplymathematics Dec 2, 2023
8928e7e
add bit depth example
simplymathematics Dec 2, 2023
97c178e
revert directory changes for power example
simplymathematics Dec 2, 2023
e302e38
fixed directory for mnist in power example
simplymathematics Dec 2, 2023
076dce7
fixed directory for mnist in power example
simplymathematics Dec 2, 2023
b4a5655
use the latest torch torchvision torchaudio
salehsedghpour Dec 2, 2023
96b071c
Merge branch 'main' of github.com:simplymathematics/deckard into bit-…
simplymathematics Dec 3, 2023
7c85be1
update workflow to push on PR
simplymathematics Dec 3, 2023
b09f54c
Merge branch 'bit-depth-power-example' of github.com:simplymathematic…
simplymathematics Dec 3, 2023
c2af1a4
removed double stage folders in log folder
simplymathematics Dec 3, 2023
1712c3b
+ more epochs for cifar100
simplymathematics Dec 3, 2023
d428e48
changed intervals from uniform to log uniform, made learning rate ran…
simplymathematics Dec 3, 2023
74df959
strip whitespace, covert o numeric in compile script
simplymathematics Dec 4, 2023
d8fbf1a
update git ignores
simplymathematics Dec 4, 2023
183ba2a
suport nb_epoch as defence choice
simplymathematics Dec 5, 2023
aeb2b72
remove adv_success from requirements
simplymathematics Dec 5, 2023
6b94128
add "NaN" to nones
simplymathematics Dec 5, 2023
416ed65
update afr script
simplymathematics Dec 5, 2023
a289b18
update mnist .dvc cache
simplymathematics Dec 5, 2023
48a1f8a
updat cifar10 plots
simplymathematics Dec 5, 2023
aac2227
uncomment paretoset in plotting
simplymathematics Dec 5, 2023
2b3752c
fix default defence bug and realtive pathing in compile script
simplymathematics Dec 5, 2023
691fae2
moved plots to subfolder
simplymathematics Dec 5, 2023
687f454
better configuration support
simplymathematics Dec 5, 2023
621d2c1
fix compile script bug
simplymathematics Dec 6, 2023
93befed
update compile and plots yaml for power example
simplymathematics Dec 6, 2023
23fb2e1
fix compile bug
simplymathematics Dec 6, 2023
1c47dfd
update plots
simplymathematics Dec 6, 2023
92bcafd
include plot files in dvc
simplymathematics Dec 6, 2023
aacfb00
update afr to read from conf file
simplymathematics Dec 6, 2023
7a3d1ec
linting
simplymathematics Dec 6, 2023
08a812a
linting
simplymathematics Dec 6, 2023
87660a4
Merge branch 'main' of github.com:simplymathematics/deckard into fix-…
simplymathematics Dec 6, 2023
269cf98
update pytorch example
simplymathematics Dec 6, 2023
957c65d
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 6, 2023
108ff15
update pytorch afr.yaml (not working)
simplymathematics Dec 7, 2023
6868dd2
split cleaning from plotting, but only working for examples/pytorch/m…
simplymathematics Dec 7, 2023
cef183c
working cleaning script
simplymathematics Dec 7, 2023
26c617e
fix pytorch examples with new clean script
simplymathematics Dec 7, 2023
00e1a0a
remove debug check from parse_results
simplymathematics Dec 7, 2023
3a051fa
make deckard a depedendency of the parsing script
simplymathematics Dec 7, 2023
0f15e04
made models.sh easier to read
simplymathematics Dec 7, 2023
2b73578
update afr for pytorch example
simplymathematics Dec 7, 2023
1987660
update power example
simplymathematics Dec 8, 2023
0ac3f5d
update dvc.lock for pytorch example
simplymathematics Dec 8, 2023
da05855
update pytorch/cifar100
simplymathematics Dec 8, 2023
e17706a
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 8, 2023
3c0deb5
update power/plots (not working)
simplymathematics Dec 8, 2023
b2b9157
add docstrings to plots.py
simplymathematics Dec 8, 2023
556898b
update power example with merge script
simplymathematics Dec 9, 2023
85dea7e
add power data
Dec 10, 2023
717b3a9
update configs
simplymathematics Dec 10, 2023
d4792c7
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 10, 2023
3aee986
add combined plots
simplymathematics Dec 11, 2023
269ad24
update afr models
simplymathematics Dec 11, 2023
f20f6d7
added support for dummy variables in afr
simplymathematics Dec 12, 2023
b147fc1
++combined_plots.py and fix afr bug
simplymathematics Dec 12, 2023
2da27ed
add cifar100 l4 power data with commenting everything else
Dec 12, 2023
55e4564
add varepsilon to attack params
simplymathematics Dec 12, 2023
d71b198
add dummy variables
simplymathematics Dec 12, 2023
50ff188
fix rounding bug
simplymathematics Dec 13, 2023
80615f7
update to newest plots
simplymathematics Dec 13, 2023
6376e0c
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 13, 2023
4f3176a
newest plots for power example
simplymathematics Dec 13, 2023
071db7f
linting
simplymathematics Dec 13, 2023
139864d
removed old afr file
simplymathematics Dec 13, 2023
fc4ec91
linting
simplymathematics Dec 13, 2023
8ea0056
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Jan 15, 2024
49ef6c1
update conf
simplymathematics Jan 15, 2024
62aefae
\Merge branch 'fix-compile-script' of github.com:simplymathematics/de…
simplymathematics Jan 15, 2024
4a34fd1
fixed kepler script bug
simplymathematics Jan 15, 2024
fdd2e8a
linting
simplymathematics Jan 15, 2024
0cc7b42
linting
simplymathematics Jan 15, 2024
40047b0
linting
simplymathematics Jan 15, 2024
4826f0a
linting
simplymathematics Jan 15, 2024
3d93817
linting
simplymathematics Jan 15, 2024
8c78260
linting
simplymathematics Jan 15, 2024
105d051
linting
simplymathematics Jan 15, 2024
14e32a8
linting
simplymathematics Jan 15, 2024
f25d72b
linting
simplymathematics Jan 15, 2024
358e759
fixed cifar100 pytorch example script
simplymathematics Jan 15, 2024
4b0ff7d
more resilient wait and cleaning scripts
simplymathematics Jan 15, 2024
4345985
+GZIP example
simplymathematics Jan 22, 2024
c49cd72
bug fixes
simplymathematics Jan 22, 2024
44d6145
fix latex nan bug
simplymathematics Jan 23, 2024
ef5612d
bug fixes
simplymathematics Jan 23, 2024
dc7c478
add index to compilation csv
simplymathematics Jan 23, 2024
9b57058
better defence merging
simplymathematics Jan 23, 2024
ff9d924
fixed bug where x,y scale are None
simplymathematics Jan 23, 2024
9dcc132
update cifar100 confs
simplymathematics Jan 23, 2024
ce45cde
fixed cleaning bug
simplymathematics Jan 23, 2024
285fcea
fixed afr plot rendering bug
simplymathematics Jan 24, 2024
90e41f1
add check for negative predict time
simplymathematics Jan 24, 2024
0254a92
update configs
simplymathematics Jan 24, 2024
f82e1d8
fixed failure rate bug and updated confs
simplymathematics Jan 24, 2024
0971505
Merge branch 'cifar100' of github.com:simplymathematics/deckard into …
simplymathematics Jan 24, 2024
b09a423
change plot default from eps to pdf
simplymathematics Jan 24, 2024
80998a8
fix bug in calculating failure rate when attack size != train size
simplymathematics Jan 24, 2024
c76394e
update mnist confs
simplymathematics Jan 25, 2024
5491358
specify attack size at the command line
simplymathematics Jan 25, 2024
74fbe66
linting
simplymathematics Jan 26, 2024
ab6459f
update all plot configs
simplymathematics Jan 26, 2024
351f597
Fix compile script (#172)
simplymathematics Jan 24, 2024
8f08882
added dummy vars, fixed plots
simplymathematics Jan 26, 2024
66515ff
fix afr.py bugs
simplymathematics Jan 26, 2024
33eeaf4
bad merge?
simplymathematics Jan 26, 2024
66f8319
merge
simplymathematics Jan 26, 2024
6625ae5
fix bugs
simplymathematics Jan 26, 2024
ae10c9f
Merge branch 'main' of github.com:simplymathematics/deckard into note…
simplymathematics Jan 26, 2024
cf265db
linting
simplymathematics Jan 26, 2024
1254a67
linting
simplymathematics Jan 26, 2024
3060e4c
linting
simplymathematics Jan 26, 2024
1da196c
linting
simplymathematics Jan 26, 2024
6609a53
linting
simplymathematics Jan 26, 2024
098af1c
linting
simplymathematics Jan 26, 2024
0f360f4
update linter
simplymathematics Jan 26, 2024
3150f9e
update linter
simplymathematics Jan 26, 2024
4455ef7
update linter
simplymathematics Jan 26, 2024
d17765e
update linter
simplymathematics Jan 26, 2024
6b6caef
linting
simplymathematics Jan 27, 2024
5978b44
update setup, .gitignore
simplymathematics Jan 30, 2024
e8257ed
fix failure rate bug (again)
simplymathematics Jan 30, 2024
aaea279
most up-to-date plots
simplymathematics Jan 31, 2024
bf5eaf8
update failure rate from h to f in pytorch examples
simplymathematics Feb 1, 2024
609bc8d
remove intercept and scale parameters from afr plots
simplymathematics Feb 1, 2024
56b3aa3
remove rows where the score is an error
simplymathematics Feb 1, 2024
ca6c7cc
update plolt confs for pytorch example
simplymathematics Feb 1, 2024
36ea908
allow setting filename from command line of AFR script
simplymathematics Feb 1, 2024
aa50ddc
plot legend tweaks
simplymathematics Feb 1, 2024
8ebdaa1
linting
simplymathematics Feb 4, 2024
2794006
Merge branch 'main' of github.com:simplymathematics/deckard into note…
simplymathematics Feb 4, 2024
5a3226b
linting
simplymathematics Feb 4, 2024
24b1b36
merge with main
simplymathematics Feb 4, 2024
de5ad1b
Update Dockerfile
simplymathematics Feb 4, 2024
f5b0287
Merge branch 'simplymathematics-workflow-diskspace-patch' of https://…
simplymathematics Feb 4, 2024
d9bd067
Merge branch 'notebook-branch' of https://github.com/simplymathematic…
simplymathematics Feb 4, 2024
b331acd
update dockerfile
simplymathematics Feb 4, 2024
5a8f88b
lintin
simplymathematics Feb 4, 2024
2b3cf6e
update gzip configs
simplymathematics Feb 4, 2024
f3123d9
better logging
simplymathematics Feb 4, 2024
d300763
add url validation for data pipeline
simplymathematics Feb 4, 2024
5e90ee8
git rm
simplymathematics Feb 4, 2024
80d6eaa
update truthseeker yaml
simplymathematics Feb 4, 2024
af7066d
add gzip .gitignore
simplymathematics Feb 4, 2024
db34da2
gzip dvc changes
simplymathematics Feb 4, 2024
a0de7bf
add sampling during training
simplymathematics Feb 5, 2024
533ad3e
more resilient find_best script
simplymathematics Feb 5, 2024
670d81a
fixed bug with finding min/max when data is non-numeric
simplymathematics Feb 5, 2024
32b13e7
add support for url/local datasets
simplymathematics Feb 5, 2024
bd2a7a5
add column dropping for data parsing
simplymathematics Feb 6, 2024
ac07dff
find best for multi-objective search
simplymathematics Feb 6, 2024
a773639
better cleaning for experiments without attacks
simplymathematics Feb 6, 2024
0874768
better filetype support when plotting
simplymathematics Feb 6, 2024
266f006
load distance matrix from disk (optionally)
simplymathematics Feb 6, 2024
7b4a517
update confs for gzip
simplymathematics Feb 6, 2024
9c178c4
Merge branches 'notebook-branch' and 'notebook-branch' of github.com:…
simplymathematics Feb 6, 2024
d13412e
update default params
simplymathematics Feb 6, 2024
21e7699
Merge branch 'main' of github.com:simplymathematics/deckard into note…
simplymathematics Mar 27, 2024
c78d51d
update .gitignore, add some models to torch_example
simplymathematics Apr 3, 2024
414198e
refactor confs
simplymathematics Apr 3, 2024
0679acf
update pytorch confs
simplymathematics Apr 3, 2024
d7e0354
minor bug fixes
simplymathematics Apr 3, 2024
43cf4d8
fix small bugs
simplymathematics Apr 3, 2024
5a1705a
add resnet examples
simplymathematics Apr 3, 2024
022cb38
update pytorch experiment confs
simplymathematics Apr 4, 2024
128f9a1
Merge branch 'main' of github.com:simplymathematics/deckard into fix-…
simplymathematics Apr 4, 2024
4836501
config changes
simplymathematics May 10, 2024
44ee1f4
increase timeline resolution
simplymathematics May 13, 2024
b73b006
config changes
simplymathematics May 13, 2024
da7fb05
removed dvc cruft
simplymathematics May 13, 2024
421e052
update cost normalization calculation
simplymathematics May 13, 2024
673fd9b
update afr plotting
simplymathematics May 15, 2024
d27dfea
remove partial effects from pytorch config, add support for aalen add…
simplymathematics May 15, 2024
8148ba4
update dvc file for pytorch plots
simplymathematics May 15, 2024
3f6fb2e
better error handling
simplymathematics May 15, 2024
392b253
Merge branch 'fix-cifar100-accuracy' of github.com:simplymathematics/…
simplymathematics May 15, 2024
cbdc65f
update survival plots
simplymathematics May 16, 2024
466d7d6
merge with soft reset
simplymathematics May 16, 2024
0166f45
++plots to newest overleaf
simplymathematics May 20, 2024
1e28adf
dummy config changes
simplymathematics May 22, 2024
f29effc
config change
simplymathematics May 22, 2024
a857435
re-run plot dvc.yaml
simplymathematics May 22, 2024
92b89dc
config changes
simplymathematics Jun 4, 2024
e0e5671
update .gitignore
simplymathematics Jun 5, 2024
7237859
fixed file bug
simplymathematics Jun 5, 2024
9531165
fix keyword bug
simplymathematics Jun 5, 2024
35f382f
change Coefficient plots to sym log scale
simplymathematics Jun 5, 2024
7a2da3d
and latex to result parsing
simplymathematics Jun 5, 2024
f777481
added support for forloop stage parsing
simplymathematics Jun 5, 2024
46a9be5
streamline some code
simplymathematics Jun 5, 2024
9a08413
config changes + predicting the metric with a model config chosen by key
simplymathematics Jun 13, 2024
5e87696
add plots yaml again
simplymathematics Jun 13, 2024
c35c94c
stop tracking cifar100.yaml
simplymathematics Jun 20, 2024
5dadbcf
fix uncaught exception
simplymathematics Jun 25, 2024
bc80872
make dataset formatting more robust
simplymathematics Jun 25, 2024
d4499db
add a type check
simplymathematics Jun 25, 2024
e537dce
update default configs for each dataset to use env vars instead of ha…
simplymathematics Jun 25, 2024
8a850ba
reconfigure the dvc pipeline for re-running and changing the number o…
simplymathematics Jun 25, 2024
e319871
config changes
simplymathematics Jul 2, 2024
faaa05b
better pytorch out of memory handling
simplymathematics Jul 2, 2024
1ad4118
add normalization to trash metric
simplymathematics Jul 2, 2024
704b883
better convergence error handline
simplymathematics Jul 6, 2024
68a7e63
config changes
simplymathematics Jul 6, 2024
c166940
linting
simplymathematics Jul 6, 2024
909a5d3
linting
simplymathematics Jul 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,10 @@ env/
# screenlog
screenlog.*

# tmp.py
tmp.py
# env310
env310/*
# dvc read-write protect files in the examples folder
examples/*/.dvc/tmp/lock
examples/*/.dvc/tmp/lock
examples/*/.dvc/tmp/rwlock
examples/*/.dvc/tmp/rwlock.lock
4 changes: 2 additions & 2 deletions deckard/base/attack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,15 @@ def __call__(
)
except TypeError as e:
logger.error(f"Failed to compute success rate. Error: {e}")
if attack_file is not None:
if attack_file is not None and not Path(attack_file).exists():
self.data.save(samples, attack_file)
if adv_predictions_file is not None and Path(adv_predictions_file).exists():
adv_predictions = self.data.load(adv_predictions_file)
results["adv_predictions"] = np.array(adv_predictions)
else:
adv_predictions = model.predict(samples)
results["adv_predictions"] = np.array(adv_predictions)
if adv_predictions_file is not None:
if adv_predictions_file is not None and not Path(adv_predictions_file).exists():
self.data.save(adv_predictions, adv_predictions_file)
if adv_probabilities_file is not None and Path(adv_probabilities_file).exists():
adv_probabilities = self.data.load(adv_probabilities_file)
Expand Down
8 changes: 5 additions & 3 deletions deckard/base/data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
make_moons,
make_circles,
)
from torchvision.io import read_image, read_file

try:
from torchvision.io import read_image, read_file
except ImportError:
pass
from art.utils import load_mnist, load_cifar10, load_diabetes, to_categorical
from ..utils import my_hash

Expand Down Expand Up @@ -225,8 +229,6 @@ def __call__(self):
return TorchDataGenerator(self.name, **self.kwargs)()
elif self.name in KERAS_DATASETS:
return KerasDataGenerator(self.name, **self.kwargs)()
elif isinstance(self.name, str) and Path(self.name).exists():
return SklearnDataGenerator(self.name, **self.kwargs)()
else: # pragma: no cover
raise ValueError(
f"Invalid name {self.name}. Please choose from {ALL_DATASETS}",
Expand Down
2 changes: 2 additions & 0 deletions deckard/base/data/sklearn_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __call__(self, X_train, X_test, y_train, y_test):
name = self.kwargs.pop("_target_", self.name)
dict_ = {"_target_": name}
dict_.update(**self.kwargs)
while "kwargs" in dict_:
dict_.update(**dict_.pop("kwargs"))
obj = instantiate(dict_)
X_train = obj.fit(X_train).transform(X_train)
X_test = obj.transform(X_test)
Expand Down
16 changes: 8 additions & 8 deletions deckard/base/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __hash__(self):
name = str(self.name).encode("utf-8")
return int.from_bytes(name, "little")

def __call__(self):
def __call__(self, **kwargs):
"""Runs the experiment. If the experiment has already been run, it will load the results from disk. If scorer is not None, it will return the score for the specified scorer. If scorer is None, it will return the score for the first scorer in the ScorerDict.
:param scorer: The scorer to return the score for. If None, the score for the first scorer in the ScorerDict will be returned.
:type scorer: str
Expand All @@ -134,7 +134,8 @@ def __call__(self):
else:
score_dict = {}
results = {}
results["score_dict_file"] = score_dict
results["score_dict"] = score_dict
files.update(**results)
#########################################################################
# Load or generate data
#########################################################################
Expand All @@ -143,10 +144,10 @@ def __call__(self):
# Load or train model
#########################################################################
if self.model is not None:
model_results = self.model(data, **files)
model_results = self.model(**files)
score_dict.update(**model_results.pop("time_dict", {}))
score_dict.update(**model_results.pop("score_dict", {}))
model = model_results["model"]
files.update(**model_results)
# Prefer probabilities, then loss_files, then predictions
if (
"probabilities" in model_results
Expand Down Expand Up @@ -174,8 +175,6 @@ def __call__(self):
##########################################################################
if self.attack is not None:
adv_results = self.attack(
data,
model,
**files,
)
if "adv_predictions" in adv_results:
Expand All @@ -195,6 +194,7 @@ def __call__(self):
if "adv_success" in adv_results:
adv_success = adv_results["adv_success"]
score_dict.update({"adv_success": adv_success})
files.update(**adv_results)
##########################################################################
# Score results
#########################################################################
Expand All @@ -216,7 +216,7 @@ def __call__(self):
logger.debug(f" len(preds) : {len(preds)}")
new_score_dict = self.scorers(ground_truth, preds)
score_dict.update(**new_score_dict)
results["score_dict_file"] = score_dict
results["score_dict"] = score_dict
if "adv_preds" in locals():
ground_truth = data[3][: len(adv_preds)]
adv_preds = adv_preds[: len(ground_truth)]
Expand All @@ -225,7 +225,7 @@ def __call__(self):
f"adv_{key}": value for key, value in adv_score_dict.items()
}
score_dict.update(**adv_score_dict)
results["score_dict_file"] = score_dict
results["score_dict"] = score_dict
# # Save results
if "score_dict_file" in files and files["score_dict_file"] is not None:
if Path(files["score_dict_file"]).exists():
Expand Down
8 changes: 4 additions & 4 deletions deckard/base/files/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,17 @@ def __call__(self):
files = dict(self.get_filenames())
return files

def get_filenames(self):
def get_filenames(self, **kwargs):
files = deepcopy(self.files)
files.update(**kwargs)
files = self._set_filenames(**files)
return files

def _set_filenames(self, **kwargs):
name = self.name
stage = self.stage
if hasattr(self, "files"):
kwargs.update(self.files)
files = dict(kwargs)
files = self.files
files.update(**kwargs)
new_files = {}
directory = self.directory
reports = self.reports
Expand Down
4 changes: 1 addition & 3 deletions deckard/base/model/art_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class ArtPipelineStage:
kwargs: dict = field(default_factory=dict)

def __init__(self, name=None, **kwargs):
logger.info(f"Creating pipeline stage: {name} kwargs: {kwargs}")
self.name = name
kwargs.update(**kwargs.pop("kwargs", {}))
self.kwargs = kwargs
Expand Down Expand Up @@ -69,7 +68,6 @@ def __call__(self):

device_type = "gpu" if torch.cuda.is_available() else "cpu"
if device_type == "gpu":
logger.info("Using GPU")
number_of_devices = torch.cuda.device_count()
num = randint(0, number_of_devices - 1)
device = torch.device(f"cuda:{num}")
Expand Down Expand Up @@ -205,7 +203,7 @@ def __call__(self, model: object, data: list) -> BaseEstimator:
name = params.pop("name", None)
kwargs = params.pop("kwargs", {})
else:
name = None
name = self.library
kwargs = {}
pre_def = []
post_def = []
Expand Down
46 changes: 44 additions & 2 deletions deckard/base/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def __call__(self, data: list, model: object, library=None):
if library in sklearn_dict.keys():
pass
elif library in torch_dict.keys():
pass
trainer["nb_epochs"] = trainer.pop(
"nb_epochs",
trainer.pop("epochs", trainer.pop("nb_epoch", 10)),
)
elif library in keras_dict.keys():
pass
elif library in tensorflow_dict.keys():
Expand Down Expand Up @@ -142,7 +145,7 @@ def __call__(self, data: list, model: object, library=None):
model.fit(data[0], data[2], **trainer)
end = process_time_ns()
end_timestamp = time()
except Exception as e:
except AttributeError as e:
raise e
except RuntimeError as e: # pragma: no cover
if "eager mode" in str(e) and library in tensorflow_dict.keys():
Expand Down Expand Up @@ -176,6 +179,45 @@ def __call__(self, data: list, model: object, library=None):
model.fit(data[0], data[2], **trainer)
end = process_time_ns()
end_timestamp = time()
elif "disable eager execution" in str(e):
logger.warning("Disabling eager execution for Tensorflow.")
import tensorflow as tf

tf.compat.v1.disable_eager_execution()
start = process_time_ns()
start_timestamp = time()
model.fit(data[0], data[2], **trainer)
end = process_time_ns()
end_timestamp = time()
elif "out of memory" in set(e).lower() and library in torch_dict.keys():
import torch

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
# Pick the GPU with the most free memory
free_memory = [
torch.cuda.memory_reserved(i) - torch.cuda.memory_allocated(i)
for i in range(torch.cuda.device_count())
]
device = f"cuda:{free_memory.index(max(free_memory))}"
data[0] = torch.from_numpy(data[0])
data[1] = torch.from_numpy(data[1])
data[0] = torch.Tensor.float(data[0])
data[1] = torch.Tensor.float(data[1])
data[0].to(device)
data[2] = torch.from_numpy(data[2])
data[3] = torch.from_numpy(data[3])
data[2] = torch.Tensor.float(data[2])
data[3] = torch.Tensor.float(data[3])
data[2].to(device)
model.model.to(device) if hasattr(model, "model") else model.to(device)
start = process_time_ns()
start_timestamp = time()
model.fit(data[0], data[2], **trainer)
end = process_time_ns()
end_timestamp = time()

else:
raise e
time_dict = {
Expand Down
Loading
Loading