-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8a58741
commit 56ab997
Showing
1 changed file
with
108 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
""" | ||
This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model. | ||
""" | ||
|
||
import torch | ||
import librosa | ||
import numpy as np | ||
import argparse | ||
from transformers import WavLMForSequenceClassification | ||
|
||
|
||
def feature_extract_simple( | ||
wav, | ||
sr=16_000, | ||
win_len=15.0, | ||
win_stride=15.0, | ||
do_normalize=False, | ||
): | ||
"""simple feature extraction for wavLM | ||
Parameters | ||
---------- | ||
wav : str or array-like | ||
path to the wav file, or array-like | ||
sr : int, optional | ||
sample rate, by default 16_000 | ||
win_len : float, optional | ||
window length, by default 15.0 | ||
win_stride : float, optional | ||
window stride, by default 15.0 | ||
do_normalize: bool, optional | ||
whether to normalize the input, by default False. | ||
Returns | ||
------- | ||
np.ndarray | ||
batched input to wavLM | ||
""" | ||
if type(wav) == str: | ||
signal, _ = librosa.core.load(wav, sr=sr) | ||
else: | ||
try: | ||
signal = np.array(wav).squeeze() | ||
except Exception as e: | ||
print(e) | ||
raise RuntimeError | ||
batched_input = [] | ||
stride = int(win_stride * sr) | ||
l = int(win_len * sr) | ||
if len(signal) / sr > win_len: | ||
for i in range(0, len(signal), stride): | ||
if i + int(win_len * sr) > len(signal): | ||
# padding the last chunk to make it the same length as others | ||
chunked = np.pad(signal[i:], (0, l - len(signal[i:]))) | ||
else: | ||
chunked = signal[i : i + l] | ||
if do_normalize: | ||
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7) | ||
batched_input.append(chunked) | ||
if i + int(win_len * sr) > len(signal): | ||
break | ||
else: | ||
if do_normalize: | ||
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7) | ||
batched_input.append(signal) | ||
return np.stack(batched_input) # [N, T] | ||
|
||
|
||
def infer(model, inputs): | ||
output = model(inputs) | ||
probs = torch.sigmoid(torch.Tensor(output.logits)) | ||
return probs | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--audio_file", | ||
type=str, | ||
help="File to run inference", | ||
) | ||
parser.add_argument( | ||
"--model_path", | ||
type=str, | ||
default="roblox/voice-safety-classifier", | ||
help="checkpoint file of model", | ||
) | ||
args = parser.parse_args() | ||
labels_name_list = [ | ||
"Profanity", | ||
"DatingAndSexting", | ||
"Racist", | ||
"Bullying", | ||
"Other", | ||
"NoViolation", | ||
] | ||
# Model is trained on only 16kHz audio | ||
audio, _ = librosa.core.load(args.audio_file, sr=16000) | ||
input_np = feature_extract_simple(audio, sr=16000) | ||
input_pt = torch.Tensor(input_np) | ||
model = WavLMForSequenceClassification.from_pretrained( | ||
args.model_path, num_labels=len(labels_name_list) | ||
) | ||
probs = infer(model, input_pt) | ||
probs = probs.reshape(-1, 6).detach().tolist() | ||
print(f"Probabilities for {args.audio_file} is:") | ||
for chunk_idx in range(len(probs)): | ||
print(f"\nSegment {chunk_idx}:") | ||
for label_idx, label in enumerate(labels_name_list): | ||
print(f"{label} : {probs[chunk_idx][label_idx]}") |