diff --git a/infer/modules/train/extract_feature_print.py b/infer/modules/train/extract_feature_print.py index 143fa6d..60ccea4 100644 --- a/infer/modules/train/extract_feature_print.py +++ b/infer/modules/train/extract_feature_print.py @@ -1,11 +1,12 @@ import os import sys import traceback +import argparse os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" -mdph = "logs/" + {model_name} +model_name = sys.argv[4] device = sys.argv[1] n_part = int(sys.argv[2]) i_part = int(sys.argv[3]) @@ -43,7 +44,7 @@ def forward_dml(ctx, x, scale): fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml -f = open(f"{mdph}/extract_f0_feature.log".format(exp_dir), "a+") +f = open(f"{model_name}/extract_f0_feature.log".format(exp_dir), "a+") def printt(strr): @@ -56,9 +57,9 @@ def printt(strr): model_path = "assets/hubert/hubert_base.pt" printt("exp_dir: " + exp_dir) -wavPath = f"{mdph}/1_16k_wavs".format(exp_dir) +wavPath = f"{model_name}/1_16k_wavs".format(exp_dir) outPath = ( - f"{mdph}/3_feature256".format(exp_dir) if version == "v1" else f"{mdph}/3_feature768".format(exp_dir) + f"{model_name}/3_feature256".format(exp_dir) if version == "v1" else f"{model_name}/3_feature768".format(exp_dir) ) os.makedirs(outPath, exist_ok=True)