Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Jul 27, 2024
1 parent 6e75ec4 commit 7eadbe0
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.7
rev: v0.5.5
hooks:
# Run the linter.
- id: ruff
Expand Down
6 changes: 3 additions & 3 deletions bootstrap/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def parallelize_model(model: torch.nn.Module) -> torch.nn.Module:


def make_optimizer(
optimizer_partial: Partial[torch.optim.Optimizer], model: torch.nn.Module
) -> torch.optim.Optimizer:
optimizer_partial: Partial[torch.optim.optimizer.Optimizer], model: torch.nn.Module
) -> torch.optim.optimizer.Optimizer:
return optimizer_partial(model.parameters())


def make_scheduler(
scheduler_partial: Partial[torch.optim.lr_scheduler.LRScheduler],
optimizer: torch.optim.Optimizer,
optimizer: torch.optim.optimizer.Optimizer,
epochs: int,
) -> torch.optim.lr_scheduler.LRScheduler:
scheduler = scheduler_partial(
Expand Down
2 changes: 1 addition & 1 deletion bootstrap/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def init_wandb(
def launch_experiment(
run, # type: ignore
data_loader: Partial[DataLoader[Any]],
optimizer: Partial[torch.optim.Optimizer],
optimizer: Partial[torch.optim.optimizer.Optimizer],
scheduler: Partial[torch.optim.lr_scheduler.LRScheduler],
trainer: Partial[BaseTrainer],
tester: Partial[BaseTester],
Expand Down
25 changes: 13 additions & 12 deletions conf/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@
zen_partial=True, populate_full_signature=False
)

""" ================== Dataset ================== """
# ================== Dataset ==================


# Dataclasses are a great and simple way to define a base config group with default values.
# Dataclasses are a great and simple way to define a base config group with default
# values.
@dataclass
class ExampleDatasetConf:
dataset_name: str = "image_dataset"
Expand Down Expand Up @@ -86,7 +87,7 @@ class ExampleDatasetConf:
name="image_a_tiny",
)

""" ================== Dataloader & sampler ================== """
# ================== Dataloader & sampler ==================


@dataclass
Expand All @@ -107,12 +108,12 @@ class DataloaderConf:
persistent_workers: bool = False


""" ================== Model ================== """
# ================== Model ==================
# Pre-set the group for store's model entries
model_store = store(group="model")

# Not that encoder_input_dim depend on dataset.img_dim, so we need to use a partial to set them in
# the launch_experiment function.
# Not that encoder_input_dim depend on dataset.img_dim, so we need to use a partial to
# set them in the launch_experiment function.
model_store(
pbuilds(
ExampleModel,
Expand All @@ -134,7 +135,7 @@ class DataloaderConf:
name="model_b",
)

""" ================== Losses ================== """
# ================== Losses ==================
training_loss_store = store(group="training_loss")
training_loss_store(
pbuilds(
Expand All @@ -145,7 +146,7 @@ class DataloaderConf:
)


""" ================== Optimizer ================== """
# ================== Optimizer ==================


@dataclass
Expand All @@ -157,21 +158,21 @@ class Optimizer:
opt_store = store(group="optimizer")
opt_store(
pbuilds(
torch.optim.Adam,
torch.optim.adam.Adam,
builds_bases=(Optimizer,),
),
name="adam",
)
opt_store(
pbuilds(
torch.optim.SGD,
torch.optim.sgd.SGD,
builds_bases=(Optimizer,),
),
name="sgd",
)


""" ================== Scheduler ================== """
# ================== Scheduler ==================
sched_store = store(group="scheduler")
sched_store(
pbuilds(
Expand All @@ -197,7 +198,7 @@ class Optimizer:
name="cosine",
)

""" ================== Experiment ================== """
# ================== Experiment ==================


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions src/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from rich.text import Text
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torchmetrics import MeanMetric

Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
"""Base trainer class.
Args:
model (torch.nn.Module): Model to train.
opt (torch.optim.Optimizer): Optimizer to use.
opt (torch.optim.optimizer.Optimizer): Optimizer to use.
train_loader (torch.utils.data.DataLoader): Training dataloader.
val_loader (torch.utils.data.DataLoader): Validation dataloader.
"""
Expand Down

0 comments on commit 7eadbe0

Please sign in to comment.