diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8ceb28c..d0b9f4e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,11 +30,35 @@ jobs: with: toolchain: ${{ matrix.rust-toolchain }} + - name: Install ONNX Runtime on Windows + if: matrix.os == 'windows-latest' + run: | + Invoke-WebRequest -Uri "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-win-x64-1.17.1.zip" -OutFile "onnxruntime.zip" + Expand-Archive -Path "onnxruntime.zip" -DestinationPath "$env:RUNNER_TEMP" + echo "ONNXRUNTIME_DIR=$env:RUNNER_TEMP\onnxruntime-win-x64-1.17.1" | Out-File -Append -Encoding ascii $env:GITHUB_ENV + + - name: Install ONNX Runtime on macOS + if: matrix.os == 'macos-latest' + run: | + curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz" + mkdir -p $HOME/onnxruntime + tar -xzf onnxruntime.tgz -C $HOME/onnxruntime + echo "ONNXRUNTIME_DIR=$HOME/onnxruntime/onnxruntime-osx-x86_64-1.17.1" >> $GITHUB_ENV + + + - name: Set ONNX Runtime library path for macOS + if: matrix.os == 'macos-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/libonnxruntime.dylib" >> $GITHUB_ENV + + - name: Set ONNX Runtime library path for Windows + if: matrix.os == 'windows-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/onnxruntime.dll" >> $GITHUB_ENV + + - name: lint run: cargo clippy -- -Dwarnings - name: build - run: cargo build - - # - name: build (web) - # run: cargo build --example=minimal --target wasm32-unknown-unknown --release + run: cargo build --features "ort/load-dynamic" + env: + ORT_DYLIB_PATH: ${{ env.ORT_DYLIB_PATH }} diff --git a/.gitignore b/.gitignore index eaec9ea..28fc69c 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ www/assets/ *.mp4 mediamtx/ +onnxruntime/ diff --git a/Cargo.toml b/Cargo.toml index a858955..9bf9e70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,12 +25,21 @@ exclude = [ default-run = "viewer" +[features] +default = [ + "person_matting", +] + +person_matting = ["bevy_ort", "ort", "ndarray"] + [dependencies] anyhow = "1.0" async-compat = "0.2" -bytes = "1.5.0" +bevy_ort = { version = "0.5", optional = true } +bytes = "1.5" futures = "0.3" +ndarray = { version = "0.15", optional = true } openh264 = "0.5" retina = "0.4" tokio = { version = "1.36", features = ["full"] } @@ -50,6 +59,18 @@ features = [ ] +[dependencies.ort] +version = "2.0.0-alpha.4" +optional = true +default-features = false +features = [ + "cuda", + "load-dynamic", + "ndarray", + "openvino", +] + + [profile.dev.package."*"] opt-level = 3 diff --git a/README.md b/README.md index e56a22f..e6827fc 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,10 @@ rust bevy light field camera array tooling - [X] grid view of light field camera array - [X] stream to files with recording controls +- [X] person segmentation post-process (batch across streams) +- [X] async segmentation model inference +- [ ] foreground extraction post-process and visualization mode - [ ] playback nersemble recordings with annotations -- [ ] person segmentation post-process (batch across streams) - [ ] camera array calibration - [ ] 3d reconstruction dataset preparation - [ ] real-time 3d reconstruction viewer @@ -27,6 +29,11 @@ rust bevy light field camera array tooling the viewer opens a window and displays the light field camera array, with post-process options +> see execution provider [bevy_ort documentation](https://github.com/mosure/bevy_ort?tab=readme-ov-file#run-the-example-person-segmentation-model-modnet) for better performance + +- windows: `cargo run --release --features "ort/cuda"` + + ### controls - `r` to start recording @@ -156,5 +163,6 @@ it is useful to test the light field viewer with emulated camera streams ## credits - [bevy_video](https://github.com/PortalCloudInc/bevy_video) - [gaussian_avatars](https://github.com/ShenhanQian/GaussianAvatars) +- [modnet](https://github.com/ZHKKKe/MODNet) - [nersemble](https://github.com/tobias-kirschstein/nersemble) - [paddle_seg_matting](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.9/Matting/docs/quick_start_en.md) diff --git a/src/lib.rs b/src/lib.rs index 6e77d68..418a510 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,5 @@ +#[cfg(feature = "person_matting")] +pub mod matting; + pub mod mp4; pub mod stream; diff --git a/src/matting.rs b/src/matting.rs new file mode 100644 index 0000000..6dd5cae --- /dev/null +++ b/src/matting.rs @@ -0,0 +1,147 @@ +use bevy::{ + prelude::*, + ecs::system::CommandQueue, + tasks::{block_on, futures_lite::future, AsyncComputeTaskPool, Task}, +}; +use bevy_ort::{ + BevyOrtPlugin, + inputs, + models::modnet::{ + images_to_modnet_input, + modnet_output_to_luma_images, + }, + Onnx, +}; + +use crate::stream::StreamId; + + +#[derive(Component, Clone, Debug, Reflect)] +pub struct MattedStream { + pub stream_id: StreamId, + pub input: Handle, + pub output: Handle, +} + + +pub struct MattingPlugin; +impl Plugin for MattingPlugin { + fn build(&self, app: &mut App) { + app.add_plugins(BevyOrtPlugin); + app.register_type::(); + app.init_resource::(); + app.add_systems(Startup, load_modnet); + app.add_systems(Update, matting_inference); + } +} + + +#[derive(Resource, Default)] +pub struct Modnet { + pub onnx: Handle, +} + + +fn load_modnet( + asset_server: Res, + mut modnet: ResMut, +) { + let modnet_handle: Handle = asset_server.load("modnet_photographic_portrait_matting.onnx"); + modnet.onnx = modnet_handle; +} + + +#[derive(Default)] +struct ModnetComputePipeline(Option>); + + +fn matting_inference( + mut commands: Commands, + images: Res>, + modnet: Res, + matted_streams: Query< + ( + Entity, + &MattedStream, + ) + >, + onnx_assets: Res>, + mut pipeline_local: Local, +) { + if let Some(pipeline) = pipeline_local.0.as_mut() { + if let Some(mut commands_queue) = block_on(future::poll_once(pipeline)) { + commands.append(&mut commands_queue); + pipeline_local.0 = None; + } + + return; + } + + let thread_pool = AsyncComputeTaskPool::get(); + + let inputs = matted_streams.iter() + .map(|(_, matted_stream)| { + images.get(matted_stream.input.clone()).unwrap() + }) + .collect::>(); + + let uninitialized = inputs.iter().any(|image| image.size() == (32, 32).into()); + if uninitialized { + return; + } + + let max_inference_size = (256, 256).into(); + let input = images_to_modnet_input( + inputs, + max_inference_size, + ); + + if onnx_assets.get(&modnet.onnx).is_none() { + return; + } + + let onnx = onnx_assets.get(&modnet.onnx).unwrap(); + let session_arc = onnx.session.clone(); + + let outputs = matted_streams.iter() + .map(|(_, matted_stream)| matted_stream.output.clone()) + .collect::>(); + + let task = thread_pool.spawn(async move { + let mask_images: Result, String> = (|| { + let session_lock = session_arc.lock().map_err(|e| e.to_string())?; + let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; + + let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string())?; + let outputs = session.run(input_values).map_err(|e| e.to_string()); + + let binding = outputs.ok().unwrap(); + let output_value: &ort::Value = binding.get("output").unwrap(); + + Ok(modnet_output_to_luma_images(output_value)) + })(); + + match mask_images { + Ok(mut mask_images) => { + let mut command_queue = CommandQueue::default(); + + command_queue.push(move |world: &mut World| { + let mut images = world.get_resource_mut::>().unwrap(); + + outputs.iter() + .for_each(|output| { + images.insert(output, mask_images.pop().unwrap()); + }); + }); + + command_queue + }, + Err(error) => { + eprintln!("inference failed: {}", error); + CommandQueue::default() + } + } + }); + + *pipeline_local = ModnetComputePipeline(Some(task)); +} diff --git a/tools/viewer.rs b/tools/viewer.rs index d135f99..53e8ece 100644 --- a/tools/viewer.rs +++ b/tools/viewer.rs @@ -18,6 +18,12 @@ use bevy_light_field::stream::{ RtspStreamDescriptor, RtspStreamManager, RtspStreamPlugin, StreamId }; +#[cfg(feature = "person_matting")] +use bevy_light_field::matting::{ + MattedStream, + MattingPlugin, +}; + const RTSP_URIS: [&str; 2] = [ "rtsp://192.168.1.23/user=admin&password=admin123&channel=1&stream=0.sdp?", @@ -44,6 +50,9 @@ fn main() { ..default() }), RtspStreamPlugin, + + #[cfg(feature = "person_matting")] + MattingPlugin, )) .add_systems(Startup, create_streams) .add_systems(Startup, setup_camera) @@ -65,6 +74,13 @@ fn create_streams( primary_window: Query<&Window, With>, ) { let window = primary_window.single(); + + #[cfg(feature = "person_matting")] + let elements = RTSP_URIS.len() * 2; + + #[cfg(not(feature = "person_matting"))] + let elements = RTSP_URIS.len(); + let ( columns, rows, @@ -73,20 +89,20 @@ fn create_streams( ) = calculate_grid_dimensions( window.width(), window.height(), - RTSP_URIS.len() + elements, ); - let images: Vec> = RTSP_URIS.iter() + let size = Extent3d { + width: 32, + height: 32, + ..default() + }; + + let input_images: Vec> = RTSP_URIS.iter() .enumerate() .map(|(index, &url)| { let entity = commands.spawn_empty().id(); - let size = Extent3d { - width: 32, - height: 32, - ..default() - }; - let mut image = Image { asset_usage: RenderAssetUsages::all(), texture_descriptor: TextureDescriptor { @@ -120,6 +136,38 @@ fn create_streams( }) .collect(); + let output_images = input_images.iter() + .enumerate() + .map(|(index, image)| { + let mut output_image = Image { + asset_usage: RenderAssetUsages::all(), + texture_descriptor: TextureDescriptor { + label: None, + size, + dimension: TextureDimension::D2, + format: TextureFormat::Rgba8UnormSrgb, + mip_level_count: 1, + sample_count: 1, + usage: TextureUsages::COPY_DST + | TextureUsages::TEXTURE_BINDING + | TextureUsages::RENDER_ATTACHMENT, + view_formats: &[TextureFormat::Rgba8UnormSrgb], + }, + ..default() + }; + output_image.resize(size); + let output_image = images.add(output_image); + + commands.spawn(MattedStream { + stream_id: StreamId(index), + input: image.clone(), + output: output_image.clone(), + }); + + output_image + }) + .collect::>(); + commands.spawn(NodeBundle { style: Style { display: Display::Grid, @@ -133,15 +181,26 @@ fn create_streams( ..default() }) .with_children(|builder| { - images.iter() - .for_each(|image| { + input_images.iter() + .zip(output_images.iter()) + .for_each(|(input, output)| { + builder.spawn(ImageBundle { + style: Style { + width: Val::Px(sprite_width), + height: Val::Px(sprite_height), + ..default() + }, + image: UiImage::new(input.clone()), + ..default() + }); + builder.spawn(ImageBundle { style: Style { width: Val::Px(sprite_width), height: Val::Px(sprite_height), ..default() }, - image: UiImage::new(image.clone()), + image: UiImage::new(output.clone()), ..default() }); });