Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

558 implement a base class for scripting models #559

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/training/default_training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ training:
max_used_perc:
state_dict_path:
state_dict_strict_load: True
script_model: False
compute_sampler_weights: False

# precision: 16
Expand Down
14 changes: 14 additions & 0 deletions tests/utils/test_script_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
import pytest
from utils.script_model import ScriptModel

def test_script_model():
model = torch.nn.Linear(3, 1)
script_model = ScriptModel(model,
input_shape=(1, 3),)

input_tensor = torch.rand((1, 3))
output = script_model.forward(input_tensor)

assert output.shape == (1, 1)
assert isinstance(output, torch.Tensor)
15 changes: 15 additions & 0 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tiling_segmentation import Tiler
from utils import augmentation as aug
from dataset import create_dataset
from utils.script_model import ScriptModel
from utils.logger import InformationLogger, tsv_line, get_logger, set_tracker
from utils.loss import verify_weights, define_loss
from utils.metrics import create_metrics_dict, calculate_batch_metrics
Expand Down Expand Up @@ -553,6 +554,7 @@ def train(cfg: DictConfig) -> None:
train_state_dict_path = get_key_def('state_dict_path', cfg['training'], default=None, expected_type=str)
state_dict_strict = get_key_def('state_dict_strict_load', cfg['training'], default=True, expected_type=bool)
dropout_prob = get_key_def('factor', cfg['scheduler']['params'], default=None, expected_type=float)
scriptmodel = get_key_def('script_model', cfg['training'], default=False, expected_type=bool)
# if error
if train_state_dict_path and not Path(train_state_dict_path).is_file():
raise logging.critical(
Expand Down Expand Up @@ -792,6 +794,19 @@ def train(cfg: DictConfig) -> None:

cur_elapsed = time.time() - since
# logging.info(f'\nCurrent elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s')

# Script model
if scriptmodel:
model_to_script = ScriptModel(model,
device=device,
input_shape=(1, num_bands, patches_size, patches_size),
mean=mean,
std=std,
min=scale[0],
max=scale[1])

scripted_model = torch.jit.script(model_to_script)
scripted_model.save(output_path.joinpath('scripted_model.pt'))
LucaRom marked this conversation as resolved.
Show resolved Hide resolved

# load checkpoint model and evaluate it on test dataset.
if int(cfg['general']['max_epochs']) > 0: # if num_epochs is set to 0, model is loaded to evaluate on test set
Expand Down
32 changes: 32 additions & 0 deletions utils/script_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

class ScriptModel(torch.nn.Module):
Copy link
Collaborator

@remtav remtav May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Very curious to start using scripted models. Thanks for this addition. Could a test for the __init__ and forward() to validate these methods act as expected be relevant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I added tests.

def __init__(self,
model,
device = torch.device("cpu"),
input_shape = (1, 3, 512, 512),
mean = [0.405,0.432,0.397],
std = [0.164,0.173,0.153],
min = 0,
max = 255,
scaled_min = 0.0,
scaled_max = 1.0):
super().__init__()
self.device = device
self.mean = torch.tensor(mean).resize_(len(mean), 1)
self.std = torch.tensor(std).resize_(len(std), 1)
self.min = min
self.max = max
self.min_val = scaled_min
self.max_val = scaled_max

input_tensor = torch.rand(input_shape).to(self.device)
self.model_scripted = torch.jit.trace(model.eval(), input_tensor)

def forward(self, input):
shape = input.shape
B, C = shape[0], shape[1]
input = (self.max_val - self.min_val) * (input - self.min) / (self.max -self.min) + self.min_val
input = (input.view(B, C, -1) - self.mean) / self.std
input = input.view(shape)
return self.model_scripted(input.to(self.device))
Loading