Skip to content

Commit

Permalink
fix: cloud transform rotation and scale (#109)
Browse files Browse the repository at this point in the history
* fix: #108 - cloud transform rotation and scale

* fix: lint
  • Loading branch information
mosure authored Jun 9, 2024
1 parent 0adb211 commit 176f048
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/gaussian/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub enum GaussianCloudDrawMode {
pub struct GaussianCloudSettings {
pub aabb: bool,
pub global_scale: f32,
pub global_transform: GlobalTransform,
pub transform: Transform,
pub visualize_bounding_box: bool,
pub visualize_depth: bool,
pub sort_mode: SortMode,
Expand All @@ -37,7 +37,7 @@ impl Default for GaussianCloudSettings {
Self {
aabb: false,
global_scale: 1.0,
global_transform: Transform::IDENTITY.into(),
transform: Transform::IDENTITY,
visualize_bounding_box: false,
visualize_depth: false,
sort_mode: SortMode::default(),
Expand Down
2 changes: 1 addition & 1 deletion src/render/bindings.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@group(0) @binding(1) var<uniform> globals: Globals;

struct GaussianUniforms {
global_transform: mat4x4<f32>,
transform: mat4x4<f32>,
global_scale: f32,
count: u32,
count_root_ceil: u32,
Expand Down
25 changes: 16 additions & 9 deletions src/render/gaussian.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ fn compute_cov3d(scale: vec3<f32>, rotation: vec4<f32>) -> array<f32, 6> {
let y = rotation.z;
let z = rotation.w;

let T = mat3x3<f32>(
gaussian_uniforms.transform[0].xyz,
gaussian_uniforms.transform[1].xyz,
gaussian_uniforms.transform[2].xyz,
);

let R = mat3x3<f32>(
1.0 - 2.0 * (y * y + z * z),
2.0 * (x * y - r * z),
Expand All @@ -160,14 +166,15 @@ fn compute_cov3d(scale: vec3<f32>, rotation: vec4<f32>) -> array<f32, 6> {

let M = S * R;
let Sigma = transpose(M) * M;
let TS = T * Sigma * transpose(T);

return array<f32, 6>(
Sigma[0][0],
Sigma[0][1],
Sigma[0][2],
Sigma[1][1],
Sigma[1][2],
Sigma[2][2],
TS[0][0],
TS[0][1],
TS[0][2],
TS[1][1],
TS[1][2],
TS[2][2],
);
}

Expand Down Expand Up @@ -320,7 +327,7 @@ fn vs_points(

let position = vec4<f32>(get_position(splat_index), 1.0);

let transformed_position = (gaussian_uniforms.global_transform * position).xyz;
let transformed_position = (gaussian_uniforms.transform * position).xyz;
let projected_position = world_to_clip(transformed_position);

discard_quad |= !in_frustum(projected_position.xyz);
Expand Down Expand Up @@ -353,8 +360,8 @@ fn vs_points(
let first_position = vec4<f32>(get_position(get_entry(1u).value), 1.0);
let last_position = vec4<f32>(get_position(get_entry(gaussian_uniforms.count - 1u).value), 1.0);

let min_position = (gaussian_uniforms.global_transform * first_position).xyz;
let max_position = (gaussian_uniforms.global_transform * last_position).xyz;
let min_position = (gaussian_uniforms.transform * first_position).xyz;
let max_position = (gaussian_uniforms.transform * last_position).xyz;

let camera_position = view.world_position;

Expand Down
2 changes: 1 addition & 1 deletion src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ pub fn extract_gaussians(
let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();

let settings_uniform = GaussianCloudUniform {
transform: settings.global_transform.compute_matrix(),
transform: settings.transform.compute_matrix(),
global_scale: settings.global_scale,
count: cloud.count as u32,
count_root_ceil: (cloud.count as f32).sqrt().ceil() as u32,
Expand Down
2 changes: 1 addition & 1 deletion src/sort/radix.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ fn radix_sort_a(
}
var key: u32 = 0xFFFFFFFFu; // Stream compaction for frustum culling
let position = vec4<f32>(get_position(entry_index), 1.0);
let transformed_position = (gaussian_uniforms.global_transform * position).xyz;
let transformed_position = (gaussian_uniforms.transform * position).xyz;
let clip_space_pos = world_to_clip(transformed_position);
if(in_frustum(clip_space_pos.xyz)) {
// key = bitcast<u32>(1.0 - clip_space_pos.z);
Expand Down
11 changes: 7 additions & 4 deletions src/sort/rayon.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use bevy::{
prelude::*,
asset::LoadState,
math::Vec3A,
utils::Instant,
};

Expand Down Expand Up @@ -36,10 +37,10 @@ pub fn rayon_sort(
&GaussianCloudSettings,
)>,
cameras: Query<(
&GlobalTransform,
&Transform,
&Camera3d,
)>,
mut last_camera_position: Local<Vec3>,
mut last_camera_position: Local<Vec3A>,
mut last_sort_time: Local<Option<Instant>>,
mut period: Local<std::time::Duration>,
mut sort_done: Local<bool>,
Expand All @@ -63,7 +64,7 @@ pub fn rayon_sort(
camera_transform,
_camera,
) in cameras.iter() {
let camera_position = camera_transform.compute_transform().translation;
let camera_position = camera_transform.compute_affine().translation;
let camera_movement = *last_camera_position != camera_position;

if camera_movement {
Expand Down Expand Up @@ -104,7 +105,9 @@ pub fn rayon_sort(
.zip(sorted_entries.sorted.par_iter_mut())
.enumerate()
.for_each(|(idx, (position, sort_entry))| {
let position = Vec3::from_slice(position.as_ref());
let position = Vec3A::from_slice(position.as_ref());
let position = settings.transform.compute_affine().transform_point3a(position);

let delta = camera_position - position;

sort_entry.key = bytemuck::cast(delta.length_squared());
Expand Down
11 changes: 7 additions & 4 deletions src/sort/std.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use bevy::{
prelude::*,
asset::LoadState,
math::Vec3A,
utils::Instant,
};

Expand Down Expand Up @@ -35,10 +36,10 @@ pub fn std_sort(
&GaussianCloudSettings,
)>,
cameras: Query<(
&GlobalTransform,
&Transform,
&Camera3d,
)>,
mut last_camera_position: Local<Vec3>,
mut last_camera_position: Local<Vec3A>,
mut last_sort_time: Local<Option<Instant>>,
mut period: Local<std::time::Duration>,
mut camera_debounce: Local<bool>,
Expand All @@ -61,7 +62,7 @@ pub fn std_sort(
camera_transform,
_camera,
) in cameras.iter() {
let camera_position = camera_transform.compute_transform().translation;
let camera_position = camera_transform.compute_affine().translation;
let camera_movement = *last_camera_position != camera_position;

if camera_movement {
Expand Down Expand Up @@ -107,7 +108,9 @@ pub fn std_sort(
.zip(sorted_entries.sorted.iter_mut())
.enumerate()
.for_each(|(idx, (position, sort_entry))| {
let position = Vec3::from_slice(position.as_ref());
let position = Vec3A::from_slice(position.as_ref());
let position = settings.transform.compute_affine().transform_point3a(position);

let delta = camera_position - position;

sort_entry.key = bytemuck::cast(delta.length_squared());
Expand Down

0 comments on commit 176f048

Please sign in to comment.