Skip to content

Commit

Permalink
Merge branch 'main' into fix-tree
Browse files Browse the repository at this point in the history
  • Loading branch information
TylerYep authored Dec 28, 2023
2 parents be7967e + 80d3e67 commit 9fd3a13
Show file tree
Hide file tree
Showing 15 changed files with 51 additions and 118 deletions.
18 changes: 4 additions & 14 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,14 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10"]
pytorch-version: ["1.4.0", "1.5.1", "1.6.0", "1.7.1", "1.8", "1.9", "1.10", "1.11", "1.12", "1.13", "2.0", "2.1"]
include:
- python-version: 3.11
pytorch-version: 2.0
- python-version: 3.11
pytorch-version: 2.1
exclude:
- python-version: 3.7
pytorch-version: 1.11
- python-version: 3.7
pytorch-version: 1.12
- python-version: 3.7
pytorch-version: 1.13
- python-version: 3.7
pytorch-version: 2.0
- python-version: 3.7
pytorch-version: 2.1

- python-version: 3.9
pytorch-version: 1.4.0
- python-version: 3.9
Expand Down Expand Up @@ -63,12 +52,13 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install mypy pytest pytest-cov
python -m pip install pytest pytest-cov
pip install torch==${{ matrix.pytorch-version }} torchvision transformers
pip install compressai
- name: mypy
if: ${{ matrix.pytorch-version == '2.1' }}
run: |
python -m pip install mypy==1.7.1
mypy --install-types --non-interactive .
- name: pytest
if: ${{ matrix.pytorch-version == '2.1' }}
Expand All @@ -77,6 +67,6 @@ jobs:
- name: pytest
if: ${{ matrix.pytorch-version != '2.1' }}
run: |
pytest --no-output -k "not test_eval_order_doesnt_matter and not test_google and not test_uninitialized_tensor and not test_input_size_half_precision and not test_recursive_with_missing_layers"
pytest --no-output -k "not test_eval_order_doesnt_matter and not test_google and not test_uninitialized_tensor and not test_input_size_half_precision and not test_recursive_with_missing_layers and not test_flan_t5_small"
- name: codecov
uses: codecov/codecov-action@v1
43 changes: 5 additions & 38 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,45 +1,12 @@
ci:
skip: [mypy, pytest]
repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.9
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/hadialqattan/pycln
rev: v2.2.2
hooks:
- id: pycln
args: [--all]

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
hooks:
- id: black
args: [-C]

- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies:
[
flake8-future-annotations,
flake8-bugbear,
flake8-comprehensions,
]

- repo: https://github.com/PyCQA/pylint
rev: v3.0.1
hooks:
- id: pylint
args: ["--disable=import-error"]
- id: ruff
args: [ --fix ]
- id: ruff-format

- repo: local
hooks:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# torchinfo

