Skip to content

Commit

Permalink
Migrate to ruff and pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Jun 4, 2024
1 parent 24a62f5 commit 3b0ef23
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 91 deletions.
16 changes: 5 additions & 11 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest mypy types-PyYAML types-tqdm
pip install ruff pytest pyright types-PyYAML types-tqdm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
- name: Lint with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Type check with mypy (training script)
ruff
- name: Type check with pyright
run: |
mypy train.py --disable-error-code=import-untyped
- name: Type check with mypy (test script)
run: |
mypy test.py --disable-error-code=import-untyped
pyright
#- name: Test with pytest
#run: |
#pytest
20 changes: 8 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.7
hooks:
- id: autoflake
args: [--in-place, --remove-all-unused-imports, -r]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
# Run the linter.
- id: ruff
args: [ --select, I, --fix ] # Sort imports too
# Run the formatter.
- id: ruff-format
12 changes: 8 additions & 4 deletions dataset/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
transforming it may be extended through class inheritance in a specific dataset file.
"""


import abc
import os
import os.path as osp
from typing import Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import torch
from hydra.utils import get_original_cwd
from torch.utils.data import Dataset


class BaseDataset(Dataset, abc.ABC):
class BaseDataset(Dataset[Any], abc.ABC):
def __init__(
self,
dataset_root: str,
Expand All @@ -36,6 +35,8 @@ def __init__(
tiny: bool = False,
) -> None:
super().__init__()
self._samples: Union[Dict[Any, Any], List[Any], torch.Tensor]
self._labels: Union[Dict[Any, Any], List[Any], torch.Tensor]
self._samples, self._labels = self._load(dataset_root, tiny, split, seed)
self._augment = augment and split == "train"
self._normalize = normalize
Expand All @@ -49,7 +50,10 @@ def __init__(
@abc.abstractmethod
def _load(
self, dataset_root: str, tiny: bool, split: str, seed: int
) -> Tuple[Union[dict, list, torch.Tensor], Union[dict, list, torch.Tensor]]:
) -> Tuple[
Union[Dict[str, Any], List[Any], torch.Tensor],
Union[Dict[str, Any], List[Any], torch.Tensor],
]:
# Implement this
raise NotImplementedError

Expand Down
33 changes: 19 additions & 14 deletions dataset/base/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
Base dataset for images.
"""


import abc
from typing import List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torchvision.io.image import read_image
from torchvision.transforms import transforms
from torchvision.io.image import read_image # type: ignore
from torchvision.transforms import transforms # type: ignore

from dataset.base import BaseDataset

