diff --git a/libDF/Cargo.toml b/libDF/Cargo.toml index c7a76abe0..32275e4d7 100644 --- a/libDF/Cargo.toml +++ b/libDF/Cargo.toml @@ -95,7 +95,7 @@ thiserror = { version = "1.0", optional = true } anyhow = { version = "1.0", optional = true, features = ["backtrace"] } ctrlc = { version = "3.2", optional = true } hound = { version = "3.4", optional = true } -hdf5 = { version = "^0.8", optional = true, git = "https://github.com/aldanor/hdf5-rust/", rev="refs/pull/216/head" } +hdf5 = { version = "^0.8", optional = true, git = "https://github.com/aldanor/hdf5-rust/", rev = "refs/pull/216/head" } ndarray = { version = "^0.15", optional = true, features = ["serde"] } ndarray-rand = { version = "^0.14", optional = true } rayon = { version = "1.5", optional = true } @@ -108,10 +108,10 @@ claxon = { version = "^0.4", optional = true } env_logger = { version = "0.10", optional = true } clap = { version = "4.0", optional = true, features = ["derive"] } rust-ini = { version = "^0.18", optional = true } -tract-core = { version = "^0.18.3", optional = true } -tract-onnx = { version = "^0.18.3", optional = true } -tract-pulse = { version = "^0.18.3", optional = true } -tract-hir = { version = "^0.18.3", optional = true } +tract-core = { version = "^0.18.6-pre", optional = true, git = "https://github.com/sonos/tract/", rev = "refs/pull/907/head" } +tract-onnx = { version = "^0.18.6-pre", optional = true, git = "https://github.com/sonos/tract/", rev = "refs/pull/907/head" } +tract-pulse = { version = "^0.18.6-pre", optional = true, git = "https://github.com/sonos/tract/", rev = "refs/pull/907/head" } +tract-hir = { version = "^0.18.6-pre", optional = true, git = "https://github.com/sonos/tract/", rev = "refs/pull/907/head" } flate2 = { version = "1.0.24", optional = true } tar = { version = "0.4.38", optional = true } diff --git a/libDF/src/bin/enhance_wav.rs b/libDF/src/bin/enhance_wav.rs index d29ff844d..27634dc3d 100644 --- a/libDF/src/bin/enhance_wav.rs +++ b/libDF/src/bin/enhance_wav.rs @@ -47,6 +47,9 @@ struct Args { /// If used with multiple channels, reduce the mask with max (1) or mean (2) #[arg(long, value_parser, default_value_t = 1)] reduce_mask: i32, + /// Convert model to half floats (16 bit) + #[arg(long)] + half_floats: bool, /// Logging verbosity #[arg( long, @@ -86,6 +89,7 @@ fn main() -> Result<()> { args.max_db_erb_thresh, args.max_db_df_thresh, args.reduce_mask.try_into().unwrap(), + args.half_floats, ); let df_params = if let Some(tar) = args.model.as_ref() { match DfParams::new(tar.clone()) { diff --git a/libDF/src/tract.rs b/libDF/src/tract.rs index 20a4f11e9..6193a91f8 100644 --- a/libDF/src/tract.rs +++ b/libDF/src/tract.rs @@ -98,6 +98,7 @@ pub struct RuntimeParams { max_db_erb_thresh: f32, max_db_df_thresh: f32, reduce_mask: ReduceMask, + half_floats: bool, } impl RuntimeParams { pub fn new( @@ -108,6 +109,7 @@ impl RuntimeParams { max_db_erb_thresh: f32, max_db_df_thresh: f32, reduce_mask: ReduceMask, + half_floats: bool, ) -> Self { Self { n_ch, @@ -117,15 +119,25 @@ impl RuntimeParams { max_db_erb_thresh, max_db_df_thresh, reduce_mask, + half_floats, } } pub fn default_with_ch(channels: usize) -> Self { - RuntimeParams::new(channels, false, 100., -10., 30., 20., ReduceMask::MEAN) + RuntimeParams::new( + channels, + false, + 100., + -10., + 30., + 20., + ReduceMask::MEAN, + false, + ) } } impl Default for RuntimeParams { fn default() -> Self { - RuntimeParams::new(1, false, 100., -10., 30., 20., ReduceMask::MEAN) + RuntimeParams::new(1, false, 100., -10., 30., 20., ReduceMask::MEAN, false) } } @@ -149,6 +161,7 @@ pub struct DfTract { pub n_freqs: usize, pub df_order: usize, pub post_filter: bool, + pub half_floats: bool, pub alpha: f32, pub min_db_thresh: f32, pub max_db_erb_thresh: f32, @@ -167,7 +180,7 @@ pub struct DfTract { #[cfg(all(not(feature = "capi"), feature = "default_model"))] impl Default for DfTract { fn default() -> Self { - let r_params = RuntimeParams::new(1, false, 100., -10., 30., 20., ReduceMask::MEAN); + let r_params = RuntimeParams::default(); let df_params = DfParams::default(); DfTract::new(df_params, &r_params).expect("Could not load DfTract") } @@ -181,12 +194,13 @@ impl DfTract { let model_cfg = config.section(Some("deepfilternet")).unwrap(); let df_cfg = config.section(Some("df")).unwrap(); let ch = rp.n_ch; + let h = rp.half_floats; - let (enc, _enc_delay) = init_encoder_from_read(&mut Cursor::new(dfp.enc), df_cfg, ch)?; + let (enc, _enc_delay) = init_encoder_from_read(&mut Cursor::new(dfp.enc), df_cfg, ch, h)?; let (erb_dec, _erb_dec_delay) = - init_erb_decoder_from_read(&mut Cursor::new(dfp.erb_dec), model_cfg, df_cfg, ch)?; + init_erb_decoder_from_read(&mut Cursor::new(dfp.erb_dec), model_cfg, df_cfg, ch, h)?; let (df_dec, _df_dec_delay) = - init_df_decoder_from_read(&mut Cursor::new(dfp.df_dec), model_cfg, df_cfg, ch)?; + init_df_decoder_from_read(&mut Cursor::new(dfp.df_dec), model_cfg, df_cfg, ch, h)?; let enc = SimpleState::new(enc.into_runnable()?)?; let erb_dec = SimpleState::new(erb_dec.into_runnable()?)?; let df_dec = SimpleState::new(df_dec.into_runnable()?)?; @@ -225,10 +239,11 @@ impl DfTract { Some(10f32.powf(-atten_lim / 20.)) }; let spec_shape = [1, 1, 1, n_freqs, 2]; - let spec_buf = unsafe { Tensor::uninitialized_dt(f32::datum_type(), &spec_shape)? }; - let erb_buf = unsafe { Tensor::uninitialized_dt(f32::datum_type(), &[1, 1, 1, nb_erb])? }; - let cplx_buf = unsafe { Tensor::uninitialized_dt(f32::datum_type(), &[1, 1, nb_df, 2])? }; - // let mut cplx_buf_t = unsafe { Tensor::uninitialized_dt(f32::datum_type(), &[1, 2, 1, nb_df])? }; + let dt = f32::datum_type(); + let spec_buf = unsafe { Tensor::uninitialized_dt(dt, &spec_shape)? }; + let erb_buf = unsafe { Tensor::uninitialized_dt(dt, &[1, 1, 1, nb_erb])? }; + let cplx_buf = unsafe { Tensor::uninitialized_dt(dt, &[1, 1, nb_df, 2])? }; + // let mut cplx_buf_t = unsafe { Tensor::uninitialized_dt(dt, &[1, 2, 1, nb_df])? }; let m_zeros = vec![0.; nb_erb]; let model_type = config.section(Some("train")).unwrap().get("model").unwrap(); @@ -282,6 +297,7 @@ impl DfTract { rolling_spec_buf_x, df_states, post_filter: rp.post_filter, + half_floats: rp.half_floats, }; m.init()?; #[cfg(feature = "timings")] @@ -378,13 +394,24 @@ impl DfTract { #[cfg(feature = "timings")] let t1 = Instant::now(); // Run encoder - let mut enc_emb = self.enc.run(tvec!( - self.erb_buf.clone(), - self.cplx_buf.clone().permute_axes(&[0, 3, 1, 2])? - ))?; + let mut enc_emb = if self.half_floats { + unsafe { + let erb = to_f16(&self.erb_buf)?; + let cpl = to_f16(&self.cplx_buf)?.permute_axes(&[0, 3, 1, 2])?; + self.enc.run(tvec!( + erb, + cpl + ))? + } + } else { + self.enc.run(tvec!( + self.erb_buf.clone(), + self.cplx_buf.clone().permute_axes(&[0, 3, 1, 2])? + ))? + }; #[cfg(feature = "timings")] let t2 = Instant::now(); - let &lsnr = enc_emb.pop().unwrap().to_scalar::()?; + let lsnr = enc_emb.pop().unwrap().cast_to_scalar::()?; let c0 = enc_emb.pop().unwrap().into_tensor(); let emb = enc_emb.pop().unwrap().into_tensor(); let (apply_erb, apply_erb_zeros, apply_df) = if lsnr < self.min_db_thresh { @@ -426,6 +453,8 @@ impl DfTract { .unwrap() .into_tensor() .into_shape(&[self.ch, self.nb_erb])? + .cast_to::()? + .into_owned() .into_array()?; if self.ch > 1 { m = match self.reduce_mask { @@ -531,6 +560,7 @@ fn df( .into_dimensionality()?; // Zero relevant frequency bins of output o_f.slice_mut(s![.., ..nb_df]).fill(Complex32::default()); + let coefs = coefs.cast_to::()?; let coefs_arr: ArrayView3 = as_arrayview_complex(coefs.to_array_view::()?, &[ch, nb_df, df_order]) .into_dimensionality()?; @@ -559,14 +589,16 @@ fn init_encoder_impl( mut m: InferenceModel, df_cfg: &ini::Properties, n_ch: usize, + half_floats: bool, ) -> Result<(TypedModel, usize)> { log::debug!("Start init encoder."); let s = tract_pulse::fact::stream_dim(); + let dt = f32::datum_type(); let nb_erb = df_cfg.get("nb_erb").unwrap().parse::()?; let nb_df = df_cfg.get("nb_df").unwrap().parse::()?; - let feat_erb = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(n_ch, 1, s, nb_erb)); - let feat_spec = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(n_ch, 2, s, nb_df)); + let feat_erb = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, 1, s, nb_erb)); + let feat_spec = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, 2, s, nb_df)); log::debug!( "Encoder input: \n feat_erb [{:?}]\n feat_spec [{:?}]", @@ -578,7 +610,6 @@ fn init_encoder_impl( .with_input_fact(1, feat_spec)? .with_input_names(["feat_erb", "feat_spec"])? .with_output_names(["e0", "e1", "e2", "e3", "emb", "c0", "lsnr"])?; - m.analyse(true)?; let mut m = m.into_typed()?; @@ -586,21 +617,33 @@ fn init_encoder_impl( let pulsed = PulsedModel::new(&m, 1)?; let delay = pulsed.output_fact(0)?.delay; log::info!("Init encoder with delay: {}", delay); - let m = pulsed.into_typed()?.into_optimized()?; + let mut m = pulsed.into_typed()?; + + if half_floats { + use tract_core::model::translator::Translate; + m = tract_core::half::HalfTranslator.translate_model(&m)?; + } + m = m.into_optimized()?; Ok((m, delay)) } -fn init_encoder(m: &Path, df_cfg: &ini::Properties, n_ch: usize) -> Result<(TypedModel, usize)> { +fn init_encoder( + m: &Path, + df_cfg: &ini::Properties, + n_ch: usize, + half_floats: bool, +) -> Result<(TypedModel, usize)> { let m = tract_onnx::onnx().with_ignore_output_shapes(true).model_for_path(m)?; - init_encoder_impl(m, df_cfg, n_ch) + init_encoder_impl(m, df_cfg, n_ch, half_floats) } fn init_encoder_from_read( m: &mut dyn Read, df_cfg: &ini::Properties, n_ch: usize, + half_floats: bool, ) -> Result<(TypedModel, usize)> { let m = tract_onnx::onnx().with_ignore_output_shapes(true).model_for_read(m)?; - init_encoder_impl(m, df_cfg, n_ch) + init_encoder_impl(m, df_cfg, n_ch, half_floats) } fn init_erb_decoder_impl( @@ -608,24 +651,23 @@ fn init_erb_decoder_impl( net_cfg: &ini::Properties, df_cfg: &ini::Properties, n_ch: usize, + half_floats: bool, ) -> Result<(TypedModel, usize)> { log::debug!("Start init ERB decoder."); let s = tract_pulse::fact::stream_dim(); + let dt = f32::datum_type(); let nb_erb = df_cfg.get("nb_erb").unwrap().parse::()?; let layer_width = net_cfg.get("conv_ch").unwrap().parse::()?; let n_hidden = layer_width * nb_erb / 4; - let emb = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(n_ch, s, n_hidden)); + let emb = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, s, n_hidden)); let e3f = nb_erb / 4; - let e3 = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(n_ch, layer_width, s, e3f)); - let e2 = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(n_ch, layer_width, s, e3f)); + let e3 = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, layer_width, s, e3f)); + let e2 = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, layer_width, s, e3f)); let e1f = nb_erb / 2; - let e1 = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(n_ch, layer_width, s, e1f)); - let e0 = InferenceFact::dt_shape( - f32::datum_type(), - shapefactoid!(n_ch, layer_width, s, nb_erb), - ); + let e1 = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, layer_width, s, e1f)); + let e0 = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, layer_width, s, nb_erb)); log::debug!( "ERB decoder input: \n emb [{:?}]\n e3 [{:?}]\n e2 [{:?}]\n e1 [{:?}]\n e0 [{:?}]", emb.shape, @@ -651,7 +693,13 @@ fn init_erb_decoder_impl( let pulsed = PulsedModel::new(&m, 1)?; let delay = pulsed.output_fact(0)?.delay; log::info!("Init ERB decoder with delay: {}", delay); - let m = pulsed.into_typed()?.into_optimized()?; + let mut m = pulsed.into_typed()?; + + if half_floats { + use tract_core::model::translator::Translate; + m = tract_core::half::HalfTranslator.translate_model(&m)?; + } + m = m.into_optimized()?; Ok((m, delay)) } fn init_erb_decoder( @@ -659,18 +707,20 @@ fn init_erb_decoder( net_cfg: &ini::Properties, df_cfg: &ini::Properties, n_ch: usize, + half_floats: bool, ) -> Result<(TypedModel, usize)> { let m = tract_onnx::onnx().with_ignore_output_shapes(true).model_for_path(m)?; - init_erb_decoder_impl(m, net_cfg, df_cfg, n_ch) + init_erb_decoder_impl(m, net_cfg, df_cfg, n_ch, half_floats) } fn init_erb_decoder_from_read( m: &mut dyn Read, net_cfg: &ini::Properties, df_cfg: &ini::Properties, n_ch: usize, + half_floats: bool, ) -> Result<(TypedModel, usize)> { let m = tract_onnx::onnx().with_ignore_output_shapes(true).model_for_read(m)?; - init_erb_decoder_impl(m, net_cfg, df_cfg, n_ch) + init_erb_decoder_impl(m, net_cfg, df_cfg, n_ch, half_floats) } fn init_df_decoder_impl( @@ -678,20 +728,19 @@ fn init_df_decoder_impl( net_cfg: &ini::Properties, df_cfg: &ini::Properties, n_ch: usize, + half_floats: bool, ) -> Result<(TypedModel, usize)> { log::debug!("Start init DF decoder."); let s = tract_pulse::fact::stream_dim(); + let dt = f32::datum_type(); let nb_erb = df_cfg.get("nb_erb").unwrap().parse::()?; let nb_df = df_cfg.get("nb_df").unwrap().parse::()?; let layer_width = net_cfg.get("conv_ch").unwrap().parse::()?; let n_hidden = layer_width * nb_erb / 4; - let emb = InferenceFact::dt_shape(f32::datum_type(), shapefactoid!(n_ch, s, n_hidden)); - let c0 = InferenceFact::dt_shape( - f32::datum_type(), - shapefactoid!(n_ch, layer_width, s, nb_df), - ); + let emb = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, s, n_hidden)); + let c0 = InferenceFact::dt_shape(dt, shapefactoid!(n_ch, layer_width, s, nb_df)); log::debug!( "ERB decoder input: \n emb [{:?}]\n c0 [{:?}]", @@ -711,7 +760,13 @@ fn init_df_decoder_impl( let pulsed = PulsedModel::new(&m, 1)?; let delay = pulsed.output_fact(0)?.delay; log::info!("Init DF decoder with delay: {}", delay); - let m = pulsed.into_typed()?.into_optimized()?; + let mut m = pulsed.into_typed()?; + + if half_floats { + use tract_core::model::translator::Translate; + m = tract_core::half::HalfTranslator.translate_model(&m)?; + } + m = m.into_optimized()?; Ok((m, delay)) } fn init_df_decoder( @@ -719,18 +774,20 @@ fn init_df_decoder( net_cfg: &ini::Properties, df_cfg: &ini::Properties, n_ch: usize, + half_floats: bool, ) -> Result<(TypedModel, usize)> { let m = tract_onnx::onnx().with_ignore_output_shapes(true).model_for_path(m)?; - init_df_decoder_impl(m, net_cfg, df_cfg, n_ch) + init_df_decoder_impl(m, net_cfg, df_cfg, n_ch, half_floats) } fn init_df_decoder_from_read( m: &mut dyn Read, net_cfg: &ini::Properties, df_cfg: &ini::Properties, n_ch: usize, + half_floats: bool, ) -> Result<(TypedModel, usize)> { let m = tract_onnx::onnx().with_ignore_output_shapes(true).model_for_read(m)?; - init_df_decoder_impl(m, net_cfg, df_cfg, n_ch) + init_df_decoder_impl(m, net_cfg, df_cfg, n_ch, half_floats) } fn calc_norm_alpha(sr: usize, hop_size: usize, tau: f32) -> f32 { @@ -803,3 +860,12 @@ pub fn as_array_mut_complex<'a>( ArrayViewMutD::from_shape_ptr(shape, ptr) } } +pub unsafe fn to_f16(x: &Tensor) -> Result { + let mut tmp = Tensor::uninitialized_dt(f16::datum_type(), x.shape())?; + let x_v = x.to_array_view_unchecked::(); + let mut t_v = tmp.to_array_view_mut_unchecked::(); + for (&i, o) in x_v.iter().zip(t_v.iter_mut()) { + *o = f16::from_f32(i) + } + Ok(tmp) +}