Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pytorch example #185

Merged
merged 243 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
243 commits
Select commit Hold shift + click to select a range
67eed77
fix bug in scoring failure
simplymathematics Dec 2, 2023
ff9cbbd
update power example to support bit-depth search
simplymathematics Dec 2, 2023
e913c96
update result directories
simplymathematics Dec 2, 2023
c8e5a06
revert changes to example/power
simplymathematics Dec 2, 2023
8928e7e
add bit depth example
simplymathematics Dec 2, 2023
97c178e
revert directory changes for power example
simplymathematics Dec 2, 2023
e302e38
fixed directory for mnist in power example
simplymathematics Dec 2, 2023
076dce7
fixed directory for mnist in power example
simplymathematics Dec 2, 2023
b4a5655
use the latest torch torchvision torchaudio
salehsedghpour Dec 2, 2023
96b071c
Merge branch 'main' of github.com:simplymathematics/deckard into bit-…
simplymathematics Dec 3, 2023
7c85be1
update workflow to push on PR
simplymathematics Dec 3, 2023
b09f54c
Merge branch 'bit-depth-power-example' of github.com:simplymathematic…
simplymathematics Dec 3, 2023
c2af1a4
removed double stage folders in log folder
simplymathematics Dec 3, 2023
1712c3b
+ more epochs for cifar100
simplymathematics Dec 3, 2023
d428e48
changed intervals from uniform to log uniform, made learning rate ran…
simplymathematics Dec 3, 2023
74df959
strip whitespace, covert o numeric in compile script
simplymathematics Dec 4, 2023
d8fbf1a
update git ignores
simplymathematics Dec 4, 2023
183ba2a
suport nb_epoch as defence choice
simplymathematics Dec 5, 2023
aeb2b72
remove adv_success from requirements
simplymathematics Dec 5, 2023
6b94128
add "NaN" to nones
simplymathematics Dec 5, 2023
416ed65
update afr script
simplymathematics Dec 5, 2023
a289b18
update mnist .dvc cache
simplymathematics Dec 5, 2023
48a1f8a
updat cifar10 plots
simplymathematics Dec 5, 2023
aac2227
uncomment paretoset in plotting
simplymathematics Dec 5, 2023
2b3752c
fix default defence bug and realtive pathing in compile script
simplymathematics Dec 5, 2023
691fae2
moved plots to subfolder
simplymathematics Dec 5, 2023
687f454
better configuration support
simplymathematics Dec 5, 2023
621d2c1
fix compile script bug
simplymathematics Dec 6, 2023
93befed
update compile and plots yaml for power example
simplymathematics Dec 6, 2023
23fb2e1
fix compile bug
simplymathematics Dec 6, 2023
1c47dfd
update plots
simplymathematics Dec 6, 2023
92bcafd
include plot files in dvc
simplymathematics Dec 6, 2023
aacfb00
update afr to read from conf file
simplymathematics Dec 6, 2023
7a3d1ec
linting
simplymathematics Dec 6, 2023
08a812a
linting
simplymathematics Dec 6, 2023
87660a4
Merge branch 'main' of github.com:simplymathematics/deckard into fix-…
simplymathematics Dec 6, 2023
269cf98
update pytorch example
simplymathematics Dec 6, 2023
957c65d
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 6, 2023
108ff15
update pytorch afr.yaml (not working)
simplymathematics Dec 7, 2023
6868dd2
split cleaning from plotting, but only working for examples/pytorch/m…
simplymathematics Dec 7, 2023
cef183c
working cleaning script
simplymathematics Dec 7, 2023
26c617e
fix pytorch examples with new clean script
simplymathematics Dec 7, 2023
00e1a0a
remove debug check from parse_results
simplymathematics Dec 7, 2023
3a051fa
make deckard a depedendency of the parsing script
simplymathematics Dec 7, 2023
0f15e04
made models.sh easier to read
simplymathematics Dec 7, 2023
2b73578
update afr for pytorch example
simplymathematics Dec 7, 2023
1987660
update power example
simplymathematics Dec 8, 2023
0ac3f5d
update dvc.lock for pytorch example
simplymathematics Dec 8, 2023
da05855
update pytorch/cifar100
simplymathematics Dec 8, 2023
e17706a
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 8, 2023
3c0deb5
update power/plots (not working)
simplymathematics Dec 8, 2023
b2b9157
add docstrings to plots.py
simplymathematics Dec 8, 2023
556898b
update power example with merge script
simplymathematics Dec 9, 2023
85dea7e
add power data
Dec 10, 2023
717b3a9
update configs
simplymathematics Dec 10, 2023
d4792c7
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 10, 2023
3aee986
add combined plots
simplymathematics Dec 11, 2023
269ad24
update afr models
simplymathematics Dec 11, 2023
f20f6d7
added support for dummy variables in afr
simplymathematics Dec 12, 2023
b147fc1
++combined_plots.py and fix afr bug
simplymathematics Dec 12, 2023
2da27ed
add cifar100 l4 power data with commenting everything else
Dec 12, 2023
55e4564
add varepsilon to attack params
simplymathematics Dec 12, 2023
d71b198
add dummy variables
simplymathematics Dec 12, 2023
50ff188
fix rounding bug
simplymathematics Dec 13, 2023
80615f7
update to newest plots
simplymathematics Dec 13, 2023
6376e0c
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Dec 13, 2023
4f3176a
newest plots for power example
simplymathematics Dec 13, 2023
071db7f
linting
simplymathematics Dec 13, 2023
139864d
removed old afr file
simplymathematics Dec 13, 2023
fc4ec91
linting
simplymathematics Dec 13, 2023
8ea0056
Merge branch 'fix-compile-script' of github.com:simplymathematics/dec…
simplymathematics Jan 15, 2024
49ef6c1
update conf
simplymathematics Jan 15, 2024
62aefae
\Merge branch 'fix-compile-script' of github.com:simplymathematics/de…
simplymathematics Jan 15, 2024
4a34fd1
fixed kepler script bug
simplymathematics Jan 15, 2024
fdd2e8a
linting
simplymathematics Jan 15, 2024
0cc7b42
linting
simplymathematics Jan 15, 2024
40047b0
linting
simplymathematics Jan 15, 2024
4826f0a
linting
simplymathematics Jan 15, 2024
3d93817
linting
simplymathematics Jan 15, 2024
8c78260
linting
simplymathematics Jan 15, 2024
105d051
linting
simplymathematics Jan 15, 2024
14e32a8
linting
simplymathematics Jan 15, 2024
f25d72b
linting
simplymathematics Jan 15, 2024
358e759
fixed cifar100 pytorch example script
simplymathematics Jan 15, 2024
4b0ff7d
more resilient wait and cleaning scripts
simplymathematics Jan 15, 2024
4345985
+GZIP example
simplymathematics Jan 22, 2024
c49cd72
bug fixes
simplymathematics Jan 22, 2024
44d6145
fix latex nan bug
simplymathematics Jan 23, 2024
ef5612d
bug fixes
simplymathematics Jan 23, 2024
dc7c478
add index to compilation csv
simplymathematics Jan 23, 2024
9b57058
better defence merging
simplymathematics Jan 23, 2024
ff9d924
fixed bug where x,y scale are None
simplymathematics Jan 23, 2024
9dcc132
update cifar100 confs
simplymathematics Jan 23, 2024
ce45cde
fixed cleaning bug
simplymathematics Jan 23, 2024
285fcea
fixed afr plot rendering bug
simplymathematics Jan 24, 2024
90e41f1
add check for negative predict time
simplymathematics Jan 24, 2024
0254a92
update configs
simplymathematics Jan 24, 2024
f82e1d8
fixed failure rate bug and updated confs
simplymathematics Jan 24, 2024
0971505
Merge branch 'cifar100' of github.com:simplymathematics/deckard into …
simplymathematics Jan 24, 2024
b09a423
change plot default from eps to pdf
simplymathematics Jan 24, 2024
80998a8
fix bug in calculating failure rate when attack size != train size
simplymathematics Jan 24, 2024
c76394e
update mnist confs
simplymathematics Jan 25, 2024
5491358
specify attack size at the command line
simplymathematics Jan 25, 2024
74fbe66
linting
simplymathematics Jan 26, 2024
ab6459f
update all plot configs
simplymathematics Jan 26, 2024
351f597
Fix compile script (#172)
simplymathematics Jan 24, 2024
8f08882
added dummy vars, fixed plots
simplymathematics Jan 26, 2024
66515ff
fix afr.py bugs
simplymathematics Jan 26, 2024
33eeaf4
bad merge?
simplymathematics Jan 26, 2024
66f8319
merge
simplymathematics Jan 26, 2024
6625ae5
fix bugs
simplymathematics Jan 26, 2024
ae10c9f
Merge branch 'main' of github.com:simplymathematics/deckard into note…
simplymathematics Jan 26, 2024
cf265db
linting
simplymathematics Jan 26, 2024
1254a67
linting
simplymathematics Jan 26, 2024
3060e4c
linting
simplymathematics Jan 26, 2024
1da196c
linting
simplymathematics Jan 26, 2024
6609a53
linting
simplymathematics Jan 26, 2024
098af1c
linting
simplymathematics Jan 26, 2024
0f360f4
update linter
simplymathematics Jan 26, 2024
3150f9e
update linter
simplymathematics Jan 26, 2024
4455ef7
update linter
simplymathematics Jan 26, 2024
d17765e
update linter
simplymathematics Jan 26, 2024
6b6caef
linting
simplymathematics Jan 27, 2024
5978b44
update setup, .gitignore
simplymathematics Jan 30, 2024
e8257ed
fix failure rate bug (again)
simplymathematics Jan 30, 2024
aaea279
most up-to-date plots
simplymathematics Jan 31, 2024
bf5eaf8
update failure rate from h to f in pytorch examples
simplymathematics Feb 1, 2024
609bc8d
remove intercept and scale parameters from afr plots
simplymathematics Feb 1, 2024
56b3aa3
remove rows where the score is an error
simplymathematics Feb 1, 2024
ca6c7cc
update plolt confs for pytorch example
simplymathematics Feb 1, 2024
36ea908
allow setting filename from command line of AFR script
simplymathematics Feb 1, 2024
aa50ddc
plot legend tweaks
simplymathematics Feb 1, 2024
8ebdaa1
linting
simplymathematics Feb 4, 2024
2794006
Merge branch 'main' of github.com:simplymathematics/deckard into note…
simplymathematics Feb 4, 2024
5a3226b
linting
simplymathematics Feb 4, 2024
24b1b36
merge with main
simplymathematics Feb 4, 2024
de5ad1b
Update Dockerfile
simplymathematics Feb 4, 2024
f5b0287
Merge branch 'simplymathematics-workflow-diskspace-patch' of https://…
simplymathematics Feb 4, 2024
d9bd067
Merge branch 'notebook-branch' of https://github.com/simplymathematic…
simplymathematics Feb 4, 2024
b331acd
update dockerfile
simplymathematics Feb 4, 2024
5a8f88b
lintin
simplymathematics Feb 4, 2024
2b3cf6e
update gzip configs
simplymathematics Feb 4, 2024
f3123d9
better logging
simplymathematics Feb 4, 2024
d300763
add url validation for data pipeline
simplymathematics Feb 4, 2024
5e90ee8
git rm
simplymathematics Feb 4, 2024
80d6eaa
update truthseeker yaml
simplymathematics Feb 4, 2024
af7066d
add gzip .gitignore
simplymathematics Feb 4, 2024
db34da2
gzip dvc changes
simplymathematics Feb 4, 2024
a0de7bf
add sampling during training
simplymathematics Feb 5, 2024
533ad3e
more resilient find_best script
simplymathematics Feb 5, 2024
670d81a
fixed bug with finding min/max when data is non-numeric
simplymathematics Feb 5, 2024
32b13e7
add support for url/local datasets
simplymathematics Feb 5, 2024
bd2a7a5
add column dropping for data parsing
simplymathematics Feb 6, 2024
ac07dff
find best for multi-objective search
simplymathematics Feb 6, 2024
a773639
better cleaning for experiments without attacks
simplymathematics Feb 6, 2024
0874768
better filetype support when plotting
simplymathematics Feb 6, 2024
266f006
load distance matrix from disk (optionally)
simplymathematics Feb 6, 2024
7b4a517
update confs for gzip
simplymathematics Feb 6, 2024
9c178c4
Merge branches 'notebook-branch' and 'notebook-branch' of github.com:…
simplymathematics Feb 6, 2024
d13412e
update default params
simplymathematics Feb 6, 2024
21e7699
Merge branch 'main' of github.com:simplymathematics/deckard into note…
simplymathematics Mar 27, 2024
c78d51d
update .gitignore, add some models to torch_example
simplymathematics Apr 3, 2024
414198e
refactor confs
simplymathematics Apr 3, 2024
0679acf
update pytorch confs
simplymathematics Apr 3, 2024
d7e0354
minor bug fixes
simplymathematics Apr 3, 2024
43cf4d8
fix small bugs
simplymathematics Apr 3, 2024
5a1705a
add resnet examples
simplymathematics Apr 3, 2024
022cb38
update pytorch experiment confs
simplymathematics Apr 4, 2024
128f9a1
Merge branch 'main' of github.com:simplymathematics/deckard into fix-…
simplymathematics Apr 4, 2024
4836501
config changes
simplymathematics May 10, 2024
44ee1f4
increase timeline resolution
simplymathematics May 13, 2024
b73b006
config changes
simplymathematics May 13, 2024
da7fb05
removed dvc cruft
simplymathematics May 13, 2024
421e052
update cost normalization calculation
simplymathematics May 13, 2024
673fd9b
update afr plotting
simplymathematics May 15, 2024
d27dfea
remove partial effects from pytorch config, add support for aalen add…
simplymathematics May 15, 2024
8148ba4
update dvc file for pytorch plots
simplymathematics May 15, 2024
3f6fb2e
better error handling
simplymathematics May 15, 2024
392b253
Merge branch 'fix-cifar100-accuracy' of github.com:simplymathematics/…
simplymathematics May 15, 2024
cbdc65f
update survival plots
simplymathematics May 16, 2024
466d7d6
merge with soft reset
simplymathematics May 16, 2024
0166f45
++plots to newest overleaf
simplymathematics May 20, 2024
1e28adf
dummy config changes
simplymathematics May 22, 2024
f29effc
config change
simplymathematics May 22, 2024
a857435
re-run plot dvc.yaml
simplymathematics May 22, 2024
92b89dc
config changes
simplymathematics Jun 4, 2024
e0e5671
update .gitignore
simplymathematics Jun 5, 2024
7237859
fixed file bug
simplymathematics Jun 5, 2024
9531165
fix keyword bug
simplymathematics Jun 5, 2024
35f382f
change Coefficient plots to sym log scale
simplymathematics Jun 5, 2024
7a2da3d
and latex to result parsing
simplymathematics Jun 5, 2024
f777481
added support for forloop stage parsing
simplymathematics Jun 5, 2024
46a9be5
streamline some code
simplymathematics Jun 5, 2024
9a08413
config changes + predicting the metric with a model config chosen by key
simplymathematics Jun 13, 2024
5e87696
add plots yaml again
simplymathematics Jun 13, 2024
c35c94c
stop tracking cifar100.yaml
simplymathematics Jun 20, 2024
5dadbcf
fix uncaught exception
simplymathematics Jun 25, 2024
bc80872
make dataset formatting more robust
simplymathematics Jun 25, 2024
d4499db
add a type check
simplymathematics Jun 25, 2024
e537dce
update default configs for each dataset to use env vars instead of ha…
simplymathematics Jun 25, 2024
8a850ba
reconfigure the dvc pipeline for re-running and changing the number o…
simplymathematics Jun 25, 2024
e319871
config changes
simplymathematics Jul 2, 2024
faaa05b
better pytorch out of memory handling
simplymathematics Jul 2, 2024
1ad4118
add normalization to trash metric
simplymathematics Jul 2, 2024
704b883
better convergence error handline
simplymathematics Jul 6, 2024
68a7e63
config changes
simplymathematics Jul 6, 2024
c166940
linting
simplymathematics Jul 6, 2024
909a5d3
linting
simplymathematics Jul 6, 2024
0e18db2
stop tracking cifar100.yaml
simplymathematics Jul 10, 2024
762067d
use pretrained models as initial weights
simplymathematics Jul 18, 2024
b64116d
better error handling
simplymathematics Jul 18, 2024
9b3b0db
remove cruft
simplymathematics Jul 18, 2024
ca0dd7b
delete old configs
simplymathematics Jul 18, 2024
1903461
rename parameter for clarity
simplymathematics Jul 18, 2024
3d71777
moved from plots to conf folder
simplymathematics Jul 18, 2024
e2f1414
update dvc to work with last commit
simplymathematics Jul 18, 2024
6d3a203
config changes for pytorch example
simplymathematics Jul 18, 2024
a719768
linting
simplymathematics Jul 18, 2024
1d07aee
update torch example to use nb_epochs instead of nb_epoch
simplymathematics Jul 19, 2024
ab8a119
linting
simplymathematics Jul 30, 2024
4060aa1
config updates
simplymathematics Jul 30, 2024
031ae8c
Merge branch 'main' of github.com:simplymathematics/deckard into upda…
simplymathematics Jul 30, 2024
ef8a8cd
fixed bad merge
simplymathematics Jul 30, 2024
c185a5f
linting
simplymathematics Jul 30, 2024
afa47cb
update .gitignore
simplymathematics Jul 30, 2024
2796641
stop tracking params file
simplymathematics Jul 30, 2024
59b777f
removed overly verbose logging
simplymathematics Jul 31, 2024
ac972da
broke up attack scripts for better dvc tracking
simplymathematics Jul 31, 2024
898161e
update pytorch confs
simplymathematics Aug 13, 2024
abf273d
add hashable object, better art type checking
simplymathematics Aug 13, 2024
0500458
created hashable object for inheritance
simplymathematics Aug 13, 2024
9b19c8e
changed AFR to AFT
simplymathematics Aug 13, 2024
6ca1c69
add arbitrary set() dictionary to catplot
simplymathematics Aug 13, 2024
8a2f361
add numeric casting to afr
simplymathematics Aug 13, 2024
0240e29
fix logging bug
simplymathematics Aug 13, 2024
db2c9ed
linting
simplymathematics Aug 13, 2024
658e987
better art typing
simplymathematics Aug 13, 2024
40594bb
hashable object
simplymathematics Aug 13, 2024
baee84d
linting
simplymathematics Aug 13, 2024
4490435
Merge branch 'main' of github.com:simplymathematics/deckard into upda…
simplymathematics Aug 13, 2024
a38ed17
fixed hashing bug
simplymathematics Aug 13, 2024
d674b97
fix bug
simplymathematics Aug 13, 2024
171efbd
linting
simplymathematics Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions deckard/base/attack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __init__(
self.attack_size = attack_size
self.init = AttackInitializer(model, name, **init)
self.kwargs = kwargs
logger.info("Instantiating Attack with id: {}".format(self.__hash__()))

def __hash__(self):
return int(my_hash(self), 16)
Expand Down Expand Up @@ -300,7 +299,6 @@ def __init__(
self.attack_size = attack_size
self.init = AttackInitializer(model, name, **init)
self.kwargs = kwargs
logger.info("Instantiating Attack with id: {}".format(self.__hash__()))

def __hash__(self):
return int(my_hash(self), 16)
Expand Down Expand Up @@ -493,7 +491,6 @@ def __init__(
self.attack_size = attack_size
self.init = AttackInitializer(model, name, **init)
self.kwargs = kwargs
logger.info("Instantiating Attack with id: {}".format(self.__hash__()))

def __hash__(self):
return int(my_hash(self), 16)
Expand Down Expand Up @@ -618,7 +615,6 @@ def __init__(
f"kwargs must be of type DictConfig or dict. Got {type(kwargs)}",
)
self.kwargs = kwargs
logger.info("Instantiating Attack with id: {}".format(self.__hash__()))

def __hash__(self):
return int(my_hash(self), 16)
Expand Down Expand Up @@ -813,7 +809,6 @@ def __init__(
kwargs.update(**kwargs.pop("kwargs"))
self.kwargs = kwargs
self.name = name if name is not None else my_hash(self)
logger.info("Instantiating Attack with id: {}".format(self.name))

def __call__(
self,
Expand Down
1 change: 0 additions & 1 deletion deckard/base/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def save(self, data, filename):
:param filename: str
"""
if filename is not None:
logger.info(f"Saving data to {filename}")
suffix = Path(filename).suffix
Path(filename).parent.mkdir(parents=True, exist_ok=True)
if isinstance(data, dict):
Expand Down
9 changes: 0 additions & 9 deletions deckard/base/data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ class SklearnDataGenerator:
kwargs: dict = field(default_factory=dict)

def __init__(self, name, **kwargs):
logger.info(
f"Instantiating {self.__class__.__name__} with name={name} and kwargs={kwargs}",
)
self.name = name
self.kwargs = {k: v for k, v in kwargs.items() if v is not None}

Expand Down Expand Up @@ -91,9 +88,6 @@ class TorchDataGenerator:
kwargs: dict = field(default_factory=dict)

def __init__(self, name, path=None, **kwargs):
logger.info(
f"Instantiating {self.__class__.__name__} with name={name} and kwargs={kwargs}",
)
self.name = name
self.path = path
self.kwargs = {k: v for k, v in kwargs.items() if v is not None}
Expand Down Expand Up @@ -179,9 +173,6 @@ class KerasDataGenerator:
kwargs: dict = field(default_factory=dict)

def __init__(self, name, **kwargs):
logger.info(
f"Instantiating {self.__class__.__name__} with name={name} and kwargs={kwargs}",
)
self.name = name
self.kwargs = {k: v for k, v in kwargs.items() if v is not None}

Expand Down
2 changes: 0 additions & 2 deletions deckard/base/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(self, **kwargs):
self.kwargs = kwargs

def __call__(self, data: list, model: object, library=None):
logger.info(f"Training model {model} with fit params: {self.kwargs}")
device = str(model.device) if hasattr(model, "device") else "cpu"
trainer = self.kwargs
if library in sklearn_dict.keys():
Expand All @@ -91,7 +90,6 @@ def __call__(self, data: list, model: object, library=None):
try:
start = process_time_ns()
start_timestamp = time()
logger.info(f"Fitting type(model): {type(model)} with kwargs {trainer}")
model.fit(data[0], data[2], **trainer)
end = process_time_ns()
end_timestamp = time()
Expand Down
30 changes: 13 additions & 17 deletions deckard/base/model/sklearn_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,10 @@ class SklearnModelPipelineStage:
kwargs: dict = field(default_factory=dict)

def __init__(self, name, stage_name, **kwargs):
logger.debug(
f"Instantiating {self.__class__.__name__} with name={name} and kwargs={kwargs}",
)
self.name = name
self.kwargs = kwargs
self.stage_name = stage_name

def __hash__(self):
return int(my_hash(self), 16)

def __call__(self, model):
logger.debug(
f"Calling SklearnModelPipelineStage with name={self.name} and kwargs={self.kwargs}",
Expand All @@ -76,7 +70,7 @@ def __call__(self, model):
stage_name = self.stage_name if self.stage_name is not None else name
while "kwargs" in kwargs:
kwargs.update(**kwargs.pop("kwargs"))
if "art." in str(type(model)):
if str(type(model)).startswith("art."):
assert isinstance(
model.model,
BaseEstimator,
Expand All @@ -102,7 +96,6 @@ class SklearnModelPipeline:
pipeline: Dict[str, SklearnModelPipelineStage] = field(default_factory=dict)

def __init__(self, **kwargs):
logger.debug(f"Instantiating {self.__class__.__name__} with kwargs={kwargs}")
pipe = {}
while "kwargs" in kwargs:
pipe.update(**kwargs.pop("kwargs"))
Expand Down Expand Up @@ -145,12 +138,12 @@ def __len__(self):
else:
return 0

def __hash__(self):
return int(my_hash(self), 16)

def __iter__(self):
return iter(self.pipeline)

def __hash__(self):
return int(my_hash(self), 16)

def __call__(self, model):
params = deepcopy(asdict(self))
pipeline = params.pop("pipeline")
Expand All @@ -172,7 +165,7 @@ def __call__(self, model):
elif isinstance(stage, SklearnModelPipelineStage):
model = stage(model=model)
elif hasattr(stage, "fit"):
if "art." in str(type(model)):
if str(type(model)).startswith("art."):
assert isinstance(
model.model,
BaseEstimator,
Expand All @@ -184,12 +177,15 @@ def __call__(self, model):
), f"model must be a sklearn estimator. Got {type(model)}"
if not isinstance(model, Pipeline) and "art." not in str(type(model)):
model = Pipeline([("model", model)])
elif "art." in str(type(model)) and not isinstance(
elif str(type(model)).startswith("art.") and not isinstance(
model.model,
Pipeline,
):
model.model = Pipeline([("model", model.model)])
elif "art." in str(type(model)) and isinstance(model.model, Pipeline):
elif str(type(model)).startswith("art.") and isinstance(
model.model,
Pipeline,
):
model.model.steps.insert(-2, [stage, model.model])
else:
model.steps.insert(-2, [stage, model])
Expand All @@ -213,6 +209,9 @@ class SklearnModelInitializer:
pipeline: SklearnModelPipeline = field(default_factory=None)
kwargs: Union[dict, None] = field(default_factory=dict)

def __hash__(self):
return int(my_hash(self), 16)

def __init__(self, data, model=None, library="sklearn", pipeline={}, **kwargs):
self.data = data
self.model = model
Expand Down Expand Up @@ -267,6 +266,3 @@ def __call__(self):
"fit",
), f"model must have a fit method. Got type {type(model)}"
return model

def __hash__(self):
return int(my_hash(self), 16)
8 changes: 7 additions & 1 deletion deckard/base/utils/hashing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from hashlib import md5
from collections import OrderedDict
from typing import NamedTuple, Union
from dataclasses import asdict, is_dataclass
from dataclasses import asdict, is_dataclass, dataclass
from omegaconf import DictConfig, OmegaConf, SCMode, ListConfig
from copy import deepcopy
import logging
Expand Down Expand Up @@ -71,3 +71,9 @@ def to_dict(obj: Union[dict, OrderedDict, NamedTuple]) -> dict:

def my_hash(obj: Union[dict, OrderedDict, NamedTuple]) -> str:
return md5(str(to_dict(obj)).encode("utf-8")).hexdigest()


@dataclass
class Hashable:
def __hash__(self):
return int(my_hash(self), 16)
10 changes: 6 additions & 4 deletions deckard/layers/afr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def ccl(p):
ax = plt.gca()
T = model.duration_col
E = model.event_col

# Cast df to numeric DataFrame
for col in df.columns:
df[col] = pd.to_numeric(df[col], errors="raise")
# Drop NaNs
df = df.dropna()
predictions_at_t0 = np.clip(
1 - model.predict_survival_function(df, times=[t0]).T.squeeze(),
1e-10,
Expand Down Expand Up @@ -347,8 +351,6 @@ def plot_aft(
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
# symlog-scale the x-axis
# ax.set_xscale("linear")
ax.get_figure().tight_layout()
ax.get_figure().savefig(file)
plt.gcf().clear()
Expand Down Expand Up @@ -624,7 +626,7 @@ def make_afr_table(
pretty_dataset = dataset.upper()
aft_data = aft_data.round(2)
aft_data.to_csv(folder / "aft_comparison.csv")
logger.info(f"Saved AFR comparison to {folder / 'aft_comparison.csv'}")
logger.info(f"Saved AFT comparison to {folder / 'aft_comparison.csv'}")
aft_data = aft_data.round(2)
aft_data.fillna("--", inplace=True)
aft_data.to_latex(
Expand Down
3 changes: 1 addition & 2 deletions deckard/layers/clean_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def drop_rows_without_results(
logger.info(f"Shape of data before data before dropping na: {data.shape}")
data.dropna(axis=0, subset=[col], inplace=True)
after = data.shape[0]
logger.info(f"Shape of data before data after dropping na: {data.shape}")
logger.info(f"Shape of data after data after dropping na: {data.shape}")
percent_change = (before - after) / before * 100
if percent_change > 5:
# input(f"{percent_change:.2f}% of data dropped for {col}. Press any key to continue.")
Expand Down Expand Up @@ -593,7 +593,6 @@ def clean_data_for_plotting(
data = fill_na(data, fillna)
data = replace_strings_in_data(data, replace_dict)
data = replace_strings_in_columns(data, col_replace_dict)

if len(pareto_dict) > 0:
data = find_pareto_set(data, pareto_dict)
return data
Expand Down
3 changes: 3 additions & 0 deletions deckard/layers/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def cat_plot(
file = Path(file).with_suffix(filetype)
logger.info(f"Rendering graph {file}")
data = digitize_cols(data, digitize)
set_ = kwargs.pop("set", {})
if hue is not None:
data = data.sort_values(by=[hue, x, y])
logger.debug(
Expand Down Expand Up @@ -162,6 +163,8 @@ def cat_plot(
graph.set(xlim=x_lim)
if y_lim is not None:
graph.set(ylim=y_lim)
if len(set_) > 0:
graph.set(**set_)
graph.tight_layout()
graph.savefig(folder / file)
plt.gcf().clear()
Expand Down
16 changes: 8 additions & 8 deletions examples/power/conf/afr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fillna:
weibull:
plot:
file : weibull_aft.pdf
title : Weibull AFR Model
title : Weibull AFT Model
labels:
"Intercept: rho_": "$\\rho$"
"Intercept: lambda_": "$\\lambda$"
Expand All @@ -36,7 +36,7 @@ weibull:
- "file": "weibull_epochs_partial_effect.pdf"
"covariate_array": "model.trainer.np_epochs"
"values_array": [1,10,25,50]
"title": "$S(t)$ for Weibull AFR"
"title": "$S(t)$ for Weibull AFT"
"ylabel": "$\\mathbb{P}~(T>t)$"
"xlabel": "Time $t$ (seconds)"
"legend_kwargs": {
Expand All @@ -46,7 +46,7 @@ weibull:
cox:
plot:
file : cox_aft.pdf
title : Cox AFR Model
title : Cox AFT Model
labels:
"data.sample.random_state": "Random State"
"atk_value": "Attack Strength"
Expand All @@ -65,7 +65,7 @@ cox:
- "file": "cox_epochs_partial_effect.pdf"
"covariate_array": "model.trainer.np_epochs"
"values_array": [1,10,25,50]
"title": "$S(t)$ for Cox AFR"
"title": "$S(t)$ for Cox AFT"
"ylabel": "$\\mathbb{P}~(T>t)$"
"xlabel": "Time $t$ (seconds)"
"legend_kwargs": {
Expand All @@ -75,7 +75,7 @@ cox:
log_logistic:
plot:
file : log_logistic_aft.pdf
title : Log logistic AFR Model
title : Log logistic AFT Model
labels:
"Intercept: beta_": "$\\beta$"
"Intercept: alpha_": "$\\alpha$"
Expand All @@ -96,7 +96,7 @@ log_logistic:
- "file": "log_logistic_epochs_partial_effect.pdf"
"covariate_array": "model.trainer.np_epochs"
"values_array": [1,10,25,50]
"title": "$S(t)$ for Log-Logistic AFR"
"title": "$S(t)$ for Log-Logistic AFT"
"ylabel": "$\\mathbb{P}~(T>t)$"
"xlabel": "Time $t$ (seconds)"
"legend_kwargs": {
Expand All @@ -106,7 +106,7 @@ log_logistic:
log_normal:
plot:
file : log_normal_aft.pdf
title : Log Normal AFR Model
title : Log Normal AFT Model
labels:
"Intercept: sigma_": "$\\sigma$"
"Intercept: mu_": "$\\mu$"
Expand All @@ -127,7 +127,7 @@ log_normal:
- "file": "log_normal_epochs_partial_effect.pdf"
"covariate_array": "model.trainer.np_epochs"
"values_array": [1,10,25,50]
"title": "$S(t)$ for Log-Normal AFR"
"title": "$S(t)$ for Log-Normal AFT"
"ylabel": "$\\mathbb{P}~(T>t)$"
"xlabel": "Time $t$ (seconds)"
"legend_kwargs": {
Expand Down
Loading
Loading