Skip to content

Commit

Permalink
Create inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kelpabc123 authored Mar 12, 2024
1 parent 8a58741 commit 56ab997
Showing 1 changed file with 108 additions and 0 deletions.
108 changes: 108 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model.
"""

import torch
import librosa
import numpy as np
import argparse
from transformers import WavLMForSequenceClassification


def feature_extract_simple(
wav,
sr=16_000,
win_len=15.0,
win_stride=15.0,
do_normalize=False,
):
"""simple feature extraction for wavLM
Parameters
----------
wav : str or array-like
path to the wav file, or array-like
sr : int, optional
sample rate, by default 16_000
win_len : float, optional
window length, by default 15.0
win_stride : float, optional
window stride, by default 15.0
do_normalize: bool, optional
whether to normalize the input, by default False.
Returns
-------
np.ndarray
batched input to wavLM
"""
if type(wav) == str:
signal, _ = librosa.core.load(wav, sr=sr)
else:
try:
signal = np.array(wav).squeeze()
except Exception as e:
print(e)
raise RuntimeError
batched_input = []
stride = int(win_stride * sr)
l = int(win_len * sr)
if len(signal) / sr > win_len:
for i in range(0, len(signal), stride):
if i + int(win_len * sr) > len(signal):
# padding the last chunk to make it the same length as others
chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
else:
chunked = signal[i : i + l]
if do_normalize:
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
batched_input.append(chunked)
if i + int(win_len * sr) > len(signal):
break
else:
if do_normalize:
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
batched_input.append(signal)
return np.stack(batched_input) # [N, T]


def infer(model, inputs):
output = model(inputs)
probs = torch.sigmoid(torch.Tensor(output.logits))
return probs


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_file",
type=str,
help="File to run inference",
)
parser.add_argument(
"--model_path",
type=str,
default="roblox/voice-safety-classifier",
help="checkpoint file of model",
)
args = parser.parse_args()
labels_name_list = [
"Profanity",
"DatingAndSexting",
"Racist",
"Bullying",
"Other",
"NoViolation",
]
# Model is trained on only 16kHz audio
audio, _ = librosa.core.load(args.audio_file, sr=16000)
input_np = feature_extract_simple(audio, sr=16000)
input_pt = torch.Tensor(input_np)
model = WavLMForSequenceClassification.from_pretrained(
args.model_path, num_labels=len(labels_name_list)
)
probs = infer(model, input_pt)
probs = probs.reshape(-1, 6).detach().tolist()
print(f"Probabilities for {args.audio_file} is:")
for chunk_idx in range(len(probs)):
print(f"\nSegment {chunk_idx}:")
for label_idx, label in enumerate(labels_name_list):
print(f"{label} : {probs[chunk_idx][label_idx]}")

0 comments on commit 56ab997

Please sign in to comment.