From 6cbc9dfb9174ec38909a3d0e46f584ce109777cf Mon Sep 17 00:00:00 2001 From: Ishan Dutta Date: Tue, 4 Apr 2023 18:58:39 +0530 Subject: [PATCH] Remove NumPy from Callback scripts (#17267) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/pytorch/callbacks/early_stopping.py | 3 +-- src/lightning/pytorch/callbacks/model_checkpoint.py | 3 +-- tests/tests_pytorch/callbacks/test_early_stopping.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index 5cbb005416cf7..c16b255a8d0bb 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -21,7 +21,6 @@ import logging from typing import Any, Callable, Dict, Optional, Tuple -import numpy as np import torch from torch import Tensor @@ -123,7 +122,7 @@ def __init__( raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - torch_inf = torch.tensor(np.Inf) + torch_inf = torch.tensor(torch.inf) self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf @property diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index bc4fe28827d12..a7513f6e9b15f 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,7 +27,6 @@ from typing import Any, Dict, Optional, Set from weakref import proxy -import numpy as np import torch import yaml from torch import Tensor @@ -442,7 +441,7 @@ def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> self.filename = filename def __init_monitor_mode(self, mode: str) -> None: - torch_inf = torch.tensor(np.Inf) + torch_inf = torch.tensor(torch.inf) mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")} if mode not in mode_dict: diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 4d8198d99e0e5..e5cc77302588c 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -20,7 +20,6 @@ from unittest.mock import Mock import cloudpickle -import numpy as np import pytest import torch @@ -245,7 +244,7 @@ def on_validation_epoch_end(self): assert trainer.current_epoch - 1 == expected_epoch, "early_stopping failed" -@pytest.mark.parametrize("stop_value", [torch.tensor(np.inf), torch.tensor(np.nan)]) +@pytest.mark.parametrize("stop_value", [torch.tensor(torch.inf), torch.tensor(torch.nan)]) def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value): losses = [4, 3, stop_value, 2, 1]