From 238e33d35085b38cbd993e8eae547a4dc9098204 Mon Sep 17 00:00:00 2001 From: Tyler Yep Date: Wed, 17 May 2023 00:12:40 -0700 Subject: [PATCH] Onboard onto ruff --- profiler.py | 6 ++--- ruff.toml | 45 ++++++++++++++++++++++++++++++++++++ tests/half_precision_test.py | 13 +++++------ torchinfo/layer_info.py | 7 +++--- 4 files changed, 57 insertions(+), 14 deletions(-) create mode 100644 ruff.toml diff --git a/profiler.py b/profiler.py index bc382d9..1b08218 100644 --- a/profiler.py +++ b/profiler.py @@ -2,10 +2,10 @@ import pstats import random -import torchvision # type: ignore[import] # pylint: disable=unused-import # noqa -from tqdm import trange # pylint: disable=unused-import # noqa +import torchvision # type: ignore[import] # pylint: disable=unused-import # noqa: F401, E501 +from tqdm import trange # pylint: disable=unused-import # noqa: F401 -from torchinfo import summary # pylint: disable=unused-import # noqa +from torchinfo import summary # pylint: disable=unused-import # noqa: F401 def profile() -> None: diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..b087b7a --- /dev/null +++ b/ruff.toml @@ -0,0 +1,45 @@ +select = ["ALL"] +ignore = [ + "ANN101", # Missing type annotation for `self` in method + "ANN102", # Missing type annotation for `cls` in classmethod + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "C901", # function is too complex (12 > 10) + "COM812", # Trailing comma missing + "D", # Docstring rules + "EM101", # Exception must not use a string literal, assign to variable first + "EM102", # Exception must not use an f-string literal, assign to variable first + "ERA001", # Found commented-out code + "FBT001", # Boolean positional arg in function definition + "FBT002", # Boolean default value in function definition + "FBT003", # Boolean positional value in function call + "PLR0911", # Too many return statements (11 > 6) + "PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable + "PLR0912", # Too many branches + "PLR0913", # Too many arguments to function call + "PLR0915", # Too many statements + "PTH123", # `open()` should be replaced by `Path.open()` + "S101", # Use of `assert` detected + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "T201", # print() found + "TCH001", # Move application import into a type-checking block + "TCH003", # Move standard library import into a type-checking block + "TRY003", # Avoid specifying long messages outside the exception class + + # torchinfo-specific ignores + "N803", # Argument name `A_i` should be lowercase + "N806", # Variable `G` in function should be lowercase + "BLE001", # Do not catch blind exception: `Exception` + "PLW0602", # Using global for `_cached_forward_pass` but no assignment is done + "PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged + "PLW2901", # `for` loop variable `name` overwritten by assignment target + "SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block + "SLF001", # Private member accessed: `_modules` + "TCH002", # Move third-party import into a type-checking block + "TRY004", # Prefer `TypeError` exception for invalid type + "TRY301", # Abstract `raise` to an inner function +] +target-version = "py37" +exclude = ["tests"] + +[flake8-pytest-style] +fixture-parentheses = false diff --git a/tests/half_precision_test.py b/tests/half_precision_test.py index 037f30b..6d4bf97 100644 --- a/tests/half_precision_test.py +++ b/tests/half_precision_test.py @@ -1,6 +1,5 @@ import pytest import torch -from pytest import approx from tests.fixtures.models import LinearModel, LSTMNet, SingleInputNet from torchinfo import summary @@ -38,12 +37,12 @@ def test_linear_model_half() -> None: x = x.type(torch.float16).cuda() results_half = summary(model, input_data=x) - assert ModelStatistics.to_megabytes(results_half.total_param_bytes) == approx( - ModelStatistics.to_megabytes(results.total_param_bytes) / 2 - ) - assert ModelStatistics.to_megabytes(results_half.total_output_bytes) == approx( - ModelStatistics.to_megabytes(results.total_output_bytes) / 2 - ) + assert ModelStatistics.to_megabytes( + results_half.total_param_bytes + ) == pytest.approx(ModelStatistics.to_megabytes(results.total_param_bytes) / 2) + assert ModelStatistics.to_megabytes( + results_half.total_output_bytes + ) == pytest.approx(ModelStatistics.to_megabytes(results.total_output_bytes) / 2) @staticmethod def test_lstm_half() -> None: diff --git a/torchinfo/layer_info.py b/torchinfo/layer_info.py index d3a3b43..5850206 100644 --- a/torchinfo/layer_info.py +++ b/torchinfo/layer_info.py @@ -203,10 +203,9 @@ def calculate_num_params(self) -> None: # kernel_size for inner layer parameters ksize = list(param.size()) - if name == "weight": - # to make [in_shape, out_shape, ksize, ksize] - if len(ksize) > 1: - ksize[0], ksize[1] = ksize[1], ksize[0] + # to make [in_shape, out_shape, ksize, ksize] + if name == "weight" and len(ksize) > 1: + ksize[0], ksize[1] = ksize[1], ksize[0] # RNN modules have inner weights such as weight_ih_l0 # Don't show parameters for the overall model, show for individual layers