diff --git a/rvc/lib/algorithm/commons.py b/rvc/lib/algorithm/commons.py index 444e095b..2524abc4 100644 --- a/rvc/lib/algorithm/commons.py +++ b/rvc/lib/algorithm/commons.py @@ -156,6 +156,7 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): acts = t_act * s_act return acts + def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]: """ Convert the pad shape to a list of integers. diff --git a/rvc/lib/algorithm/modules.py b/rvc/lib/algorithm/modules.py index ffd589fa..1038356d 100644 --- a/rvc/lib/algorithm/modules.py +++ b/rvc/lib/algorithm/modules.py @@ -1,6 +1,7 @@ import torch from rvc.lib.algorithm.commons import fused_add_tanh_sigmoid_multiply + class WaveNet(torch.nn.Module): """WaveNet residual blocks as used in WaveGlow diff --git a/rvc/lib/zluda.py b/rvc/lib/zluda.py index c5a55fb9..482009cc 100644 --- a/rvc/lib/zluda.py +++ b/rvc/lib/zluda.py @@ -1,16 +1,18 @@ import torch + if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"): _torch_stft = torch.stft + def z_stft( audio: torch.Tensor, n_fft: int, - hop_length: int = None, - win_length: int = None, - window: torch.Tensor = None, - center: bool = True, - pad_mode: str = "reflect", - normalized: bool = False, - onesided: bool = None, + hop_length: int = None, + win_length: int = None, + window: torch.Tensor = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: bool = None, return_complex: bool = None, ): sd = audio.device @@ -26,11 +28,11 @@ def z_stft( onesided=onesided, return_complex=return_complex, ).to(sd) - + def z_jit(f, *_, **__): f.graph = torch._C.Graph() return f - + # hijacks torch.stft = z_stft torch.jit.script = z_jit @@ -38,4 +40,4 @@ def z_jit(f, *_, **__): torch.backends.cudnn.enabled = False torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_math_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(False) \ No newline at end of file + torch.backends.cuda.enable_mem_efficient_sdp(False) diff --git a/rvc/train/train.py b/rvc/train/train.py index bf1b01bb..a63ce7a1 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -1024,6 +1024,7 @@ def save_to_json( with open(file_path, "w") as f: json.dump(data, f) + if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn") main()