Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pure torch implementation #433

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

Conversation

grazder
Copy link
Contributor

@grazder grazder commented Sep 20, 2023

I've created this issue about pure torch reimplementation - #430

Sharing code. This is draft PR, so right now work still in progress, and i can make some changes later.
You can find my implementation in folder torchDF. Also there is a README.md there with some details.

I'll be glad to hear your feedback.

Also there are some changes in deepfilternet3.py, modules.py, multiframe.py. It was necessary to reach 100% compatibility of streaming tract model and offline enhance method.

This fork was created based on ca46bf5. Therefore, some code may be a little outdated.

Offline model torch implementation in torchDF/torch_df_offline.py.

Streaming model torch implementation in torchDF/torch_df_streaming.py

To convert streaming model to onnx you can use torchDF/model_onnx_export.py

@dingchaoyue
Copy link

Thank you for your great work!

In the code [torch_df_offline.py], sample_rate is set to 48000,
and self.erb_indices = torch.tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 5, 7, 7, 8, 10, 12, 13 , 15, 18, 20, 24, 28, 31, 37, 42, 50, 56, 67])

When sample_rate=16000, what should self.erb_indices be set to?

@grazder
Copy link
Contributor Author

grazder commented Oct 7, 2023

@dingchaoyue

here is original code, you can calculate feature bank using it

