Skip to content

Commit

Permalink
Device handling (#150)
Browse files Browse the repository at this point in the history
* fixed device handling bug for devices with unused gpu
* fixed device handling bug for np arrays in data
  • Loading branch information
simplymathematics authored Nov 29, 2023
1 parent e0b6981 commit ecf3bf5
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deckard/base/model/art_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from art.estimators import BaseEstimator
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

import numpy as np
from .keras_models import KerasInitializer, keras_dict # noqa F401
from .tensorflow_models import ( # noqa F401
TensorflowV1Initializer,
Expand Down Expand Up @@ -71,6 +71,8 @@ def __call__(self):
logger.info("Model moved to GPU")
device = torch.device("cuda")
model.to(device)
if isinstance(data[0][0], np.ndarray):
data = [torch.from_numpy(d).to(device) for d in data]
data = [d.to(device) for d in data]
model = TorchInitializer(
data=data,
Expand Down

0 comments on commit ecf3bf5

Please sign in to comment.