From fc48be5cf42936d5113aa2fe9d5613e49af42d2c Mon Sep 17 00:00:00 2001 From: Politrees <143968312+Bebra777228@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:11:12 +0500 Subject: [PATCH] Update extract_feature_print.py --- infer/modules/train/extract_feature_print.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/infer/modules/train/extract_feature_print.py b/infer/modules/train/extract_feature_print.py index 143fa6d..60ccea4 100644 --- a/infer/modules/train/extract_feature_print.py +++ b/infer/modules/train/extract_feature_print.py @@ -1,11 +1,12 @@ import os import sys import traceback +import argparse os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" -mdph = "logs/" + {model_name} +model_name = sys.argv[4] device = sys.argv[1] n_part = int(sys.argv[2]) i_part = int(sys.argv[3]) @@ -43,7 +44,7 @@ def forward_dml(ctx, x, scale): fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml -f = open(f"{mdph}/extract_f0_feature.log".format(exp_dir), "a+") +f = open(f"{model_name}/extract_f0_feature.log".format(exp_dir), "a+") def printt(strr): @@ -56,9 +57,9 @@ def printt(strr): model_path = "assets/hubert/hubert_base.pt" printt("exp_dir: " + exp_dir) -wavPath = f"{mdph}/1_16k_wavs".format(exp_dir) +wavPath = f"{model_name}/1_16k_wavs".format(exp_dir) outPath = ( - f"{mdph}/3_feature256".format(exp_dir) if version == "v1" else f"{mdph}/3_feature768".format(exp_dir) + f"{model_name}/3_feature256".format(exp_dir) if version == "v1" else f"{model_name}/3_feature768".format(exp_dir) ) os.makedirs(outPath, exist_ok=True)