Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
simplymathematics committed Nov 29, 2023
1 parent c3e7f14 commit e362f85
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions deckard/base/model/art_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __call__(self):
tuple(torch_dict.values()),
):
import torch

device_type = "gpu" if torch.cuda.is_available() else "cpu"
if device_type == "gpu":
logger.info("Using GPU")
Expand Down
2 changes: 1 addition & 1 deletion deckard/base/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(self, **kwargs):
logger.info(f"Initializing model trainer with kwargs {kwargs}")
self.kwargs = kwargs


def __call__(self, data: list, model: object, library=None):
logger.info(f"Training model {model} with fit params: {self.kwargs}")
device = str(model.device) if hasattr(model, "device") else "cpu"
Expand Down Expand Up @@ -179,6 +178,7 @@ def __call__(self, data: list, model: object, library=None):
end_timestamp = time()
elif "should be the same" in str(e).lower():
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data[0] = torch.from_numpy(data[0])
data[1] = torch.from_numpy(data[1])
Expand Down

0 comments on commit e362f85

Please sign in to comment.