Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pure torch implementation #433

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 51 additions & 25 deletions DeepFilterNet/df/deepfilternet3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from loguru import logger
from torch import Tensor, nn
import torch.nn.functional as F

import df.multiframe as MF
from df.config import Csv, DfParams, config
Expand Down Expand Up @@ -49,6 +50,9 @@ def __init__(self):
self.emb_num_layers: int = config(
"EMB_NUM_LAYERS", cast=int, default=2, section=self.section
)
self.emb_gru_skip_enc: str = config(
"EMB_GRU_SKIP_ENC", default="none", section=self.section
)
self.emb_gru_skip: str = config("EMB_GRU_SKIP", default="none", section=self.section)
self.df_hidden_dim: int = config(
"DF_HIDDEN_DIM", cast=int, default=256, section=self.section
Expand Down Expand Up @@ -99,6 +103,9 @@ def __init__(self):
self.erb_conv0 = Conv2dNormAct(
1, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True
)
self.conv_buffer_size = p.conv_kernel_inp[0] - 1
self.conv_ch = p.conv_ch

conv_layer = partial(
Conv2dNormAct,
in_ch=p.conv_ch,
Expand All @@ -110,8 +117,9 @@ def __init__(self):
self.erb_conv1 = conv_layer(fstride=2)
self.erb_conv2 = conv_layer(fstride=2)
self.erb_conv3 = conv_layer(fstride=1)
self.df_conv0_ch = p.conv_ch
self.df_conv0 = Conv2dNormAct(
2, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True
2, self.df_conv0_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True
)
self.df_conv1 = conv_layer(fstride=2)
self.erb_bins = p.nb_erb
Expand All @@ -128,13 +136,27 @@ def __init__(self):
else:
self.combine = Add()
self.emb_n_layers = p.emb_num_layers
if p.emb_gru_skip_enc == "none":
skip_op = None
elif p.emb_gru_skip_enc == "identity":
assert self.emb_in_dim == self.emb_out_dim, "Dimensions do not match"
skip_op = partial(nn.Identity)
elif p.emb_gru_skip_enc == "groupedlinear":
skip_op = partial(
GroupedLinearEinsum,
input_size=self.emb_out_dim,
hidden_size=self.emb_out_dim,
groups=p.lin_groups,
)
else:
raise NotImplementedError()
self.emb_gru = SqueezedGRU_S(
self.emb_in_dim,
self.emb_dim,
output_size=self.emb_out_dim,
num_layers=1,
batch_first=True,
gru_skip_op=None,
gru_skip_op=skip_op,
linear_groups=p.lin_groups,
linear_act_layer=partial(nn.ReLU, inplace=True),
)
Expand All @@ -143,8 +165,8 @@ def __init__(self):
self.lsnr_offset = p.lsnr_min

def forward(
self, feat_erb: Tensor, feat_spec: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
self, feat_erb: Tensor, feat_spec: Tensor, hidden: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
# Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands.
# erb: [B, 1, T, Fe]
# spec: [B, 2, T, Fc]
Expand All @@ -154,14 +176,14 @@ def forward(
e2 = self.erb_conv2(e1) # [B, C*4, T, F/4]
e3 = self.erb_conv3(e2) # [B, C*4, T, F/4]
c0 = self.df_conv0(feat_spec) # [B, C, T, Fc]
c1 = self.df_conv1(c0) # [B, C*2, T, Fc]
c1 = self.df_conv1(c0) # [B, C*2, T, Fc/2]
cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1]
cemb = self.df_fc_emb(cemb) # [T, B, C * F/4]
emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F/4]
emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F]
emb = self.combine(emb, cemb)
emb, _ = self.emb_gru(emb) # [B, T, -1]
emb, hidden = self.emb_gru(emb, hidden) # [B, T, -1]
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
return e0, e1, e2, e3, emb, c0, lsnr
return e0, e1, e2, e3, emb, c0, lsnr, hidden


class ErbDecoder(nn.Module):
Expand Down Expand Up @@ -221,16 +243,16 @@ def __init__(self):
p.conv_ch, 1, kernel_size=p.conv_kernel, activation_layer=nn.Sigmoid
)

