Skip to content

Commit

Permalink
Merge pull request #743 from IAHispano/formatter/main
Browse files Browse the repository at this point in the history
chore(format): run black on main
  • Loading branch information
blaisewf authored Sep 27, 2024
2 parents 268a81f + 098f80d commit a0cc410
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
1 change: 1 addition & 0 deletions rvc/lib/algorithm/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
acts = t_act * s_act
return acts


def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
"""
Convert the pad shape to a list of integers.
Expand Down
1 change: 1 addition & 0 deletions rvc/lib/algorithm/modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from rvc.lib.algorithm.commons import fused_add_tanh_sigmoid_multiply


class WaveNet(torch.nn.Module):
"""WaveNet residual blocks as used in WaveGlow
Expand Down
22 changes: 12 additions & 10 deletions rvc/lib/zluda.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import torch

if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
_torch_stft = torch.stft

def z_stft(
audio: torch.Tensor,
n_fft: int,
hop_length: int = None,
win_length: int = None,
window: torch.Tensor = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: bool = None,
hop_length: int = None,
win_length: int = None,
window: torch.Tensor = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: bool = None,
return_complex: bool = None,
):
sd = audio.device
Expand All @@ -26,16 +28,16 @@ def z_stft(
onesided=onesided,
return_complex=return_complex,
).to(sd)

def z_jit(f, *_, **__):
f.graph = torch._C.Graph()
return f

# hijacks
torch.stft = z_stft
torch.jit.script = z_jit
# disabling unsupported cudnn
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
1 change: 1 addition & 0 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ def save_to_json(
with open(file_path, "w") as f:
json.dump(data, f)


if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
main()

0 comments on commit a0cc410

Please sign in to comment.