-
Notifications
You must be signed in to change notification settings - Fork 0
/
biodenoising_live.py
155 lines (137 loc) · 5.45 KB
/
biodenoising_live.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) Earth Species Project. This work is based on Facebook's denoiser.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import argparse
import sys
import sounddevice as sd
import torch
import biodenoising
def get_parser():
parser = argparse.ArgumentParser(
"denoiser.live",
description="Performs live speech enhancement, reading audio from "
"the default mic (or interface specified by --in) and "
"writing the enhanced version to 'Soundflower (2ch)' "
"(or the interface specified by --out)."
)
parser.add_argument(
"-i", "--in", dest="in_",
help="name or index of input interface.")
parser.add_argument(
"-o", "--out", default="Soundflower (2ch)",
help="name or index of output interface.")
biodenoising.denoiser.pretrained.add_model_flags(parser)
parser.add_argument(
"--no_compressor", action="store_false", dest="compressor",
help="Deactivate compressor on output, might lead to clipping.")
parser.add_argument(
"--device", default="cpu")
parser.add_argument(
"--dry", type=float, default=0.04,
help="Dry/wet knob, between 0 and 1. 0=maximum noise removal "
"but it might cause distortions. Default is 0.04")
parser.add_argument(
"-t", "--num_threads", type=int,
help="Number of threads. If you have DDR3 RAM, setting -t 1 can "
"improve performance.")
parser.add_argument(
"-f", "--num_frames", type=int, default=1,
help="Number of frames to process at once. Larger values increase "
"the overall lag, but will improve speed.")
return parser
def parse_audio_device(device):
if device is None:
return device
try:
return int(device)
except ValueError:
return device
def query_devices(device, kind):
try:
caps = sd.query_devices(device, kind=kind)
except ValueError:
message = biodenoising.denoiser.bold(f"Invalid {kind} audio interface {device}.\n")
message += (
"If you are on Mac OS X, try installing Soundflower "
"(https://github.com/mattingalls/Soundflower).\n"
"You can list available interfaces with `python3 -m sounddevice` on Linux and OS X, "
"and `python.exe -m sounddevice` on Windows. You must have at least one loopback "
"audio interface to use this.")
print(message, file=sys.stderr)
sys.exit(1)
return caps
def main():
args = get_parser().parse_args()
if args.num_threads:
torch.set_num_threads(args.num_threads)
model = biodenoising.denoiser.pretrained.get_model(args).to(args.device)
model.eval()
print("Model loaded.")
streamer = biodenoising.denoiser.demucs.DemucsStreamer(model, dry=args.dry, num_frames=args.num_frames)
device_in = parse_audio_device(args.in_)
caps = query_devices(device_in, "input")
channels_in = min(caps['max_input_channels'], 2)
stream_in = sd.InputStream(
device=device_in,
samplerate=model.sample_rate,
channels=channels_in)
device_out = parse_audio_device(args.out)
caps = query_devices(device_out, "output")
channels_out = min(caps['max_output_channels'], 2)
stream_out = sd.OutputStream(
device=device_out,
samplerate=model.sample_rate,
channels=channels_out)
stream_in.start()
stream_out.start()
first = True
current_time = 0
last_log_time = 0
last_error_time = 0
cooldown_time = 2
log_delta = 10
sr_ms = model.sample_rate / 1000
stride_ms = streamer.stride / sr_ms
print(f"Ready to process audio, total lag: {streamer.total_length / sr_ms:.1f}ms.")
while True:
try:
if current_time > last_log_time + log_delta:
last_log_time = current_time
tpf = streamer.time_per_frame * 1000
rtf = tpf / stride_ms
print(f"time per frame: {tpf:.1f}ms, ", end='')
print(f"RTF: {rtf:.1f}")
streamer.reset_time_per_frame()
length = streamer.total_length if first else streamer.stride
first = False
current_time += length / model.sample_rate
frame, overflow = stream_in.read(length)
frame = torch.from_numpy(frame).mean(dim=1).to(args.device)
with torch.no_grad():
out = streamer.feed(frame[None])[0]
if not out.numel():
continue
if args.compressor:
out = 0.99 * torch.tanh(out)
out = out[:, None].repeat(1, channels_out)
mx = out.abs().max().item()
if mx > 1:
print("Clipping!!")
out.clamp_(-1, 1)
out = out.cpu().numpy()
underflow = stream_out.write(out)
if overflow or underflow:
if current_time >= last_error_time + cooldown_time:
last_error_time = current_time
tpf = 1000 * streamer.time_per_frame
print(f"Not processing audio fast enough, time per frame is {tpf:.1f}ms "
f"(should be less than {stride_ms:.1f}ms).")
except KeyboardInterrupt:
print("Stopping")
break
stream_out.stop()
stream_in.stop()
if __name__ == "__main__":
main()