Skip to content

Commit

Permalink
- Default to 1 thread to use as multithreading slows down on CPU
Browse files Browse the repository at this point in the history
- Use sapply instead of for loop to loop over the windowed samples inside a with_no_grad chunk
- Function silero no longer requires to provide the sample_rate, this is now extracted using audio::load.wave
- Factor out the use of package av to only the examples - internally replaced with package audio
  • Loading branch information
jwijffels committed Mar 18, 2024
1 parent 93caa8c commit 7bbb4d2
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 43 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ importFrom(audio,load.wave)
importFrom(torch,autograd_set_grad_mode)
importFrom(torch,jit_load)
importFrom(torch,jit_scalar)
importFrom(torch,torch_float)
importFrom(torch,torch_set_num_threads)
importFrom(torch,torch_tensor)
importFrom(torch,with_no_grad)
8 changes: 6 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
## CHANGES IN audio.vadsilero VERSION 0.2

- disable the gradient history recording & set model in evaluation (inference) mode
- replace package wav with package audio for reading in the wav file
- Disable the gradient history recording & set model in evaluation (inference) mode
- Replace package wav with package audio for reading in the wav file
- Default to 1 thread to use as multithreading slows down on CPU
- Use sapply instead of for loop to loop over the windowed samples inside a with_no_grad chunk
- Function silero no longer requires to provide the sample_rate, this is now extracted using audio::load.wave
- Factor out the use of package av to only the examples - internally replaced with package audio

## CHANGES IN audio.vadsilero VERSION 0.1

