Skip to content

Commit

Permalink
Update extract_feature_print.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Bebra777228 authored Mar 4, 2024
1 parent 81302c1 commit fc48be5
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions infer/modules/train/extract_feature_print.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import sys
import traceback
import argparse

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

mdph = "logs/" + {model_name}
model_name = sys.argv[4]
device = sys.argv[1]
n_part = int(sys.argv[2])
i_part = int(sys.argv[3])
Expand Down Expand Up @@ -43,7 +44,7 @@ def forward_dml(ctx, x, scale):

fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml

f = open(f"{mdph}/extract_f0_feature.log".format(exp_dir), "a+")
f = open(f"{model_name}/extract_f0_feature.log".format(exp_dir), "a+")


def printt(strr):
Expand All @@ -56,9 +57,9 @@ def printt(strr):
model_path = "assets/hubert/hubert_base.pt"

printt("exp_dir: " + exp_dir)
wavPath = f"{mdph}/1_16k_wavs".format(exp_dir)
wavPath = f"{model_name}/1_16k_wavs".format(exp_dir)
outPath = (
f"{mdph}/3_feature256".format(exp_dir) if version == "v1" else f"{mdph}/3_feature768".format(exp_dir)
f"{model_name}/3_feature256".format(exp_dir) if version == "v1" else f"{model_name}/3_feature768".format(exp_dir)
)
os.makedirs(outPath, exist_ok=True)

Expand Down

0 comments on commit fc48be5

Please sign in to comment.