def forward(self, emb, e3, e2, e1, e0) -> Tensor:
def forward(self, emb: Tensor, e3: Tensor, e2: Tensor, e1: Tensor, e0: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]:
# Estimates erb mask
b, _, t, f8 = e3.shape
emb, _ = self.emb_gru(emb)
emb, hidden = self.emb_gru(emb, hidden)
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8]
e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4]
e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2]
e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F]
m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F]
return m
return m, hidden


class DfOutputReshapeMF(nn.Module):
Expand Down Expand Up @@ -271,6 +293,7 @@ def __init__(self):

conv_layer = partial(Conv2dNormAct, separable=True, bias=False)
kt = p.df_pathway_kernel_size_t
self.conv_buffer_size = kt - 1
self.df_convp = conv_layer(layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1))

self.df_gru = SqueezedGRU_S(
Expand Down Expand Up @@ -299,16 +322,15 @@ def __init__(self):
self.df_out = nn.Sequential(df_out, nn.Tanh())
self.df_fc_a = nn.Sequential(nn.Linear(self.df_n_hidden, 1), nn.Sigmoid())

def forward(self, emb: Tensor, c0: Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, emb: Tensor, c0: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]:
b, t, _ = emb.shape
c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden
c, hidden = self.df_gru(emb, hidden) # [B, T, H], H: df_n_hidden
if self.df_skip is not None:
c = c + self.df_skip(emb)
c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last
alpha = self.df_fc_a(c) # [B, T, 1]
c = self.df_out(c) # [B, T, F*O*2], O: df_order
c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2]
return c, alpha
return c, hidden


