diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f31ad78..c45b176 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,19 +12,25 @@ jobs: fail-fast: false matrix: python-version: ["3.8", "3.9", "3.10"] - pytorch-version: ["1.4.0", "1.5.1", "1.6.0", "1.7.1", "1.8", "1.9", "1.10", "1.11", "1.12", "1.13", "2.0", "2.1"] + pytorch-version: ["1.4.0", "1.5.1", "1.6.0", "1.7.1", "1.8", "1.9", "1.10", "1.11", "1.12", "1.13", "2.0", "2.1", "2.2"] include: - python-version: 3.11 pytorch-version: 2.0 - python-version: 3.11 pytorch-version: 2.1 + - python-version: 3.11 + pytorch-version: 2.2 exclude: + - python-version: 3.8 + pytorch-version: 1.7.1 - python-version: 3.9 pytorch-version: 1.4.0 - python-version: 3.9 pytorch-version: 1.5.1 - python-version: 3.9 pytorch-version: 1.6.0 + - python-version: 3.9 + pytorch-version: 1.7.1 - python-version: 3.10 pytorch-version: 1.4.0 @@ -56,16 +62,16 @@ jobs: pip install torch==${{ matrix.pytorch-version }} torchvision transformers pip install compressai - name: mypy - if: ${{ matrix.pytorch-version == '2.1' }} + if: ${{ matrix.pytorch-version == '2.2' }} run: | - python -m pip install mypy==1.7.1 + python -m pip install mypy==1.9.0 mypy --install-types --non-interactive . - name: pytest - if: ${{ matrix.pytorch-version == '2.1' }} + if: ${{ matrix.pytorch-version == '2.2' }} run: | pytest --cov=torchinfo --cov-report= --durations=0 - name: pytest - if: ${{ matrix.pytorch-version != '2.1' }} + if: ${{ matrix.pytorch-version != '2.2' }} run: | pytest --no-output -k "not test_eval_order_doesnt_matter and not test_google and not test_uninitialized_tensor and not test_input_size_half_precision and not test_recursive_with_missing_layers and not test_flan_t5_small" - name: codecov diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 226f1ee..1d3351a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ ci: skip: [mypy, pytest] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.2 + rev: v0.3.4 hooks: - id: ruff args: [--fix] diff --git a/ruff.toml b/ruff.toml index 4f319cb..5a58061 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,8 +1,8 @@ target-version = "py38" -select = ["ALL"] -ignore = [ - "ANN101", # Missing type annotation for `self` in method - "ANN102", # Missing type annotation for `cls` in classmethod +lint.select = ["ALL"] +lint.ignore = [ + "ANN101", # Missing type annotation for `self` in method + "ANN102", # Missing type annotation for `cls` in classmethod "ANN401", # Dynamically typed expressions (typing.Any) are disallowed "C901", # function is too complex (12 > 10) "COM812", # Trailing comma missing @@ -15,36 +15,34 @@ ignore = [ "FBT003", # Boolean positional value in function call "FIX002", # Line contains TODO "ISC001", # Isort - "PLR0911", # Too many return statements (11 > 6) + "PLR0911", # Too many return statements (11 > 6) "PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable - "PLR0912", # Too many branches + "PLR0912", # Too many branches "PLR0913", # Too many arguments to function call "PLR0915", # Too many statements "S101", # Use of `assert` detected "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "T201", # print() found "T203", # pprint() found - "TCH001", # Move application import into a type-checking block - "TCH003", # Move standard library import into a type-checking block - "TD002", # Missing author in TODO; try: `# TODO(): ...` - "TD003", # Missing issue link on the line following this TODO - "TD005", # Missing issue description after `TODO` - "TRY003", # Avoid specifying long messages outside the exception class + "TD002", # Missing author in TODO; try: `# TODO(): ...` + "TD003", # Missing issue link on the line following this TODO + "TD005", # Missing issue description after `TODO` + "TRY003", # Avoid specifying long messages outside the exception class # torchinfo-specific ignores - "N803", # Argument name `A_i` should be lowercase - "N806", # Variable `G` in function should be lowercase + "N803", # Argument name `A_i` should be lowercase + "N806", # Variable `G` in function should be lowercase "BLE001", # Do not catch blind exception: `Exception` - "PLW0602", # Using global for `_cached_forward_pass` but no assignment is done - "PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged - "PLW2901", # `for` loop variable `name` overwritten by assignment target - "SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block - "SLF001", # Private member accessed: `_modules` - "TCH002", # Move third-party import into a type-checking block - "TRY004", # Prefer `TypeError` exception for invalid type - "TRY301", # Abstract `raise` to an inner function + "PLW0602", # Using global for `_cached_forward_pass` but no assignment is done + "PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged + "PLW2901", # `for` loop variable `name` overwritten by assignment target + "SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block + "SLF001", # Private member accessed: `_modules` + "TCH002", # Move third-party import into a type-checking block + "TRY004", # Prefer `TypeError` exception for invalid type + "TRY301", # Abstract `raise` to an inner function ] -exclude = ["tests"] # TODO: check tests too +exclude = ["tests"] # TODO: check tests too -[flake8-pytest-style] +[lint.flake8-pytest-style] fixture-parentheses = false diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 2ab5181..f91a2f7 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -302,7 +302,7 @@ def __init__(self) -> None: self.constant = 5 def forward(self, x: dict[int, torch.Tensor], scale_factor: int) -> torch.Tensor: - return cast(torch.Tensor, scale_factor * (x[256] + x[512][0]) * self.constant) + return scale_factor * (x[256] + x[512][0]) * self.constant class ModuleDictModel(nn.Module): @@ -358,7 +358,7 @@ def __int__(self) -> IntWithGetitem: return self def __getitem__(self, val: int) -> torch.Tensor: - return cast(torch.Tensor, self.tensor * val) + return self.tensor * val class EdgecaseInputOutputModel(nn.Module): @@ -575,7 +575,7 @@ def __init__(self) -> None: self.b = nn.Parameter(torch.empty(10), requires_grad=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, self.w * x + self.b) + return self.w * x + self.b class MixedTrainable(nn.Module): @@ -717,7 +717,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: h = torch.mm(x, self.a) + self.b if self.output_dim is None: - return cast(torch.Tensor, h) + return h return cast(torch.Tensor, self.fc2(h)) diff --git a/torchinfo/formatting.py b/torchinfo/formatting.py index 4a5e58c..3b48405 100644 --- a/torchinfo/formatting.py +++ b/torchinfo/formatting.py @@ -1,10 +1,12 @@ from __future__ import annotations import math -from typing import Any +from typing import TYPE_CHECKING, Any from .enums import ColumnSettings, RowSettings, Units, Verbosity -from .layer_info import LayerInfo + +if TYPE_CHECKING: + from .layer_info import LayerInfo HEADER_TITLES = { ColumnSettings.KERNEL_SIZE: "Kernel Shape", diff --git a/torchinfo/layer_info.py b/torchinfo/layer_info.py index 57752c5..0224a0a 100644 --- a/torchinfo/layer_info.py +++ b/torchinfo/layer_info.py @@ -119,8 +119,8 @@ def calculate_size( size = list(inputs.size()) elem_bytes = inputs.element_size() - elif isinstance(inputs, np.ndarray): - inputs_ = torch.from_numpy(inputs) + elif isinstance(inputs, np.ndarray): # type: ignore[unreachable] + inputs_ = torch.from_numpy(inputs) # type: ignore[unreachable] size, elem_bytes = list(inputs_.size()), inputs_.element_size() elif isinstance(inputs, (list, tuple)): @@ -217,9 +217,9 @@ def calculate_num_params(self) -> None: final_name = name # Fix the final row to display more nicely if self.inner_layers: - self.inner_layers[final_name][ - ColumnSettings.NUM_PARAMS - ] = f"└─{self.inner_layers[final_name][ColumnSettings.NUM_PARAMS][2:]}" + self.inner_layers[final_name][ColumnSettings.NUM_PARAMS] = ( + f"└─{self.inner_layers[final_name][ColumnSettings.NUM_PARAMS][2:]}" + ) def calculate_macs(self) -> None: """ @@ -322,8 +322,9 @@ def nested_list_size(inputs: Sequence[Any] | torch.Tensor) -> tuple[list[int], i size, elem_bytes = nested_list_size(inputs.tensors) elif isinstance(inputs, torch.Tensor): size, elem_bytes = list(inputs.size()), inputs.element_size() - elif isinstance(inputs, np.ndarray): - inputs_torch = torch.from_numpy(inputs) # preserves dtype + elif isinstance(inputs, np.ndarray): # type: ignore[unreachable] + # preserves dtype + inputs_torch = torch.from_numpy(inputs) # type: ignore[unreachable] size, elem_bytes = list(inputs_torch.size()), inputs_torch.element_size() elif not hasattr(inputs, "__getitem__") or not inputs: size, elem_bytes = [], 0 @@ -358,8 +359,8 @@ def rgetattr(module: nn.Module, attr: str) -> torch.Tensor | None: if not hasattr(module, attr_i): return None module = getattr(module, attr_i) - assert isinstance(module, torch.Tensor) - return module + assert isinstance(module, torch.Tensor) # type: ignore[unreachable] + return module # type: ignore[unreachable] def get_children_layers(summary_list: list[LayerInfo], index: int) -> list[LayerInfo]: diff --git a/torchinfo/model_statistics.py b/torchinfo/model_statistics.py index da0abc4..854bea8 100644 --- a/torchinfo/model_statistics.py +++ b/torchinfo/model_statistics.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from .enums import Units from .formatting import CONVERSION_FACTORS, FormattingOptions -from .layer_info import LayerInfo + +if TYPE_CHECKING: + from .layer_info import LayerInfo class ModelStatistics: diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index f511cac..4e58fc5 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -484,7 +484,7 @@ def get_device( model_parameter = None if model_parameter is not None and model_parameter.is_cuda: - return model_parameter.device # type: ignore[no-any-return] + return model_parameter.device return torch.device("cuda" if torch.cuda.is_available() else "cpu") return None