From 3b0ef239c6d7985dcad4f4ce7e70dcebaf7c4b08 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 4 Jun 2024 23:28:22 +0100 Subject: [PATCH] Migrate to ruff and pyright --- .github/workflows/python-app.yml | 16 ++----- .pre-commit-config.yaml | 20 ++++---- dataset/base/__init__.py | 12 +++-- dataset/base/image.py | 33 +++++++------ dataset/example.py | 2 +- launch_experiment.py | 53 ++++++++++++++------- pyproject.toml | 81 +++++++++++++++++++++++++++++++- requirements.txt | 4 +- src/base_tester.py | 16 +++++-- utils/__init__.py | 39 ++++++++------- utils/anim.py | 6 +-- 11 files changed, 191 insertions(+), 91 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 054ef6b..9d81fe9 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -26,21 +26,15 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest mypy types-PyYAML types-tqdm + pip install ruff pytest pyright types-PyYAML types-tqdm pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with flake8 + - name: Lint with ruff run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Type check with mypy (training script) + ruff + - name: Type check with pyright run: | - mypy train.py --disable-error-code=import-untyped - - name: Type check with mypy (test script) - run: | - mypy test.py --disable-error-code=import-untyped + pyright #- name: Test with pytest #run: | #pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0631bbd..b11f153 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,16 +8,12 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files -- repo: https://github.com/PyCQA/autoflake - rev: v2.2.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.4.7 hooks: - - id: autoflake - args: [--in-place, --remove-all-unused-imports, -r] -- repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort -- repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black + # Run the linter. + - id: ruff + args: [ --select, I, --fix ] # Sort imports too + # Run the formatter. + - id: ruff-format diff --git a/dataset/base/__init__.py b/dataset/base/__init__.py index c6d0e9e..6438ca9 100644 --- a/dataset/base/__init__.py +++ b/dataset/base/__init__.py @@ -12,18 +12,17 @@ transforming it may be extended through class inheritance in a specific dataset file. """ - import abc import os import os.path as osp -from typing import Tuple, Union +from typing import Any, Dict, List, Tuple, Union import torch from hydra.utils import get_original_cwd from torch.utils.data import Dataset -class BaseDataset(Dataset, abc.ABC): +class BaseDataset(Dataset[Any], abc.ABC): def __init__( self, dataset_root: str, @@ -36,6 +35,8 @@ def __init__( tiny: bool = False, ) -> None: super().__init__() + self._samples: Union[Dict[Any, Any], List[Any], torch.Tensor] + self._labels: Union[Dict[Any, Any], List[Any], torch.Tensor] self._samples, self._labels = self._load(dataset_root, tiny, split, seed) self._augment = augment and split == "train" self._normalize = normalize @@ -49,7 +50,10 @@ def __init__( @abc.abstractmethod def _load( self, dataset_root: str, tiny: bool, split: str, seed: int - ) -> Tuple[Union[dict, list, torch.Tensor], Union[dict, list, torch.Tensor]]: + ) -> Tuple[ + Union[Dict[str, Any], List[Any], torch.Tensor], + Union[Dict[str, Any], List[Any], torch.Tensor], + ]: # Implement this raise NotImplementedError diff --git a/dataset/base/image.py b/dataset/base/image.py index f3cbf81..25adb6d 100644 --- a/dataset/base/image.py +++ b/dataset/base/image.py @@ -9,13 +9,12 @@ Base dataset for images. """ - import abc -from typing import List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from torchvision.io.image import read_image -from torchvision.transforms import transforms +from torchvision.io.image import read_image # type: ignore +from torchvision.transforms import transforms # type: ignore from dataset.base import BaseDataset @@ -32,7 +31,7 @@ def __init__( dataset_name: str, split: str, seed: int, - img_size: Optional[tuple] = None, + img_size: Optional[tuple[int, ...]] = None, augment: bool = False, normalize: bool = False, tiny: bool = False, @@ -49,13 +48,13 @@ def __init__( tiny=tiny, ) self._img_size = self.IMG_SIZE if img_size is None else img_size - self._transforms = transforms.Compose( + self._transforms: Callable[[torch.Tensor], torch.Tensor] = transforms.Compose( [ transforms.Resize(self._img_size), ] ) - self._normalization = transforms.Normalize( - self.IMAGE_NET_MEAN, self.IMAGE_NET_STD + self._normalization: Callable[[torch.Tensor], torch.Tensor] = ( + transforms.Normalize(self.IMAGE_NET_MEAN, self.IMAGE_NET_STD) ) try: import albumentations as A # type: ignore @@ -63,7 +62,7 @@ def __init__( raise ImportError( "Please install albumentations to use the augmentation pipeline." ) - self._augs = A.Compose( + self._augs: Callable[..., Dict[str, Any]] = A.Compose( [ A.RandomCropFromBorders(), A.RandomBrightnessContrast(), @@ -74,22 +73,28 @@ def __init__( @abc.abstractmethod def _load( self, dataset_root: str, tiny: bool, split: str, seed: int - ) -> Tuple[Union[dict, list, torch.Tensor], Union[dict, list, torch.Tensor]]: + ) -> Tuple[ + Union[Dict[str, Any], List[Any], torch.Tensor], + Union[Dict[str, Any], List[Any], torch.Tensor], + ]: # Implement this raise NotImplementedError - def __getitem__(self, index: int): + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: """ This should be common to all image datasets! Override if you need something else. """ # ==== Load image and apply transforms === - img = read_image(self._samples[index]) # Returns a Tensor + img: torch.Tensor + img = read_image(self._samples[index]) # type: ignore + if not isinstance(img, torch.Tensor): + raise ValueError("Image not loaded as a Tensor.") img = self._transforms(img) if self._normalize: img = self._normalization(img) if self._augment: - img = self._augs(image=img) + img = self._augs(image=img)["image"] # ==== Load label and apply transforms === - label = self._labels[index] + label: Any = self._labels[index] return img, label diff --git a/dataset/example.py b/dataset/example.py index 80dedf0..a2a673f 100644 --- a/dataset/example.py +++ b/dataset/example.py @@ -38,7 +38,7 @@ def __init__( dataset_name, split, seed, - (img_dim, img_dim), + (img_dim, img_dim) if img_dim is not None else None, augment=augment, normalize=normalize, debug=debug, diff --git a/launch_experiment.py b/launch_experiment.py index 5f32853..4085787 100644 --- a/launch_experiment.py +++ b/launch_experiment.py @@ -8,6 +8,7 @@ import os from dataclasses import asdict +from typing import Any, Optional import hydra_zen import torch @@ -17,8 +18,9 @@ from hydra.utils import to_absolute_path from hydra_zen import just from hydra_zen.typing import Partial +from torch.utils.data import DataLoader, Dataset -import conf.experiment # Must import the config to add all components to the store! # noqa +import conf.experiment as exp_conf # type: ignore from conf import project as project_conf from model import TransparentDataParallel from src.base_tester import BaseTester @@ -27,13 +29,13 @@ def launch_experiment( - run, - data_loader: Partial[torch.utils.data.DataLoader], + run: exp_conf.RunConfig, + data_loader: Partial[torch.utils.data.DataLoader], # type: ignore optimizer: Partial[torch.optim.Optimizer], - scheduler: Partial[torch.optim.lr_scheduler._LRScheduler], + scheduler: Partial[torch.optim.lr_scheduler.LRScheduler], trainer: Partial[BaseTrainer], tester: Partial[BaseTester], - dataset: Partial[torch.utils.data.Dataset], + dataset: Partial[Dataset[Any]], model: Partial[torch.nn.Module], training_loss: Partial[torch.nn.Module], ): @@ -65,19 +67,19 @@ def launch_experiment( "============ Partials instantiation ============" model_inst = model( - encoder_input_dim=just(dataset).img_dim ** 2 + encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore ) # Use just() to get the config out of the Zen-Partial print(model_inst) print(f"Number of parameters: {sum(p.numel() for p in model_inst.parameters())}") print( f"Number of trainable parameters: {sum(p.numel() for p in model_inst.parameters() if p.requires_grad)}" ) - train_dataset, val_dataset, test_dataset = None, None, None + train_dataset: Optional[Dataset[Any]] = None + val_dataset: Optional[Dataset[Any]] = None + test_dataset: Optional[Dataset[Any]] = None if run.training_mode: - train_dataset, val_dataset = ( - dataset(split="train", seed=run.seed), - dataset(split="val", seed=run.seed), - ) + train_dataset = dataset(split="train", seed=run.seed) + val_dataset = dataset(split="val", seed=run.seed) else: test_dataset = dataset(split="test", augment=False, seed=run.seed) @@ -104,38 +106,45 @@ def launch_experiment( ) model_inst = TransparentDataParallel(model_inst) - if not run.training_mode: - training_loss_inst = None - else: + training_loss_inst: Optional[torch.nn.Module] = None + if run.training_mode: training_loss_inst = training_loss() "============ CUDA ============" model_inst: torch.nn.Module = to_cuda_(model_inst) # type: ignore - training_loss_inst: torch.nn.Module = to_cuda_(training_loss_inst) # type: ignore + training_loss_inst = to_cuda_(training_loss_inst) # type: ignore "============ Weights & Biases ============" if project_conf.USE_WANDB: # exp_conf is a string, so we need to load it back to a dict: exp_conf = yaml.safe_load(exp_conf) - wandb.init( + wandb.init( # type: ignore project=project_conf.PROJECT_NAME, name=run_name, config=exp_conf, ) - wandb.watch(model_inst, log="all", log_graph=True) + wandb.watch(model_inst, log="all", log_graph=True) # type: ignore " ============ Reproducibility of data loaders ============ " g = None if project_conf.REPRODUCIBLE: g = torch.Generator() g.manual_seed(run.seed) - train_loader_inst, val_loader_inst, test_loader_inst = None, None, None + train_loader_inst: Optional[DataLoader[Any]] = None + val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None + test_loader_inst: Optional[DataLoader[Any]] = None if run.training_mode: + if train_dataset is None or val_dataset is None: + raise ValueError( + "train_dataset and val_dataset must be defined in training mode!" + ) train_loader_inst = data_loader(train_dataset, generator=g) val_loader_inst = data_loader( val_dataset, generator=g, shuffle=False, drop_last=False ) else: + if test_dataset is None: + raise ValueError("test_dataset must be defined in testing mode!") test_loader_inst = data_loader( test_dataset, generator=g, shuffle=False, drop_last=False ) @@ -167,6 +176,12 @@ def launch_experiment( ) if run.training_mode: + if training_loss_inst is None: + raise ValueError("training_loss must be defined in training mode!") + if val_loader_inst is None or train_loader_inst is None: + raise ValueError( + "val_loader and train_loader must be defined in training mode!" + ) trainer( run_name=run_name, model=model_inst, @@ -187,6 +202,8 @@ def launch_experiment( model_ckpt_path=model_ckpt_path, ) else: + if test_loader_inst is None: + raise ValueError("test_loader must be defined in testing mode!") tester( run_name=run_name, model=model_inst, diff --git a/pyproject.toml b/pyproject.toml index 5d7bf33..615a051 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,79 @@ -[tool.isort] -profile = "black" +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "vendor", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py38" + +[tool.ruff.lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +# Enable auto-formatting of code examples in docstrings. Markdown, +# reStructuredText code/literal blocks and doctests are all supported. +# +# This is currently disabled by default, but it is planned for this +# to be opt-out in the future. +docstring-code-format = false + +# Set the line length limit used when formatting code snippets in +# docstrings. +# +# This only has an effect when the `docstring-code-format` setting is +# enabled. +docstring-code-line-length = "dynamic" diff --git a/requirements.txt b/requirements.txt index ba4ea64..ddfcdc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,7 @@ torchmetrics plotext tqdm wandb -isort -black -autoflake +ruff pre-commit blosc2 ipython diff --git a/src/base_tester.py b/src/base_tester.py index e71270c..15ebd70 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -11,7 +11,7 @@ import signal from collections import defaultdict -from typing import Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union import torch from torch.utils.data import DataLoader @@ -22,15 +22,18 @@ from src.base_trainer import BaseTrainer from utils import to_cuda, update_pbar_str +T = TypeVar("T") + class BaseTester(BaseTrainer): def __init__( self, run_name: str, - data_loader: DataLoader, + data_loader: DataLoader[T], model: torch.nn.Module, model_ckpt_path: str, - **kwargs, + training_loss: Optional[torch.nn.Module] = None, + **kwargs: Optional[Dict[str, Any]], ) -> None: """Base trainer class. Args: @@ -39,7 +42,8 @@ def __init__( train_loader (torch.utils.data.DataLoader): Training dataloader. val_loader (torch.utils.data.DataLoader): Validation dataloader. """ - _ = kwargs + _args = kwargs + _loss = training_loss self._run_name = run_name self._model = model assert model_ckpt_path is not None, "No model checkpoint path provided." @@ -76,7 +80,9 @@ def _test_iteration( # TODO: Compute your metrics here! return {} - def test(self, visualize_every: int = 0, **kwargs): + def test( + self, visualize_every: int = 0, **kwargs: Optional[Dict[str, Any]] + ) -> None: """Computes the average loss on the test set. Args: visualize_every (int, optional): Visualize the model predictions every n batches. diff --git a/utils/__init__.py b/utils/__init__.py index 1b25f83..27150de 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -6,24 +6,25 @@ # Distributed under terms of the MIT license. -import importlib -import inspect +# import importlib +# import inspect import random -import sys + +# import sys import traceback from contextlib import contextmanager -from typing import Any, List, Tuple, Union +from typing import Any, Callable, Dict, List, Union -import IPython +# import IPython import numpy as np import torch -import tqdm +from tqdm import tqdm from conf import project as project_conf def seed_everything(seed: int): - torch.manual_seed(seed) + torch.manual_seed(seed) # type: ignore np.random.seed(seed) random.seed(seed) # torch.use_deterministic_algorithms(True) @@ -33,7 +34,7 @@ def seed_everything(seed: int): torch.backends.cudnn.benchmark = False -def to_cuda_(x: Any) -> Union[Tuple, List, torch.Tensor, torch.nn.Module]: +def to_cuda_(x: Any) -> Any: device = "cpu" dtype = x.dtype if isinstance(x, torch.Tensor) else None if project_conf.USE_CUDA_IF_AVAILABLE and torch.cuda.is_available(): @@ -46,19 +47,19 @@ def to_cuda_(x: Any) -> Union[Tuple, List, torch.Tensor, torch.nn.Module]: if isinstance(x, (torch.Tensor, torch.nn.Module)): x = x.to(device, dtype=dtype) elif isinstance(x, tuple): - x = tuple(to_cuda_(t) for t in x) + x = tuple(to_cuda_(t) for t in x) # type: ignore elif isinstance(x, list): - x = [to_cuda_(t) for t in x] + x = [to_cuda_(t) for t in x] # type: ignore elif isinstance(x, dict): - x = {key: to_cuda_(value) for key, value in x.items()} + x = {key: to_cuda_(value) for key, value in x.items()} # type: ignore return x -def to_cuda(func): +def to_cuda(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator to move function arguments to cuda if available and if they are torch tensors, torch modules or tuples/lists of.""" - def wrapper(*args, **kwargs): + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: args = to_cuda_(args) for key, value in kwargs.items(): kwargs[key] = to_cuda_(value) @@ -71,11 +72,11 @@ def colorize(string: str, ansii_code: Union[int, str]) -> str: return f"\033[{ansii_code}m{string}\033[0m" -def blink_pbar(i: int, pbar: tqdm.tqdm, n: int) -> None: +def blink_pbar(i: int, pbar: tqdm, n: int) -> None: """Blink the progress bar every n iterations. Args: i (int): current iteration - pbar (tqdm.tqdm): progress bar + pbar (tqdm): progress bar n (int): blink every n iterations """ if i % n == 0: @@ -88,7 +89,7 @@ def blink_pbar(i: int, pbar: tqdm.tqdm, n: int) -> None: @contextmanager def colorize_prints(ansii_code: Union[int, str]): - if type(ansii_code) is str: + if isinstance(ansii_code, str): ansii_code = project_conf.ANSI_COLORS[ansii_code] print(f"\033[{ansii_code}m", end="") try: @@ -97,10 +98,10 @@ def colorize_prints(ansii_code: Union[int, str]): print("\033[0m", end="") -def update_pbar_str(pbar: tqdm.tqdm, string: str, color_code: int) -> None: +def update_pbar_str(pbar: tqdm, string: str, color_code: int) -> None: """Update the progress bar string. Args: - pbar (tqdm.tqdm): progress bar + pbar (tqdm): progress bar string (str): string to update the progress bar with color_code (int): color code for the string """ @@ -117,6 +118,7 @@ def get_function_frame(func, exc_traceback): return None +''' # TODO: Refactor this def debug_trace(callable): """ @@ -344,3 +346,4 @@ def __new__(cls, name, bases, dct): obj = super().__new__(cls, name, bases, dct) obj = debug_methods(obj) return obj +''' diff --git a/utils/anim.py b/utils/anim.py index 947236d..59b08df 100644 --- a/utils/anim.py +++ b/utils/anim.py @@ -9,8 +9,8 @@ from typing import Dict, Optional import numpy as np -import pyvista as pv -from trimesh import Trimesh +import pyvista as pv # type: ignore +from trimesh import Trimesh # type: ignore colors = { "hand": np.array([181 / 255, 144 / 255, 191 / 255]), @@ -75,7 +75,7 @@ def __init__( ): super().__init__() try: - import scenepic as sp + import scenepic as sp # type: ignore except ImportError: raise Exception( "scenepic not installed. "