Skip to content

Commit

Permalink
Refactor export_model method in SegmentationSegformer
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Oct 11, 2024
1 parent f0e4590 commit d83f213
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions geo_deep_learning/tasks_with_models/segmentation_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,22 @@ def on_train_end(self):
best_model_dir = Path(best_model_path).parent
best_model_name = Path(best_model_path).stem
best_model_export_path = str(best_model_dir / f"{best_model_name}_scripted.pt")
self.export_model(best_model_export_path, self.trainer.datamodule)

@classmethod
def load_from_checkpoint(cls, checkpoint_path, map_location=None, **kwargs):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
model = cls(**checkpoint['hyper_parameters'])
model.load_state_dict(checkpoint['state_dict'])
return model
self.export_model(best_model_path, best_model_export_path, self.trainer.datamodule)

def export_model(self, checkpoint_path: str, export_path: str, datamodule: LightningDataModule):
best_model = self.load_from_checkpoint(checkpoint_path)
input_channels = self.hparams["init_args"]["in_channels"]
map_location = "cuda"
if self.device.type == "cpu":
map_location = "cpu"
best_model = self.__class__.load_from_checkpoint(checkpoint_path, map_location=map_location)
best_model.eval()

scrpted_model = script_model(best_model, datamodule)
dummy_input = torch.randn(1, self.hparams.in_channels, *datamodule.patch_size)
scrpted_model = script_model(best_model.model, datamodule)
patch_size = datamodule.patch_size
dummy_input = torch.rand(1, input_channels, *patch_size, device=torch.device(map_location))
traced_model = torch.jit.trace(scrpted_model, dummy_input)
torch.jit.save(traced_model, export_path)
print(f"Model exported to TorchScript {export_path}")
print(f"Model exported to TorchScript")



0 comments on commit d83f213

Please sign in to comment.