class DfNet(nn.Module):
Expand All @@ -333,12 +355,12 @@ def __init__(
self.emb_dim: int = layer_width * p.nb_erb
self.erb_bins: int = p.nb_erb
if p.conv_lookahead > 0:
assert p.conv_lookahead == p.df_lookahead
assert p.conv_lookahead >= p.df_lookahead
self.pad_feat = nn.ConstantPad2d((0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0)
else:
self.pad_feat = nn.Identity()
if p.df_lookahead > 0:
self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -p.df_lookahead, p.df_lookahead), 0.0)
self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, p.df_lookahead - 1, -p.df_lookahead + 1), 0.0)
else:
self.pad_spec = nn.Identity()
self.register_buffer("erb_fb", erb_fb)
Expand Down Expand Up @@ -381,9 +403,11 @@ def forward(
"""
feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)

feat_erb = self.pad_feat(feat_erb)
feat_spec = self.pad_feat(feat_spec)
e0, e1, e2, e3, emb, c0, lsnr = self.enc(feat_erb, feat_spec)
# feat_erb = self.pad_feat(feat_erb)
# feat_spec = self.pad_feat(feat_spec)
spec = self.pad_spec(spec)

e0, e1, e2, e3, emb, c0, lsnr, _ = self.enc(feat_erb, feat_spec, hidden=None)

if self.lsnr_droput:
idcs = lsnr.squeeze() > -10.0
Expand All @@ -400,19 +424,21 @@ def forward(

if self.run_erb:
if self.lsnr_droput:
m[:, :, idcs] = self.erb_dec(emb, e3, e2, e1, e0)
m[:, :, idcs], _ = self.erb_dec(emb, e3, e2, e1, e0, hidden=None)
else:
m = self.erb_dec(emb, e3, e2, e1, e0)
spec_m = self.mask(spec, m)
m, _ = self.erb_dec(emb, e3, e2, e1, e0, hidden=None)

pad_spec = F.pad(spec, (0, 0, 0, 0, 1, -1, 0, 0), value=0)
spec_m = self.mask(pad_spec, m)
else:
m = torch.zeros((), device=spec.device)
spec_m = torch.zeros_like(spec)

if self.run_df:
if self.lsnr_droput:
df_coefs[:, idcs] = self.df_dec(emb, c0)[0]
df_coefs[:, idcs], _ = self.df_dec(emb, c0, hidden=None)
else:
df_coefs = self.df_dec(emb, c0)[0]
df_coefs, _ = self.df_dec(emb, c0, hidden=None)
df_coefs = self.df_out_transform(df_coefs)
spec = self.df_op(spec, df_coefs)
spec[..., self.nb_df :, :] = spec_m[..., self.nb_df :, :]
Expand Down
2 changes: 1 addition & 1 deletion DeepFilterNet/df/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def __init__(
else:
self.linear_out = nn.Identity()

def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]:
def forward(self, input: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]:
x = self.linear_in(input)
x, h = self.gru(x, h)
x = self.linear_out(x)
Expand Down
2 changes: 1 addition & 1 deletion DeepFilterNet/df/multiframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, real: bo
if real:
self.pad = nn.ConstantPad3d((0, 0, 0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
else:
self.pad = nn.ConstantPad2d((0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
self.pad = nn.ConstantPad2d((0, 0, frame_size - lookahead, lookahead - 1), 0.0)
self.need_unfold = frame_size > 1
self.lookahead = lookahead

Expand Down
17 changes: 12 additions & 5 deletions libDF/src/tract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,14 @@ impl DfTract {
}
}

pub fn set_spec_buffer(&mut self, spec: ArrayView2<f32>) -> Result<()> {
debug_assert_eq!(self.spec_buf.shape(), spec.shape());
let mut buf = self.spec_buf.to_array_view_mut()?.into_shape([self.ch, self.n_freqs])?;
pub fn set_spec_buffer(&mut self, spec: ArrayView2<Complex32>) -> Result<()> {
let mut buf = as_arrayview_mut_complex(
self.spec_buf.to_array_view_mut::<f32>().unwrap(),
&[self.ch, self.n_freqs],
);

debug_assert_eq!(buf.shape(), spec.shape());

for (i_ch, mut b_ch) in spec.outer_iter().zip(buf.outer_iter_mut()) {
for (&i, b) in i_ch.iter().zip(b_ch.iter_mut()) {
*b = i
Expand Down Expand Up @@ -645,7 +650,7 @@ fn df(
debug_assert_eq!(ch, spec_out.shape()[0]);
debug_assert!(spec.len() >= df_order);
let mut o_f: ArrayViewMut2<Complex32> =
as_array_mut_complex(spec_out.to_array_view_mut::<f32>()?, &[ch, n_freqs])
as_arrayview_mut_complex(spec_out.to_array_view_mut::<f32>()?, &[ch, n_freqs])
.into_dimensionality()?;
// Zero relevant frequency bins of output
o_f.slice_mut(s![.., ..nb_df]).fill(Complex32::default());
Expand Down Expand Up @@ -913,6 +918,7 @@ pub fn as_slice_complex(buffer: &[f32]) -> &[Complex32] {
}
}

#[allow(clippy::needless_pass_by_ref_mut)]
pub fn as_slice_mut_complex(buffer: &mut [f32]) -> &mut [Complex32] {
unsafe {
let ptr = buffer.as_ptr() as *mut Complex32;
Expand All @@ -921,6 +927,7 @@ pub fn as_slice_mut_complex(buffer: &mut [f32]) -> &mut [Complex32] {
}
}

#[allow(clippy::needless_pass_by_ref_mut)]
pub fn as_slice_mut_real(buffer: &mut [Complex32]) -> &mut [f32] {
unsafe {
let ptr = buffer.as_ptr() as *mut f32;
Expand Down Expand Up @@ -960,7 +967,7 @@ pub fn as_arrayview_complex<'a>(
ArrayViewD::from_shape_ptr(shape, ptr)
}
}
pub fn as_array_mut_complex<'a>(
pub fn as_arrayview_mut_complex<'a>(
buffer: ArrayViewMutD<'a, f32>,
shape: &[usize], // having an explicit shape parameter allows to also squeeze axes.
) -> ArrayViewMutD<'a, Complex32> {
Expand Down
8 changes: 7 additions & 1 deletion pyDF/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ crate-type = ["cdylib"]
path = "src/lib.rs"

[dependencies]
deep_filter = { features = ["transforms", "logging"], path = "../libDF" }
deep_filter = { path = "../libDF", default-features = false, features = [
"tract",
"use-jemalloc",
"default-model",
] }

pyo3 = { version = "0.19", features = ["extension-module"]}
numpy = "0.19"
ndarray = "^0.15"
tract-core = "^0.19.4"
38 changes: 36 additions & 2 deletions pyDF/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,50 @@ use df::transforms::{
TransformError,
};
use df::{Complex32, DFState, UNIT_NORM_INIT};

use df::tract::*;

use ndarray::{Array1, Array2, Array3, Array4, ArrayD, ArrayView4, Axis, ShapeError};
use numpy::{
IntoPyArray, PyArray1, PyArray2, PyArray3, PyArrayDyn, PyReadonlyArray1, PyReadonlyArray2,
PyReadonlyArray3, PyReadonlyArrayDyn,
PyReadonlyArray3, PyReadonlyArrayDyn
};
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;


#[pyclass]
struct DF {
state: DFState,
state: DFState,
}

#[pyclass(unsendable)]
struct DFTractPy {
tract: DfTract,
}

#[pymethods]
#[allow(clippy::upper_case_acronyms)]
impl DFTractPy {
#[new]
fn new() -> Self {
DFTractPy {
tract: Default::default()
}
}

unsafe fn process<'py>(
&mut self,
py: Python<'py>,
input: &PyArray2<f32>
) -> PyResult<&'py PyArray2<f32>> {
let channels = input.shape()[0];
let mut output = Array2::<f32>::zeros((channels, self.tract.hop_size));
let input = input.as_array(); // unsafe fn
let _ = self.tract.process(input, output.view_mut()).expect("Error during df::process");

Ok(output.into_pyarray(py))
}
}

#[pymethods]
Expand Down Expand Up @@ -138,6 +171,7 @@ impl DF {
#[pymodule]
fn libdf(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<DF>()?;
m.add_class::<DFTractPy>()?;

#[pyfn(m)]
#[pyo3(name = "erb")]
Expand Down
49 changes: 49 additions & 0 deletions torchDF/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# TorchDF

Commit against which the comparison was made - https://github.com/Rikorose/DeepFilterNet/commit/ca46bf54afaf8ace3272aaee5931b4317bd6b5f4

Installation:
```
cd path/to/DeepFilterNet/
pip install maturin poetry poethepoet
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
maturin build --release -m pyDF/Cargo.toml

cd DeepFilterNet
export PYTHONPATH=$PWD

cd ../torchDF
poetry install
poe install-torch-cpu
```

Here is presented offline and streaming implementation of DeepFilterNet3 on pure torch. Streaming model can be fully exported to ONNX using `model_onnx_export.py`.

Every script and test have to run inside poetry enviroment.

To run tests:
```
poetry run python -m pytest -v
```
We compare this model to existing `enhance` method (which is partly written on Rust) and tract model (which is purely on Rust). All tests are passing, so model is working.

To enhance audio using streaming implementation:
```
poetry run python torch_df_streaming_minimal.py --audio-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --output-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary_enhanced.wav
```

To convert model to onnx and run tests:
```
poetry run python model_onnx_export.py --test --performance --inference-path examples/A1CIM28ZUCA8RX_M_Street_Near_Regular_SP_Mobile_Primary.wav --ort
```

TODO:
* Issues about split + simplify
* Thinkging of offline method exportability + compatability with streaming functions
* torch.where(..., ..., 0) export issue
* dynamo.export check
* thinking of torchDF naming
* rfft hacks tests
* torch.nonzero thinking
* rfft nn.module
* more static methods
Binary file not shown.
Loading