pub fn erb_fb(sr: usize, fft_size: usize, nb_bands: usize, min_nb_freqs: usize) -> Vec<usize> {

@dingchaoyue
Copy link

Thank you. Problem solved

@Rikorose
Copy link
Owner

Hi, thanks for this PR! I would be interested in merging it. A few things that I want to discuss:

  1. Only keep things related to the pytorch real-time path in here. Make separat PRs for other stuff, e.g. jemalloc for pyDF etc.
  2. If possible, I would like not introduce another python wheel. Do you see any option, to include this version as a tool.poetry.script within the DeepFilterNet wheel?
  3. Can you add a test maybe here?

@grazder
Copy link
Contributor Author

grazder commented Oct 19, 2023

  1. I'm using the pyDF changes to run a test to compare the original Rust and Torch threading implementations. What do you think about the tests then? Do I need to remove these test then? Do we need to keep this test in this PR?
  2. Yeah, I see this implementation more for users understanding. We can not introduce this as an another wheel. I can try to add it as tool.poetry.script and figure out how to make it better. Can you describe the logic behind it and how you use tool.poetry.script right now?
  3. I can add it here, yes. But we need to decide what to do with tests (as i described in point 1)

@StuartIanNaylor
Copy link

Hey if anyone can recreate the LADSPA plugin with Onnx, please do as the single thread of tract needs a pretty big single core.

@Rikorose
Copy link
Owner

A few notes:
1.1 Don't remove existing depending features in pyDF (logging, transforms)
1.2 Don't add features in this PR that is unrelated (jemalloc)
1.3 Hide tract dependency behind a feature flag. This should not be compiled by default

[features]
tract = ["deep_filter/tract", "deep_filter/default-model"]

However, I am not sure if there is a way to add this as an optional dependency to the deepfilterlib dependency in the pyproject file. Because, by default, the tract dependency is not necessary to compile for a standard use case (e.g. pytorch inference or pytorch training). But can you add the deepfilterlib without tract as it is currently, and only add deepfilterlib[tract] e.g. only when DeepFilterNet is installed with an additional tract feature?
1.4 is tract_core really necessary? It should be pulled via libDF/tract
2.1 You can just add an additional readme_py_rt.md or so in DeepFilterNet and create a script at DeepFilterNet/scripts. Then you can add this python script to the tool.poetry.script section. You can move the other scripts e.g. model_onnx_export.py also to the scripts folder. But we need to find a solution for the tract dependency of deepfilterlib.
2.2 A different option would be, to create a submodule e.g. at DeepFilterNet/py_rt/. Maybe this is more convenient.
4. Remove the .wav example. Please use the existing samples at assets/noisy_snr0.wav

@grazder
Copy link
Contributor Author

grazder commented Oct 31, 2023

@Rikorose
I saw your comment, right not I have more important tasks now. But I'll come back in a while and fix it.

@hulucky1102
Copy link

Hello author, when I build the model, TypeError: forward() takes 3 positional arguments but 4 were given. This is caused by no hidden parameters. How can I solve it? I don’t know where the enc module code loaded there is located, encoder Adding hidden vectors to the module still cannot solve the problem. The hidden state is not loaded during training.
image

@grazder
Copy link
Contributor Author

grazder commented Jan 3, 2024

Hello! Can you provide more details? What code are you running?

Encoder in this reimplementation takes 4 parameters

https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/DeepFilterNet/df/deepfilternet3.py#L168

https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/torch_df_streaming.py#L492

self, feat_erb: Tensor, feat_spec: Tensor, hidden: Tensor

Btw this code for inference only, i didn't do anything with training.

Also I didn't commit in this branch for a long time, so nothing should break if you running code with no changes.

What do you mean by "build"? Do tests passing? You can watch here how to run tests

@hulucky1102
Copy link

Hello, this is because I installed the df library in the environment. The function calls are caused by the functions in the library not calling the functions you set.

@hulucky1102
Copy link

Hello, the onnx model is generated through your model_onnx_export.py, and the audio is read from the device and monitored in real time through onnx. The effect is not good. Is it because the code is not used correctly.
image

@grazder
Copy link
Contributor Author

grazder commented Jan 4, 2024

@hulucky1102 You can look here for an example on how to inference streaming version correctly

https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/test_torchdf.py#L70

Also, check that you export with always_apply_all_stages=True parameter:

https://github.com/grazder/DeepFilterNet/blob/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF/model_onnx_export.py#L232

@hulucky1102
Copy link

Thank you very much, this problem is caused by me passing in attend_lim_db. Is there some setting in the model that will cause muting, which will cause some of my speech to be incoherent, so I would like to remove the muting setting in the model.

@hulucky1102
Copy link

Both use DeepFilterNet3, one uses API enhance simulation and the other uses onnx simulation. The above two results are inconsistent. The result of using onnx simulation has vocal loss.
image
image

@grazder
Copy link
Contributor Author

grazder commented Jan 6, 2024

Is it happening only with ONNX inference? Or with torch inference of streaming version too?

@hulucky1102
Copy link

This situation occurs in both onnx and torch streaming

@RyanMetcalfeInt8
Copy link

Hi @grazder,

Thanks for your amazing work! I was hoping you could help clear up something for me.

As part of your change-set, you added a new 'hidden_states' input/output tensor to Encoder, ErbDecoder, and DfDecoder (in deepfilterne3.py). And I see that in your streaming implementation, these are used as part of the state management logic.

What I am confused about is that in the main branch's streaming implementation (i.e. the Rust implementation), it appears to work using these onnx models, DeepFilterNet3_onnx.tar.gz. But I don't see these 'hidden states' tensors exposed as inputs / outputs to these models. So how does the Rust implementation work without the ability to manage these? It seems like it's necessary based on your pure pytorch implementation.

It's very possible that I overlooked something simple..

Thanks again!
Ryan

@FisherDom
Copy link

hi, @grazder
thanks for your work, I found some issues in the code, (or maybe I'm wrong)
In torch_df_streaming.py, self.rolling_spec_buf_y_shape = (self.df_order + 2, self.freq_size, 2) ,but in lines 706:current_spec = new_rolling_spec_buf_y[self.df_order - 1] , I think this should be (self.df_order + 2 -1), Would love to hear your opinion.
Best wish!

@grazder
Copy link
Contributor Author

grazder commented Apr 7, 2024

Hello, we need to store two future frames for a single frame prediction, so self.df_order - 1 is correct

Also you can find it here:

let spec = self.rolling_spec_buf_y.get_mut(self.df_order - 1).unwrap();

@FisherDom
Copy link

Thanks for your answer! @grazder

Sorry to bother you again.I would like to get a tflite model that can perform real-time inference on DSP. If I use the onnx model obtained from your model_onnx_export.py file, will there be any problems when converting the onnx model to tf model? I noticed that you are using a newly registered operator.

Secondly, if I want to use C language for inference on DSP, would you suggest that I use the original author's single complete model or three models? I have relatively little experience with deployment.

Wish everything goes well with your work!

@grazder
Copy link
Contributor Author

grazder commented Apr 10, 2024

@FisherDom

will there be any problems when converting the onnx model to tf model

I think that you cat face problems with RFFT / IRFFT, I don't know a lot about tf operations, so I can't say exactly.

I noticed that you are using a newly registered operator

Yeah, new operator is in torchDF_main branch. New operator gave me like ~x2.5 speedup. Also in this branch there are some else graph optimizations.

You can use torchDF-changes branch (from this PR). In that variant RFFT / IRFFT implemented as matmul, which is suboptimal, but you I bet you will not face problems with RFFT / IRFFT export.

I want to use C language for inference on DSP, would you suggest that I use the original author's single complete model or three models?

Well you can use C capi, you can find build in actions or you can look at actions config.
I didn't benchmark original Rust speed and my implementation for a long time. But when I checked original implementation (using Rust) has the same speed as torchDF-changes branch. But torchDF_main now much faster, so you can try it.

Also if you want to use C, you can try onnxruntime C API

@viki347
Copy link

viki347 commented Apr 22, 2024

@grazder Hallo! Thank you very much for your contribution on providing the Pure Torch code , which saved me a lot of time that would have been spent learning Rust.
When I run your code,
torch_df_offline.py(https://github.com/grazder/DeepFilterNet/blob/torchDF-changes/torchDF/torch_df_offline.py)
I just found a little bug, which the durations do not match. The duration of the output audio file is shorter than the input. Therefore, upon inspecting your code, I found that the ‘frame_synthesis' function does not consider the second part of the last audio block.
Below is the result of my modification for your reference.

    def frame_synthesis(self, input_data, i_last_record, out_chunks):  # added two additional variables.
        """
        Original code - libDF/src/lib.rs - frame_synthesis()
        Inverse rfft for one frame. Every frame is summarized with buffer from previous frame.
        And saving buffer for next frame.

        Parameters:
            input_data: Complex[F] - Enhanced audio spectrogram

        Returns:
            output:     Float[f] - Enhanced audio
        """
        x = torch.fft.irfft(input_data, norm='forward') * self.window
        x_first, x_second = torch.split(x, [self.frame_size, x.shape[0] - self.frame_size])
        output = x_first + self.synthesis_mem  

        self.synthesis_mem = x_second

        if i_last_record == out_chunks:  # if the chunk of outwav is the last chunk
            output = output + x_second + x_second

        return output

@FisherDom
Copy link

@grazder sry to bother you! I want to quantify the model at 8 bits now, but I have only quantified the CV model. Can you give me some hints or information? Is it quantized on torch or onnx?
Best wish!

@grazder
Copy link
Contributor Author

grazder commented May 20, 2024

@FisherDom

Hello! I've tried to quantify with ONNX here - https://github.com/grazder/DeepFilterNet/blob/torchDF-temp/torchDF/model_onnx_export.py
But it didn't gave me anything, seems like old ops became faster, but because of many quantize / dequantize nodes models didn't became much faster or smaller.

@nyabinary
Copy link

Any updates on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants