Skip to content

Commit

Permalink
Remove NumPy from Callback scripts (Lightning-AI#17267)
Browse files Browse the repository at this point in the history

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Apr 4, 2023
1 parent ee5bee0 commit 6cbc9df
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/lightning/pytorch/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import logging
from typing import Any, Callable, Dict, Optional, Tuple

import numpy as np
import torch
from torch import Tensor

Expand Down Expand Up @@ -123,7 +122,7 @@ def __init__(
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")

self.min_delta *= 1 if self.monitor_op == torch.gt else -1
torch_inf = torch.tensor(np.Inf)
torch_inf = torch.tensor(torch.inf)
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

@property
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from typing import Any, Dict, Optional, Set
from weakref import proxy

import numpy as np
import torch
import yaml
from torch import Tensor
Expand Down Expand Up @@ -442,7 +441,7 @@ def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) ->
self.filename = filename

def __init_monitor_mode(self, mode: str) -> None:
torch_inf = torch.tensor(np.Inf)
torch_inf = torch.tensor(torch.inf)
mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")}

if mode not in mode_dict:
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from unittest.mock import Mock

import cloudpickle
import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -245,7 +244,7 @@ def on_validation_epoch_end(self):
assert trainer.current_epoch - 1 == expected_epoch, "early_stopping failed"


@pytest.mark.parametrize("stop_value", [torch.tensor(np.inf), torch.tensor(np.nan)])
@pytest.mark.parametrize("stop_value", [torch.tensor(torch.inf), torch.tensor(torch.nan)])
def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):

losses = [4, 3, stop_value, 2, 1]
Expand Down

0 comments on commit 6cbc9df

Please sign in to comment.