Skip to content

Commit

Permalink
builder/save-keras-archive-default (#50)
Browse files Browse the repository at this point in the history
* save model using keras archive format by default

* remove extra lines

* turn off keras archive for older mods
  • Loading branch information
alphasentaurii committed Apr 6, 2024
1 parent 98a0533 commit 5db3f10
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 14 deletions.
9 changes: 9 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
1.1.0 (unreleased)
==================

new features
------------

- `architect.builder.Builder.save_model` uses preferred keras archive format by default [#50]


1.0.1 (2024-04-03)
==================

Expand Down
39 changes: 27 additions & 12 deletions spacekit/builder/architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_pretrained_network(self, arch=None):
self.log.error(err)
sys.exit(1)
model_src = "spacekit.builder.trained_networks"
archive_file = f"{arch}.zip" # hst_cal.zip | jwt_cal.zip | svm_align.zip
archive_file = f"{arch}.zip" # hst_cal.zip | jwst_cal.zip | svm_align.zip
with importlib.resources.path(model_src, archive_file) as mod:
self.model_path = mod
if self.blueprint is None:
Expand Down Expand Up @@ -338,7 +338,16 @@ def set_callbacks(self, patience=15):
self.callbacks = [checkpoint_cb, early_stopping_cb]
return self.callbacks

def save_model(self, weights=True, output_path="."):
def save_keras_model(self, model_path):
dpath = os.path.dirname(model_path)
name = os.path.basename(model_path)
if not name.endswith("keras"):
name += ".keras"
keras_model_path = os.path.join(dpath, name)
self.model.save(keras_model_path)
self.model_path = keras_model_path

def save_model(self, weights=True, output_path=".", keras_archive=True):
"""The model architecture, and training configuration (including the optimizer, losses, and metrics)
are stored in saved_model.pb. The weights are saved in the variables/ directory.
Expand All @@ -348,6 +357,8 @@ def save_model(self, weights=True, output_path="."):
save weights learned by the model separately also, by default True
output_path : str, optional
where to save the model files, by default "."
keras_archive : bool, optional
save model using new (preferred) keras archive format, by default True
"""
if self.name is None:
self.name = str(self.model.name_scope().rstrip("/"))
Expand All @@ -357,16 +368,20 @@ def save_model(self, weights=True, output_path="."):
model_name = self.name

model_path = os.path.join(output_path, "models", model_name)
weights_path = f"{model_path}/weights/ckpt"
self.model.save(model_path)
if weights is True:
self.model.save_weights(weights_path)
for root, _, files in os.walk(model_path):
indent = " " * root.count(os.sep)
print("{}{}/".format(indent, os.path.basename(root)))
for filename in files:
print("{}{}".format(indent + " ", filename))
self.model_path = model_path

if keras_archive is True:
self.save_keras_model(model_path)
else:
self.model.save(model_path)
if weights is True:
weights_path = f"{model_path}/weights/ckpt"
self.model.save_weights(weights_path)
for root, _, files in os.walk(model_path):
indent = " " * root.count(os.sep)
print("{}{}/".format(indent, os.path.basename(root)))
for filename in files:
print("{}{}".format(indent + " ", filename))
self.model_path = model_path

def model_diagram(
self,
Expand Down
6 changes: 4 additions & 2 deletions spacekit/skopes/hst/svm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def load_ensemble_data(


def train_ensemble(
XTR, YTR, XTS, YTS, model_name="ensembleSVM", params=None, output_path=None
XTR, YTR, XTS, YTS, model_name="ensembleSVM", params=None, output_path=None, keras=False,
):
"""Build, compile and fit an ensemble model with regression test data and image input arrays.
Expand Down Expand Up @@ -215,7 +215,7 @@ def train_ensemble(
if output_path is None:
output_path = os.getcwd()
model_outpath = os.path.join(output_path, os.path.dirname(model_name))
ens.save_model(weights=True, output_path=model_outpath)
ens.save_model(weights=True, output_path=model_outpath, keras_archive=keras)
return ens


Expand Down Expand Up @@ -267,6 +267,7 @@ def run_training(
model_name="ensembleSVM",
params=None,
output_path=None,
keras=False,
):
"""Main calling function to load and prep the data, train the model, compute results and save to disk.
Expand Down Expand Up @@ -305,6 +306,7 @@ def run_training(
model_name=model_name,
params=params,
output_path=output_path,
keras=keras,
)
com, val = compute_results(ens, tv_idx, val_set=(XVL, YVL), output_path=output_path)
return ens, com, val
Expand Down

0 comments on commit 5db3f10

Please sign in to comment.