From f0e731362726cc881c7902197588414a6c097c0c Mon Sep 17 00:00:00 2001 From: Denghui Lu Date: Tue, 14 Sep 2021 08:06:53 +0800 Subject: [PATCH] Add error message for repeated model compression (#1136) * add error message for replicated model compression * fix typo --- deepmd/entrypoints/compress.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/deepmd/entrypoints/compress.py b/deepmd/entrypoints/compress.py index 58c3c2a075..b03e8cf653 100644 --- a/deepmd/entrypoints/compress.py +++ b/deepmd/entrypoints/compress.py @@ -87,6 +87,8 @@ def compress( jdata = j_loader(training_script) t_min_nbor_dist = get_min_nbor_dist(jdata, get_rcut(jdata)) + _check_compress_type(input) + tf.constant(t_min_nbor_dist, name = 'train_attr/min_nbor_dist', dtype = GLOBAL_ENER_FLOAT_PRECISION) @@ -137,3 +139,13 @@ def compress( log.info("\n\n") log.info("stage 2: freeze the model") freeze(checkpoint_folder=checkpoint_folder, output=output, node_names=None) + +def _check_compress_type(model_file): + try: + t_model_type = bytes.decode(get_tensor_by_name(model_file, 'model_type')) + except GraphWithoutTensorError as e: + # Compatible with the upgraded model, which has no 'model_type' info + t_model_type = None + + if t_model_type == "compressed_model": + raise RuntimeError("The input frozen model %s has already been compressed! Please do not compress the model repeatedly. " % model_file) \ No newline at end of file