From c3e7f14d6aef84560b6156fb55bc4b99dfea4e12 Mon Sep 17 00:00:00 2001 From: Charlie Meyers Date: Wed, 29 Nov 2023 22:15:53 +0000 Subject: [PATCH 1/2] fixed device handling bug --- deckard/base/model/art_pipeline.py | 7 ++++++- deckard/base/model/model.py | 3 --- test/pipelines/evasion/.gitignore | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) create mode 100644 test/pipelines/evasion/.gitignore diff --git a/deckard/base/model/art_pipeline.py b/deckard/base/model/art_pipeline.py index 6b540f2d..110524bb 100644 --- a/deckard/base/model/art_pipeline.py +++ b/deckard/base/model/art_pipeline.py @@ -64,8 +64,13 @@ def __call__(self): tuple(torch_dict.values()), ): import torch - device_type = "gpu" if torch.cuda.is_available() else "cpu" + if device_type == "gpu": + logger.info("Using GPU") + logger.info("Model moved to GPU") + device = torch.device("cuda") + model.to(device) + data = [d.to(device) for d in data] model = TorchInitializer( data=data, model=model, diff --git a/deckard/base/model/model.py b/deckard/base/model/model.py index d335ef60..01a10927 100644 --- a/deckard/base/model/model.py +++ b/deckard/base/model/model.py @@ -110,8 +110,6 @@ def __init__(self, **kwargs): logger.info(f"Initializing model trainer with kwargs {kwargs}") self.kwargs = kwargs - # def __hash__(self): - # return int(my_hash(self), 16) def __call__(self, data: list, model: object, library=None): logger.info(f"Training model {model} with fit params: {self.kwargs}") @@ -181,7 +179,6 @@ def __call__(self, data: list, model: object, library=None): end_timestamp = time() elif "should be the same" in str(e).lower(): import torch - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data[0] = torch.from_numpy(data[0]) data[1] = torch.from_numpy(data[1]) diff --git a/test/pipelines/evasion/.gitignore b/test/pipelines/evasion/.gitignore new file mode 100644 index 00000000..47f39c2f --- /dev/null +++ b/test/pipelines/evasion/.gitignore @@ -0,0 +1 @@ +.dvc/* From e362f85c045be40c3bd0baadaf10745d5b58064a Mon Sep 17 00:00:00 2001 From: Charlie Meyers Date: Wed, 29 Nov 2023 22:17:05 +0000 Subject: [PATCH 2/2] linting --- deckard/base/model/art_pipeline.py | 1 + deckard/base/model/model.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deckard/base/model/art_pipeline.py b/deckard/base/model/art_pipeline.py index 110524bb..aece4fb5 100644 --- a/deckard/base/model/art_pipeline.py +++ b/deckard/base/model/art_pipeline.py @@ -64,6 +64,7 @@ def __call__(self): tuple(torch_dict.values()), ): import torch + device_type = "gpu" if torch.cuda.is_available() else "cpu" if device_type == "gpu": logger.info("Using GPU") diff --git a/deckard/base/model/model.py b/deckard/base/model/model.py index 01a10927..5e96557a 100644 --- a/deckard/base/model/model.py +++ b/deckard/base/model/model.py @@ -110,7 +110,6 @@ def __init__(self, **kwargs): logger.info(f"Initializing model trainer with kwargs {kwargs}") self.kwargs = kwargs - def __call__(self, data: list, model: object, library=None): logger.info(f"Training model {model} with fit params: {self.kwargs}") device = str(model.device) if hasattr(model, "device") else "cpu" @@ -179,6 +178,7 @@ def __call__(self, data: list, model: object, library=None): end_timestamp = time() elif "should be the same" in str(e).lower(): import torch + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data[0] = torch.from_numpy(data[0]) data[1] = torch.from_numpy(data[1])