diff --git a/main.py b/main.py index 1b31311..2570022 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,21 @@ K.set_session(sess) +def get_config(): + if is_coco: + import coco + class InferenceConfig(coco.CocoConfig): + GPU_COUNT = 1 + IMAGES_PER_GPU = 1 + + config = InferenceConfig() + + else: + config = mask_config(NUMBER_OF_CLASSES) + + return config + + def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): graph = sess.graph @@ -82,8 +97,13 @@ def make_serving_ready(model_path, save_serve_path, version_number): print("*" * 80) +# Load Mask RCNN config +# you can also load your own config in here. +# config = your_custom_config_class +config = get_config() + + # LOAD MODEL -config = mask_config(NUMBER_OF_CLASSES) model = MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config) model.load_weights(H5_WEIGHT_PATH, by_name=True) diff --git a/user_config.py b/user_config.py index 5cd5278..1b8bb44 100644 --- a/user_config.py +++ b/user_config.py @@ -1,5 +1,8 @@ # User Define parameters +# Make it True if you want to use the provided coco weights +is_coco = False + # keras model directory path MODEL_DIR = '/keras_model/'