Expand All @@ -32,7 +31,7 @@ def __init__(
dataset_name: str,
split: str,
seed: int,
img_size: Optional[tuple] = None,
img_size: Optional[tuple[int, ...]] = None,
augment: bool = False,
normalize: bool = False,
tiny: bool = False,
Expand All @@ -49,21 +48,21 @@ def __init__(
tiny=tiny,
)
self._img_size = self.IMG_SIZE if img_size is None else img_size
self._transforms = transforms.Compose(
self._transforms: Callable[[torch.Tensor], torch.Tensor] = transforms.Compose(
[
transforms.Resize(self._img_size),
]
)
self._normalization = transforms.Normalize(
self.IMAGE_NET_MEAN, self.IMAGE_NET_STD
self._normalization: Callable[[torch.Tensor], torch.Tensor] = (
transforms.Normalize(self.IMAGE_NET_MEAN, self.IMAGE_NET_STD)
)
try:
import albumentations as A # type: ignore
except ImportError:
raise ImportError(
"Please install albumentations to use the augmentation pipeline."
)
self._augs = A.Compose(
self._augs: Callable[..., Dict[str, Any]] = A.Compose(
[
A.RandomCropFromBorders(),
A.RandomBrightnessContrast(),
Expand All @@ -74,22 +73,28 @@ def __init__(
@abc.abstractmethod
def _load(
self, dataset_root: str, tiny: bool, split: str, seed: int
) -> Tuple[Union[dict, list, torch.Tensor], Union[dict, list, torch.Tensor]]:
) -> Tuple[
Union[Dict[str, Any], List[Any], torch.Tensor],
Union[Dict[str, Any], List[Any], torch.Tensor],
]:
# Implement this
raise NotImplementedError

def __getitem__(self, index: int):
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This should be common to all image datasets!
Override if you need something else.
"""
# ==== Load image and apply transforms ===
img = read_image(self._samples[index]) # Returns a Tensor
img: torch.Tensor
img = read_image(self._samples[index]) # type: ignore
if not isinstance(img, torch.Tensor):
raise ValueError("Image not loaded as a Tensor.")
img = self._transforms(img)
if self._normalize:
img = self._normalization(img)
if self._augment:
img = self._augs(image=img)
img = self._augs(image=img)["image"]
# ==== Load label and apply transforms ===
label = self._labels[index]
label: Any = self._labels[index]
return img, label
2 changes: 1 addition & 1 deletion dataset/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
dataset_name,
split,
seed,
(img_dim, img_dim),
(img_dim, img_dim) if img_dim is not None else None,
augment=augment,
normalize=normalize,
debug=debug,
Expand Down
53 changes: 35 additions & 18 deletions launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import os
from dataclasses import asdict
from typing import Any, Optional

import hydra_zen
import torch
Expand All @@ -17,8 +18,9 @@
from hydra.utils import to_absolute_path
from hydra_zen import just
from hydra_zen.typing import Partial
from torch.utils.data import DataLoader, Dataset

import conf.experiment # Must import the config to add all components to the store! # noqa
import conf.experiment as exp_conf # type: ignore
from conf import project as project_conf
from model import TransparentDataParallel
from src.base_tester import BaseTester
Expand All @@ -27,13 +29,13 @@


def launch_experiment(
run,
data_loader: Partial[torch.utils.data.DataLoader],
run: exp_conf.RunConfig,
data_loader: Partial[torch.utils.data.DataLoader], # type: ignore
optimizer: Partial[torch.optim.Optimizer],
scheduler: Partial[torch.optim.lr_scheduler._LRScheduler],
scheduler: Partial[torch.optim.lr_scheduler.LRScheduler],
trainer: Partial[BaseTrainer],
tester: Partial[BaseTester],
dataset: Partial[torch.utils.data.Dataset],
dataset: Partial[Dataset[Any]],
model: Partial[torch.nn.Module],
training_loss: Partial[torch.nn.Module],
):
Expand Down Expand Up @@ -65,19 +67,19 @@ def launch_experiment(

"============ Partials instantiation ============"
model_inst = model(
encoder_input_dim=just(dataset).img_dim ** 2
encoder_input_dim=just(dataset).img_dim ** 2 # type: ignore
) # Use just() to get the config out of the Zen-Partial
print(model_inst)
print(f"Number of parameters: {sum(p.numel() for p in model_inst.parameters())}")
print(
f"Number of trainable parameters: {sum(p.numel() for p in model_inst.parameters() if p.requires_grad)}"
)
train_dataset, val_dataset, test_dataset = None, None, None
train_dataset: Optional[Dataset[Any]] = None
val_dataset: Optional[Dataset[Any]] = None
test_dataset: Optional[Dataset[Any]] = None
if run.training_mode:
train_dataset, val_dataset = (
dataset(split="train", seed=run.seed),
dataset(split="val", seed=run.seed),
)
train_dataset = dataset(split="train", seed=run.seed)
val_dataset = dataset(split="val", seed=run.seed)
else:
test_dataset = dataset(split="test", augment=False, seed=run.seed)

Expand All @@ -104,38 +106,45 @@ def launch_experiment(
)
model_inst = TransparentDataParallel(model_inst)

if not run.training_mode:
training_loss_inst = None
else:
training_loss_inst: Optional[torch.nn.Module] = None
if run.training_mode:
training_loss_inst = training_loss()

"============ CUDA ============"
model_inst: torch.nn.Module = to_cuda_(model_inst) # type: ignore
training_loss_inst: torch.nn.Module = to_cuda_(training_loss_inst) # type: ignore
training_loss_inst = to_cuda_(training_loss_inst) # type: ignore

"============ Weights & Biases ============"
if project_conf.USE_WANDB:
# exp_conf is a string, so we need to load it back to a dict:
exp_conf = yaml.safe_load(exp_conf)
wandb.init(
wandb.init( # type: ignore
project=project_conf.PROJECT_NAME,
name=run_name,
config=exp_conf,
)
wandb.watch(model_inst, log="all", log_graph=True)
wandb.watch(model_inst, log="all", log_graph=True) # type: ignore
" ============ Reproducibility of data loaders ============ "
g = None
if project_conf.REPRODUCIBLE:
g = torch.Generator()
g.manual_seed(run.seed)

train_loader_inst, val_loader_inst, test_loader_inst = None, None, None
train_loader_inst: Optional[DataLoader[Any]] = None
val_loader_inst: Optional[DataLoader[Dataset[Any]]] = None
test_loader_inst: Optional[DataLoader[Any]] = None
if run.training_mode:
if train_dataset is None or val_dataset is None:
raise ValueError(
"train_dataset and val_dataset must be defined in training mode!"
)
train_loader_inst = data_loader(train_dataset, generator=g)
val_loader_inst = data_loader(
val_dataset, generator=g, shuffle=False, drop_last=False
)
else:
if test_dataset is None:
raise ValueError("test_dataset must be defined in testing mode!")
test_loader_inst = data_loader(
test_dataset, generator=g, shuffle=False, drop_last=False
)
Expand Down Expand Up @@ -167,6 +176,12 @@ def launch_experiment(
)

if run.training_mode:
if training_loss_inst is None:
raise ValueError("training_loss must be defined in training mode!")
if val_loader_inst is None or train_loader_inst is None:
raise ValueError(
"val_loader and train_loader must be defined in training mode!"
)
trainer(
run_name=run_name,
model=model_inst,
Expand All @@ -187,6 +202,8 @@ def launch_experiment(
model_ckpt_path=model_ckpt_path,
)
else:
if test_loader_inst is None:
raise ValueError("test_loader must be defined in testing mode!")
tester(
run_name=run_name,
model=model_inst,
Expand Down
Loading

0 comments on commit 3b0ef23

Please sign in to comment.