[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/)
[![PyPI version](https://badge.fury.io/py/torchinfo.svg)](https://badge.fury.io/py/torchinfo)
[![Conda version](https://img.shields.io/conda/vn/conda-forge/torchinfo)](https://anaconda.org/conda-forge/torchinfo)
[![Build Status](https://github.com/TylerYep/torchinfo/actions/workflows/test.yml/badge.svg)](https://github.com/TylerYep/torchinfo/actions/workflows/test.yml)
Expand Down Expand Up @@ -470,7 +470,7 @@ Estimated Total Size (MB): 0.00
All issues and pull requests are much appreciated! If you are wondering how to build the project:

- torchinfo is actively developed using the lastest version of Python.
- Changes should be backward compatible to Python 3.7, and will follow Python's End-of-Life guidance for old versions.
- Changes should be backward compatible to Python 3.8, and will follow Python's End-of-Life guidance for old versions.
- Run `pip install -r requirements-dev.txt`. We use the latest versions of all dev packages.
- Run `pre-commit install`.
- To use auto-formatting tools, use `pre-commit run -a`.
Expand Down
6 changes: 3 additions & 3 deletions profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import pstats
import random

import torchvision # type: ignore[import-untyped] # pylint: disable=unused-import # noqa: F401, E501
from tqdm import trange # pylint: disable=unused-import # noqa: F401
import torchvision # type: ignore[import-untyped] # noqa: F401
from tqdm import trange # noqa: F401

from torchinfo import summary # pylint: disable=unused-import # noqa: F401
from torchinfo import summary # noqa: F401


def profile() -> None:
Expand Down
5 changes: 0 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
black
codecov
flake8
isort
mypy
pycln
pylint
pytest
pytest-cov
pre-commit
Expand Down
11 changes: 8 additions & 3 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
target-version = "py38"
select = ["ALL"]
ignore = [
"ANN101", # Missing type annotation for `self` in method
Expand All @@ -12,17 +13,22 @@ ignore = [
"FBT001", # Boolean positional arg in function definition
"FBT002", # Boolean default value in function definition
"FBT003", # Boolean positional value in function call
"FIX002", # Line contains TODO
"ISC001", # Isort
"PLR0911", # Too many return statements (11 > 6)
"PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable
"PLR0912", # Too many branches
"PLR0913", # Too many arguments to function call
"PLR0915", # Too many statements
"PTH123", # `open()` should be replaced by `Path.open()`
"S101", # Use of `assert` detected
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"T201", # print() found
"T203", # pprint() found
"TCH001", # Move application import into a type-checking block
"TCH003", # Move standard library import into a type-checking block
"TD002", # Missing author in TODO; try: `# TODO(<author_name>): ...`
"TD003", # Missing issue link on the line following this TODO
"TD005", # Missing issue description after `TODO`
"TRY003", # Avoid specifying long messages outside the exception class

# torchinfo-specific ignores
Expand All @@ -38,8 +44,7 @@ ignore = [
"TRY004", # Prefer `TypeError` exception for invalid type
"TRY301", # Abstract `raise` to an inner function
]
target-version = "py37"
exclude = ["tests"]
exclude = ["tests"] # TODO: check tests too

[flake8-pytest-style]
fixture-parentheses = false
38 changes: 4 additions & 34 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,15 @@ keywords = torch pytorch torchsummary torch-summary summary keras deep-learning

[options]
packages = torchinfo
python_requires = >=3.7
python_requires = >=3.8
include_package_data = True

[options.package_data]
torchinfo = py.typed

[mypy]
strict = True
implicit_reexport = True
warn_unreachable = True
disallow_any_unimported = True
extra_checks = True
enable_error_code = ignore-without-code

[pylint.main]
evaluation = 10.0 - ((float(5 * error + warning + refactor + convention + info) / statement) * 10)

[pylint.MESSAGES CONTROL]
extension-pkg-whitelist = torch
enable =
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
disable =
missing-module-docstring,
missing-function-docstring,
too-many-instance-attributes,
too-many-arguments,
too-many-branches,
too-many-locals,
invalid-name,
line-too-long, # Covered by flake8
no-member,
fixme,
duplicate-code,

[isort]
profile = black

[flake8]
max-line-length = 88
extend-ignore = E203,F401

[tool:pytest]
python_files = *_test.py
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest

from torchinfo import ModelStatistics
from torchinfo.formatting import HEADER_TITLES, ColumnSettings
from torchinfo.enums import ColumnSettings
from torchinfo.formatting import HEADER_TITLES
from torchinfo.torchinfo import clear_cached_forward_pass


Expand Down
1 change: 0 additions & 1 deletion tests/fixtures/genotype.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# type: ignore
# pylint: skip-file
from collections import namedtuple

import torch
Expand Down
11 changes: 5 additions & 6 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# pylint: disable=too-few-public-methods
from __future__ import annotations

import math
Expand Down Expand Up @@ -180,7 +179,7 @@ def __init__(self) -> None:
self.weights = torch.nn.ParameterList(
[
torch.nn.Parameter(weight)
for weight in torch.Tensor(100, 300).split([100, 200], dim=1) # type: ignore[no-untyped-call] # noqa: E501
for weight in torch.Tensor(100, 300).split([100, 200], dim=1) # type: ignore[no-untyped-call]
]
)

Expand Down Expand Up @@ -303,7 +302,7 @@ def __init__(self) -> None:
self.constant = 5

def forward(self, x: dict[int, torch.Tensor], scale_factor: int) -> torch.Tensor:
return scale_factor * (x[256] + x[512][0]) * self.constant
return cast(torch.Tensor, scale_factor * (x[256] + x[512][0]) * self.constant)


class ModuleDictModel(nn.Module):
Expand Down Expand Up @@ -359,7 +358,7 @@ def __int__(self) -> IntWithGetitem:
return self

def __getitem__(self, val: int) -> torch.Tensor:
return self.tensor * val
return cast(torch.Tensor, self.tensor * val)


class EdgecaseInputOutputModel(nn.Module):
Expand Down Expand Up @@ -576,7 +575,7 @@ def __init__(self) -> None:
self.b = nn.Parameter(torch.empty(10), requires_grad=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w * x + self.b
return cast(torch.Tensor, self.w * x + self.b)


class MixedTrainable(nn.Module):
Expand Down Expand Up @@ -718,7 +717,7 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = torch.mm(x, self.a) + self.b
if self.output_dim is None:
return h
return cast(torch.Tensor, h)
return cast(torch.Tensor, self.fc2(h))


Expand Down
3 changes: 1 addition & 2 deletions tests/fixtures/tmva_net.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# type: ignore
# pylint: skip-file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn


class DoubleConvBlock(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,14 +558,14 @@ def test_empty_module_list() -> None:


def test_single_parameter_model() -> None:
class ParameterA(nn.Module): # pylint: disable=too-few-public-methods
class ParameterA(nn.Module):
"""A model with one parameter."""

def __init__(self) -> None:
super().__init__()
self.w = nn.Parameter(torch.zeros(1024))

class ParameterB(nn.Module): # pylint: disable=too-few-public-methods
class ParameterB(nn.Module):
"""A model with one parameter and one Conv2d layer."""

def __init__(self) -> None:
Expand Down
8 changes: 8 additions & 0 deletions torchinfo/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
class Mode(str, Enum):
"""Enum containing all model modes."""

__slots__ = ()

TRAIN = "train"
EVAL = "eval"

Expand All @@ -15,6 +17,8 @@ class Mode(str, Enum):
class RowSettings(str, Enum):
"""Enum containing all available row settings."""

__slots__ = ()

DEPTH = "depth"
VAR_NAMES = "var_names"
ASCII_ONLY = "ascii_only"
Expand All @@ -25,6 +29,8 @@ class RowSettings(str, Enum):
class ColumnSettings(str, Enum):
"""Enum containing all available column settings."""

__slots__ = ()

KERNEL_SIZE = "kernel_size"
INPUT_SIZE = "input_size"
OUTPUT_SIZE = "output_size"
Expand All @@ -38,6 +44,8 @@ class ColumnSettings(str, Enum):
class Units(str, Enum):
"""Enum containing all available bytes units."""

__slots__ = ()

AUTO = "auto"
MEGABYTES = "M"
GIGABYTES = "G"
Expand Down
2 changes: 1 addition & 1 deletion torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def trainable(self) -> str:

@staticmethod
def calculate_size(
inputs: DETECTED_INPUT_OUTPUT_TYPES, batch_dim: int | None
inputs: DETECTED_INPUT_OUTPUT_TYPES | None, batch_dim: int | None
) -> tuple[list[int], int]:
"""
Set input_size or output_size using the model's inputs.
Expand Down
Loading

0 comments on commit 9fd3a13

Please sign in to comment.