From 5db3f10211adcb1fcb9461c96d3b62287fd26757 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ru=20Ke=C3=AFn?= <3181182+alphasentaurii@users.noreply.github.com> Date: Wed, 3 Apr 2024 19:06:27 -0400 Subject: [PATCH] builder/save-keras-archive-default (#50) * save model using keras archive format by default * remove extra lines * turn off keras archive for older mods --- CHANGES.rst | 9 ++++++++ spacekit/builder/architect.py | 39 ++++++++++++++++++++++---------- spacekit/skopes/hst/svm/train.py | 6 +++-- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index edccbd1..76a9863 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,12 @@ +1.1.0 (unreleased) +================== + +new features +------------ + +- `architect.builder.Builder.save_model` uses preferred keras archive format by default [#50] + + 1.0.1 (2024-04-03) ================== diff --git a/spacekit/builder/architect.py b/spacekit/builder/architect.py index 7f0d03d..c487c6a 100644 --- a/spacekit/builder/architect.py +++ b/spacekit/builder/architect.py @@ -145,7 +145,7 @@ def load_pretrained_network(self, arch=None): self.log.error(err) sys.exit(1) model_src = "spacekit.builder.trained_networks" - archive_file = f"{arch}.zip" # hst_cal.zip | jwt_cal.zip | svm_align.zip + archive_file = f"{arch}.zip" # hst_cal.zip | jwst_cal.zip | svm_align.zip with importlib.resources.path(model_src, archive_file) as mod: self.model_path = mod if self.blueprint is None: @@ -338,7 +338,16 @@ def set_callbacks(self, patience=15): self.callbacks = [checkpoint_cb, early_stopping_cb] return self.callbacks - def save_model(self, weights=True, output_path="."): + def save_keras_model(self, model_path): + dpath = os.path.dirname(model_path) + name = os.path.basename(model_path) + if not name.endswith("keras"): + name += ".keras" + keras_model_path = os.path.join(dpath, name) + self.model.save(keras_model_path) + self.model_path = keras_model_path + + def save_model(self, weights=True, output_path=".", keras_archive=True): """The model architecture, and training configuration (including the optimizer, losses, and metrics) are stored in saved_model.pb. The weights are saved in the variables/ directory. @@ -348,6 +357,8 @@ def save_model(self, weights=True, output_path="."): save weights learned by the model separately also, by default True output_path : str, optional where to save the model files, by default "." + keras_archive : bool, optional + save model using new (preferred) keras archive format, by default True """ if self.name is None: self.name = str(self.model.name_scope().rstrip("/")) @@ -357,16 +368,20 @@ def save_model(self, weights=True, output_path="."): model_name = self.name model_path = os.path.join(output_path, "models", model_name) - weights_path = f"{model_path}/weights/ckpt" - self.model.save(model_path) - if weights is True: - self.model.save_weights(weights_path) - for root, _, files in os.walk(model_path): - indent = " " * root.count(os.sep) - print("{}{}/".format(indent, os.path.basename(root))) - for filename in files: - print("{}{}".format(indent + " ", filename)) - self.model_path = model_path + + if keras_archive is True: + self.save_keras_model(model_path) + else: + self.model.save(model_path) + if weights is True: + weights_path = f"{model_path}/weights/ckpt" + self.model.save_weights(weights_path) + for root, _, files in os.walk(model_path): + indent = " " * root.count(os.sep) + print("{}{}/".format(indent, os.path.basename(root))) + for filename in files: + print("{}{}".format(indent + " ", filename)) + self.model_path = model_path def model_diagram( self, diff --git a/spacekit/skopes/hst/svm/train.py b/spacekit/skopes/hst/svm/train.py index 12f2541..f9b17cf 100644 --- a/spacekit/skopes/hst/svm/train.py +++ b/spacekit/skopes/hst/svm/train.py @@ -164,7 +164,7 @@ def load_ensemble_data( def train_ensemble( - XTR, YTR, XTS, YTS, model_name="ensembleSVM", params=None, output_path=None + XTR, YTR, XTS, YTS, model_name="ensembleSVM", params=None, output_path=None, keras=False, ): """Build, compile and fit an ensemble model with regression test data and image input arrays. @@ -215,7 +215,7 @@ def train_ensemble( if output_path is None: output_path = os.getcwd() model_outpath = os.path.join(output_path, os.path.dirname(model_name)) - ens.save_model(weights=True, output_path=model_outpath) + ens.save_model(weights=True, output_path=model_outpath, keras_archive=keras) return ens @@ -267,6 +267,7 @@ def run_training( model_name="ensembleSVM", params=None, output_path=None, + keras=False, ): """Main calling function to load and prep the data, train the model, compute results and save to disk. @@ -305,6 +306,7 @@ def run_training( model_name=model_name, params=params, output_path=output_path, + keras=keras, ) com, val = compute_results(ens, tv_idx, val_set=(XVL, YVL), output_path=output_path) return ens, com, val