Skip to content

Commit

Permalink
Onboard onto ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
TylerYep committed Sep 9, 2023
1 parent 96a710a commit 238e33d
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 14 deletions.
6 changes: 3 additions & 3 deletions profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import pstats
import random

import torchvision # type: ignore[import] # pylint: disable=unused-import # noqa
from tqdm import trange # pylint: disable=unused-import # noqa
import torchvision # type: ignore[import] # pylint: disable=unused-import # noqa: F401, E501
from tqdm import trange # pylint: disable=unused-import # noqa: F401

from torchinfo import summary # pylint: disable=unused-import # noqa
from torchinfo import summary # pylint: disable=unused-import # noqa: F401


def profile() -> None:
Expand Down
45 changes: 45 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
select = ["ALL"]
ignore = [
"ANN101", # Missing type annotation for `self` in method
"ANN102", # Missing type annotation for `cls` in classmethod
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"C901", # function is too complex (12 > 10)
"COM812", # Trailing comma missing
"D", # Docstring rules
"EM101", # Exception must not use a string literal, assign to variable first
"EM102", # Exception must not use an f-string literal, assign to variable first
"ERA001", # Found commented-out code
"FBT001", # Boolean positional arg in function definition
"FBT002", # Boolean default value in function definition
"FBT003", # Boolean positional value in function call
"PLR0911", # Too many return statements (11 > 6)
"PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable
"PLR0912", # Too many branches
"PLR0913", # Too many arguments to function call
"PLR0915", # Too many statements
"PTH123", # `open()` should be replaced by `Path.open()`
"S101", # Use of `assert` detected
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"T201", # print() found
"TCH001", # Move application import into a type-checking block
"TCH003", # Move standard library import into a type-checking block
"TRY003", # Avoid specifying long messages outside the exception class

# torchinfo-specific ignores
"N803", # Argument name `A_i` should be lowercase
"N806", # Variable `G` in function should be lowercase
"BLE001", # Do not catch blind exception: `Exception`
"PLW0602", # Using global for `_cached_forward_pass` but no assignment is done
"PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged
"PLW2901", # `for` loop variable `name` overwritten by assignment target
"SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block
"SLF001", # Private member accessed: `_modules`
"TCH002", # Move third-party import into a type-checking block
"TRY004", # Prefer `TypeError` exception for invalid type
"TRY301", # Abstract `raise` to an inner function
]
target-version = "py37"
exclude = ["tests"]

[flake8-pytest-style]
fixture-parentheses = false
13 changes: 6 additions & 7 deletions tests/half_precision_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import torch
from pytest import approx

from tests.fixtures.models import LinearModel, LSTMNet, SingleInputNet
from torchinfo import summary
Expand Down Expand Up @@ -38,12 +37,12 @@ def test_linear_model_half() -> None:
x = x.type(torch.float16).cuda()
results_half = summary(model, input_data=x)

assert ModelStatistics.to_megabytes(results_half.total_param_bytes) == approx(
ModelStatistics.to_megabytes(results.total_param_bytes) / 2
)
assert ModelStatistics.to_megabytes(results_half.total_output_bytes) == approx(
ModelStatistics.to_megabytes(results.total_output_bytes) / 2
)
assert ModelStatistics.to_megabytes(
results_half.total_param_bytes
) == pytest.approx(ModelStatistics.to_megabytes(results.total_param_bytes) / 2)
assert ModelStatistics.to_megabytes(
results_half.total_output_bytes
) == pytest.approx(ModelStatistics.to_megabytes(results.total_output_bytes) / 2)

@staticmethod
def test_lstm_half() -> None:
Expand Down
7 changes: 3 additions & 4 deletions torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,9 @@ def calculate_num_params(self) -> None:

# kernel_size for inner layer parameters
ksize = list(param.size())
if name == "weight":
# to make [in_shape, out_shape, ksize, ksize]
if len(ksize) > 1:
ksize[0], ksize[1] = ksize[1], ksize[0]
# to make [in_shape, out_shape, ksize, ksize]
if name == "weight" and len(ksize) > 1:
ksize[0], ksize[1] = ksize[1], ksize[0]

# RNN modules have inner weights such as weight_ih_l0
# Don't show parameters for the overall model, show for individual layers
Expand Down

0 comments on commit 238e33d

Please sign in to comment.