Expand Down
2 changes: 1 addition & 1 deletion R/pkg.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#' @importFrom audio load.wave
#' @importFrom torch torch_tensor jit_load jit_scalar autograd_set_grad_mode
#' @importFrom torch torch_tensor jit_load jit_scalar autograd_set_grad_mode torch_float with_no_grad torch_set_num_threads
NULL
90 changes: 53 additions & 37 deletions R/silero.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#' It works with .wav audio files with a sample rate of 8 or 16 Khz an can be applied over a window of eiher 32, 64 or 96 milliseconds.
#' @param file the path to an audio file which should be a wav file in 16 bit with mono PCM samples (pcm_s16le codec) with a sampling rate of either 8Khz or 16KHz
#' @param milliseconds integer with the number of milliseconds indicating to compute by this number of milliseconds the VAD signal. Can only be 32, 64 or 96 Defaults to 64.
#' @param sample_rate integer with the sample rate of \code{file}. If not provided, will use package av to extract it.
#' @param threshold numeric indicating if the probability is above this threshold, the segment is detected as voiced. Defaults to 0.5
#' @param threads integer with the number of threads to use, which is passed on to \code{\link[torch]{torch_set_num_threads}}. Defaults to 1.
#' @return an object of class \code{VAD} which is a list with elements
#' \itemize{
#' \item{file: the path to the file}
Expand Down Expand Up @@ -32,42 +32,42 @@
#' plot(vad$vad$millisecond, vad$vad$probability, type = "l",
#' xlab = "Millisecond", ylab = "Probability voiced")
#'
#' \dontrun{
#' library(av)
#' x <- read_audio_bin(file)
#' library(audio)
#' x <- load.wave(file)
#' plot(seq_along(x) / 16000, x, type = "l")
#' abline(v = vad$vad_segments$start, col = "red", lwd = 2)
#' abline(v = vad$vad_segments$end, col = "blue", lwd = 2)
#'
#' \dontrun{
#' ##
#' ## If you have audio which is not in mono or another sample rate
#' ## consider using R package av to convert to the desired format
#' library(av)
#' av_media_info(file)
#' av_audio_convert(file, output = "audio_pcm_16khz.wav",
#' format = "wav", channels = 1, sample_rate = 16000)
#' vad <- silero("audio_pcm_16khz.wav", milliseconds = 64)
#' }
silero <- function(file,
milliseconds = 64,
sample_rate,
threshold = 0.5){
threshold = 0.5,
threads = 1L){
stopifnot(file.exists(file))
if(requireNamespace(package = "av", quietly = TRUE)){
info <- av::av_media_info(file)
if(info$audio$channels != 1){
stop(sprintf("%s does not contain audio in mono", file))
}
if(missing(sample_rate)){
sample_rate <- info$audio$sample_rate
}
}
milliseconds <- as.integer(milliseconds)
stopifnot(milliseconds %in% c(32L, 64L, 96L))
sound <- audio::load.wave(file)
sample_rate <- attr(sound, which = "rate")
sample_rate <- as.integer(sample_rate)
if(is.matrix(sound)){
stop(sprintf("%s does not contain audio in mono", file))
}
if(!sample_rate %in% c(8000L, 16000L)){
stop(sprintf("%s should be in 8000Hz or 16000Hz, not in %s Hz", file, sample_rate))
}
torch_set_num_threads(threads)

model <- SILERO()
msg <- predict.SILERO(model, file, sample_rate = sample_rate, milliseconds = milliseconds, threshold = threshold)
msg <- predict.SILERO(model, sound, file = file, sample_rate = sample_rate, milliseconds = milliseconds, threshold = threshold)

## Get groups of sequences of voice/non-voice
grp <- rle(msg$vad$has_voice)
Expand Down Expand Up @@ -101,9 +101,7 @@ SILERO <- function(){
out
}

predict.SILERO <- function(object, newdata, sample_rate, milliseconds, window = milliseconds * (sample_rate / 1000), threshold = 0.5){
sound <- audio::load.wave(newdata)
sample_rate <- attr(sound, which = "rate")
predict.SILERO <- function(object, sound, file = "", sample_rate, milliseconds, window = milliseconds * (sample_rate / 1000), threshold = 0.5){
n_samples <- length(sound)
sample_rate <- torch::jit_scalar(as.integer(sample_rate))

Expand All @@ -113,31 +111,49 @@ predict.SILERO <- function(object, newdata, sample_rate, milliseconds, window =
if(!window %in% c(256, 512, 768, 1024, 1536)){
stop("Unknown combination of milliseconds and sample_rate")
}

test <- torch::torch_tensor(sound)

elements <- seq.int(from = 1, to = n_samples, by = window)
out <- numeric(length = length(elements))
for(i in seq_along(elements)){
#cat(i, sep = "\n")
if((elements[i]+window-1) > n_samples){
samples <- sound[elements[i]:length(sound)]
samples <- c(samples, rep(as.numeric(0), times = window - length(samples)))
samples <- torch::torch_tensor(samples)
out[i] <- as.numeric(object$model$forward(samples, sr = sample_rate))
}else{
samples <- test[elements[i]:(elements[i]+window-1)]
#samples <- torch::torch_tensor(samples)
#print(str(samples))
out[i] <- as.numeric(object$model$forward(samples, sr = sample_rate))
}
}

# test <- torch::torch_tensor(sound)
# for(i in seq_along(elements)){
# #cat(i, sep = "\n")
# if((elements[i]+window-1) > n_samples){
# samples <- sound[elements[i]:length(sound)]
# samples <- c(samples, rep(as.numeric(0), times = window - length(samples)))
# samples <- torch::torch_tensor(samples)
# out[i] <- as.numeric(object$model$forward(samples, sr = sample_rate))
# }else{
# samples <- test[elements[i]:(elements[i]+window-1)]
# #samples <- torch::torch_tensor(samples)
# #print(str(samples))
# out[i] <- as.numeric(object$model$forward(samples, sr = sample_rate))
# }
# }

samples <- torch::torch_tensor(rep(0, times = window), dtype = torch::torch_float())
with_no_grad({
out <- sapply(seq_along(elements), FUN = function(i){
if((elements[i]+window-1) > n_samples){
samples <- sound[elements[i]:length(sound)]
samples <- c(samples, rep(as.numeric(0), times = window - length(samples)))
samples <- torch::torch_tensor(samples, dtype = torch::torch_float())
#samples[] <- samples
score <- object$model$forward(samples, sr = sample_rate)
}else{
samples[] <- sound[elements[i]:(elements[i]+window-1)]
score <- object$model$forward(samples, sr = sample_rate)
}
as.numeric(score)
}, USE.NAMES = FALSE)
})


sample_rate <- as.integer(sample_rate)
vad <- data.frame(millisecond = elements, probability = out, stringsAsFactors = FALSE)
vad$has_voice <- ifelse(vad$probability > threshold, TRUE, FALSE)
vad$millisecond <- as.integer(vad$millisecond / (sample_rate / 1000))
msg <- list(
file= newdata,
file = file,
sample_rate = sample_rate,
channels = 1L,
samples = n_samples,
Expand Down
12 changes: 9 additions & 3 deletions man/silero.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7bbb4d2

Please sign in to comment.