From 681f9e15713f76e7239d1ffc14e9a26851a21577 Mon Sep 17 00:00:00 2001 From: Ben Orchard Date: Wed, 13 Dec 2023 11:56:58 +0000 Subject: [PATCH] remove __init__.py Defines a function to load a model (the one used in the paper), but using a static MLflow run ID, so not useful for other users. --- src/gz21_ocean_momentum/__init__.py | 43 ----------------------------- 1 file changed, 43 deletions(-) delete mode 100644 src/gz21_ocean_momentum/__init__.py diff --git a/src/gz21_ocean_momentum/__init__.py b/src/gz21_ocean_momentum/__init__.py deleted file mode 100644 index 2fd45416..00000000 --- a/src/gz21_ocean_momentum/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 25 18:28:27 2019 - -@author: Arthur -""" -import logging -import torch -import mlflow -from .models.utils import load_model_cls, pickle_artifact -from pathlib import Path - -MODEL_RUN_ID = "dc74cea68a7f4c7e98f9228649a97135" - -print("To load the net from the paper, use the function load_paper_net().") - - -def load_paper_net(device: str = "gpu"): - """ - Load the neural network from the paper - """ - model_module_name = "subgrid.models.fully_conv_net" - model_cls_name = "FullyCNN" - model_cls = load_model_cls(model_module_name, model_cls_name) - net = model_cls(2, 4) - if device == "cpu": - transformation = torch.load("./final_transformation.pth") - else: - transformation = pickle_artifact(MODEL_RUN_ID, "models/transformation") - net.final_transformation = transformation - - # Load parameters of pre-trained model - logging.info("Loading the neural net parameters") - client = mlflow.tracking.MlflowClient() - model_file = client.download_artifacts(MODEL_RUN_ID, "models/trained_model.pth") - if device == "cpu": - print("Device: CPU") - model_file = "./nn_weights_cpu.pth" - net.load_state_dict(torch.load(model_file, map_location=torch.device("cpu"))) - else: - net.load_state_dict(torch.load(model_file)) - print(net) - return net