Skip to content

Commit

Permalink
avoid using torch_tensor twice
Browse files Browse the repository at this point in the history
  • Loading branch information
jwijffels committed Mar 18, 2024
1 parent 859c261 commit 93caa8c
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ URL: https://github.com/bnosac/audio.vadsilero
Encoding: UTF-8
Imports:
torch,
wav
audio
Suggests:
av
RoxygenNote: 7.1.2
3 changes: 2 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

S3method(print,VAD)
export(silero)
importFrom(audio,load.wave)
importFrom(torch,autograd_set_grad_mode)
importFrom(torch,jit_load)
importFrom(torch,jit_scalar)
importFrom(torch,torch_tensor)
importFrom(wav,read_wav)
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## CHANGES IN audio.vadsilero VERSION 0.2

- disable the gradient history recording & set model in evaluation (inference) mode
- replace package wav with package audio for reading in the wav file

## CHANGES IN audio.vadsilero VERSION 0.1

Expand Down
2 changes: 1 addition & 1 deletion R/pkg.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#' @importFrom wav read_wav
#' @importFrom audio load.wave
#' @importFrom torch torch_tensor jit_load jit_scalar autograd_set_grad_mode
NULL
16 changes: 10 additions & 6 deletions R/silero.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,19 @@ SILERO <- function(){
}

predict.SILERO <- function(object, newdata, sample_rate, milliseconds, window = milliseconds * (sample_rate / 1000), threshold = 0.5){
sound <- audio::load.wave(newdata)
sample_rate <- attr(sound, which = "rate")
n_samples <- length(sound)
sample_rate <- torch::jit_scalar(as.integer(sample_rate))

if(!sample_rate %in% c(8000, 16000)){
stop("sample_rate should be 8000 or 16000")
}
if(!window %in% c(256, 512, 768, 1024, 1536)){
stop("Unknown combination of milliseconds and sample_rate")
}
sound <- wav::read_wav(newdata)
sound <- sound[1, ]

test <- torch::torch_tensor(sound)
n_samples <- length(sound)
sample_rate <- torch::jit_scalar(sample_rate)

elements <- seq.int(from = 1, to = n_samples, by = window)
out <- numeric(length = length(elements))
Expand All @@ -121,10 +123,12 @@ predict.SILERO <- function(object, newdata, sample_rate, milliseconds, window =
if((elements[i]+window-1) > n_samples){
samples <- sound[elements[i]:length(sound)]
samples <- c(samples, rep(as.numeric(0), times = window - length(samples)))
out[i] <- as.numeric(object$model$forward(torch::torch_tensor(samples), sr = sample_rate))
samples <- torch::torch_tensor(samples)
out[i] <- as.numeric(object$model$forward(samples, sr = sample_rate))
}else{
samples <- test[elements[i]:(elements[i]+window-1)]
samples <- torch::torch_tensor(samples)
#samples <- torch::torch_tensor(samples)
#print(str(samples))
out[i] <- as.numeric(object$model$forward(samples, sr = sample_rate))
}
}
Expand Down
5 changes: 3 additions & 2 deletions man/silero.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 93caa8c

Please sign in to comment.