From 6b60297fc6632a5322d3a06bcb94ff805d130b31 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 22 Sep 2024 17:06:58 -0400 Subject: [PATCH 1/2] Fix typo with torch cuda is_available call Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/mixins/tensorable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mixins/tensorable.py b/src/mixins/tensorable.py index d07a2b5..7393837 100644 --- a/src/mixins/tensorable.py +++ b/src/mixins/tensorable.py @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs): def _cuda(self, T: torch.Tensor, do_cuda: bool | None = None): if do_cuda is None: - do_cuda = self.do_cuda if hasattr(self, "do_cuda") else torch.cuda.is_available + do_cuda = self.do_cuda if hasattr(self, "do_cuda") else torch.cuda.is_available() return T.cuda() if do_cuda else T From 8a303a3bbb83f603dec1aa979afcf1a5be03de9b Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 22 Sep 2024 17:07:14 -0400 Subject: [PATCH 2/2] Fix typo with torch cuda is_available call Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/mixins/tensorable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mixins/tensorable.py b/src/mixins/tensorable.py index 7393837..c603fe1 100644 --- a/src/mixins/tensorable.py +++ b/src/mixins/tensorable.py @@ -12,7 +12,7 @@ class TensorableMixin: Tensor_T = Union[torch.Tensor, tuple["Tensor_T"], dict[Hashable, "Tensor_T"]] def __init__(self, *args, **kwargs): - self.do_cuda = kwargs.get("do_cuda", torch.cuda.is_available) + self.do_cuda = kwargs.get("do_cuda", torch.cuda.is_available()) def _cuda(self, T: torch.Tensor, do_cuda: bool | None = None): if do_cuda is None: