Skip to content

Commit

Permalink
refactor: Update export_gltf function to use Camera and ImageSlice ob…
Browse files Browse the repository at this point in the history
…ject for setting up the 3d scene
  • Loading branch information
provos committed Jun 26, 2024
1 parent 6957e0d commit d855a6b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 52 deletions.
20 changes: 14 additions & 6 deletions gltf.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,8 @@ def create_card(gltf_obj, i, corners_3d, subdivisions=300, depth_map=None, displ

def export_gltf(
output_path,
aspect_ratio,
focal_length,
camera_distance,
card_corners_3d_list,
cam,
image_slices,
image_paths,
depth_paths = [],
displacement_scale=0.0,
Expand All @@ -318,12 +316,21 @@ def export_gltf(
aspect_ratio (float): The aspect ratio of the camera.
focal_length (float): The focal length of the camera.
camera_distance (float): The distance of the camera from the origin.
card_corners_3d_list (list): List of 3D corner coordinates for each card.
cam (Camera): The camera object for the scene.
image_slices (list): List of 3D corner coordinates for each card.
image_paths (list): List of file paths for each image slice.
depth_paths (list, optional): List of file paths for each depth map. Defaults to [].
displacement_scale (float, optional): The scale of the displacement. Defaults to 0.0.
inline_images (bool, optional): Whether to inline the images in the glTF file. Defaults to True.
"""

# compute pre-requisites
image_height, image_width = image_slices[0].image.shape[:2]
camera_matrix = cam.camera_matrix(image_width, image_height)
aspect_ratio = float(camera_matrix[0, 2]) / camera_matrix[1, 2]
focal_length = cam.focal_length
camera_distance = cam.camera_distance

# Create a new glTF object
gltf_obj = gltf.GLTF2(
scene=0
Expand All @@ -345,7 +352,8 @@ def export_gltf(
alpha_mode = "MASK" if support_dof else "BLEND"

# Create the card objects (planes)
for i, corners_3d in enumerate(card_corners_3d_list):
for i, image_slice in enumerate(image_slices):
corners_3d = image_slice.create_card(image_height, image_width, cam)
# Translaton hack so that we can put the depth on the node
z_transform = corners_3d[0][2]
corners_3d[:, 2] -= z_transform
Expand Down
5 changes: 1 addition & 4 deletions segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,9 @@ def process_image(image_path, output_path, num_slices=5,

image_paths = [output_path /
f"image_slice_{i}.png" for i in range(num_slices)]
# fix it
aspect_ratio = float(camera_matrix[0, 2]) / camera_matrix[1, 2]

output_path = Path(output_path) / "model.gltf"
export_gltf(output_path, aspect_ratio, camera.focal_length, camera.camera_distance,
card_corners_3d_list, image_paths)
export_gltf(output_path, camera, image_slices, image_paths)


def main():
Expand Down
41 changes: 6 additions & 35 deletions test_webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,6 @@ def test_export_state_as_gltf(self, mock_export_gltf, mock_postprocess_depth_map
# Test case 1: Displacement scale is 0
state = AppState()
state.image_slices = self.state.image_slices
camera_matrix = self.camera.camera_matrix(100, 100)
card_corners_3d_list = [slice.create_card(
100, 100, self.camera) for slice in state.image_slices]

mock_export_gltf.return_value = Path("output.gltf")

Expand All @@ -282,16 +279,9 @@ def test_export_state_as_gltf(self, mock_export_gltf, mock_postprocess_depth_map
expected_call = mock_export_gltf.call_args_list[0]
expected_args, expected_kwargs = expected_call
self.assertEqual(expected_args[0], Path("output_dir/model.gltf"))
self.assertAlmostEqual(expected_args[1], float(
camera_matrix[0, 2]) / camera_matrix[1, 2])
self.assertEqual(expected_args[2], 50)
self.assertEqual(expected_args[3], 10)
for expected_corner, actual_corner in zip(expected_args[4], card_corners_3d_list):
np.testing.assert_array_almost_equal(
expected_corner, actual_corner)
image_slices_filenames = [slice.filename for slice in self.state.image_slices]
self.assertEqual(expected_args[5], image_slices_filenames)
self.assertEqual(expected_args[6], [])
self.assertEqual(expected_args[3], image_slices_filenames)
self.assertEqual(expected_args[4], [])
self.assertEqual(expected_kwargs["displacement_scale"], 0)

@patch("PIL.Image.fromarray")
Expand All @@ -303,8 +293,6 @@ def test_export_state_as_gltf_with_displacement(
# Test case 2: Displacement scale is greater than 0
state = AppState()
state.image_slices = self.state.image_slices
camera_matrix = self.camera.camera_matrix(100, 100)
card_corners_3d_list = [slice.create_card(100, 100, self.camera) for slice in state.image_slices]

mock_export_gltf.return_value = Path("output.gltf")

Expand Down Expand Up @@ -334,27 +322,17 @@ def test_export_state_as_gltf_with_displacement(
expected_call = mock_export_gltf.call_args_list[0]
expected_args, expected_kwargs = expected_call
self.assertEqual(expected_args[0], Path("output_dir/model.gltf"))
self.assertAlmostEqual(expected_args[1], float(
camera_matrix[0, 2]) / camera_matrix[1, 2])
self.assertEqual(expected_args[2], 50)
self.assertEqual(expected_args[3], 10)
for expected_corner, actual_corner in zip(expected_args[4], card_corners_3d_list):
np.testing.assert_array_almost_equal(
expected_corner, actual_corner)

image_slices_filenames = [slice.filename for slice in self.state.image_slices]
self.assertEqual(expected_args[5], image_slices_filenames)
self.assertEqual(expected_args[6], [self.mock_depth_file] * 3)
self.assertEqual(expected_args[3], image_slices_filenames)
self.assertEqual(expected_args[4], [self.mock_depth_file] * 3)
self.assertEqual(expected_kwargs["displacement_scale"], 1)

@patch("webui.export_gltf")
def test_export_state_as_gltf_with_upscaled(self, mock_export_gltf):
# Test case 3: Upscaled slices exist
state = AppState()
state.image_slices = self.state.image_slices
camera_matrix = self.camera.camera_matrix(100, 100)
card_corners_3d_list = [slice.create_card(
100, 100, self.camera) for slice in state.image_slices]

# Pretend the upscaled file exists
mock_upscaled_file = MagicMock()
Expand All @@ -368,15 +346,8 @@ def test_export_state_as_gltf_with_upscaled(self, mock_export_gltf):
expected_call = mock_export_gltf.call_args_list[0]
expected_args, expected_kwargs = expected_call
self.assertEqual(expected_args[0], Path("output_dir/model.gltf"))
self.assertAlmostEqual(expected_args[1], float(
camera_matrix[0, 2]) / camera_matrix[1, 2])
self.assertEqual(expected_args[2], 50)
self.assertEqual(expected_args[3], 10)
for expected_corner, actual_corner in zip(expected_args[4], card_corners_3d_list):
np.testing.assert_array_almost_equal(
expected_corner, actual_corner)
self.assertEqual(expected_args[5], [mock_upscaled_file] * 3)
self.assertEqual(expected_args[6], [self.mock_depth_file] * 3)
self.assertEqual(expected_args[3], [mock_upscaled_file] * 3)
self.assertEqual(expected_args[4], [self.mock_depth_file] * 3)
self.assertEqual(expected_kwargs["displacement_scale"], 1)

# TODO: Add more test cases for generate_depth_map, postprocess_depth_map, and export_gltf
Expand Down
10 changes: 3 additions & 7 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,14 +1204,11 @@ def gltf_create(


def export_state_as_gltf(
state, filename,
state: AppState, filename,
camera,
displacement_scale, modelname='midas',
inline_images=True,
support_dof=False):
camera_matrix = state.camera_matrix()
card_corners_3d_list = state.get_cards()

depth_filenames = []
if displacement_scale > 0:
for i, slice_image in enumerate(state.image_slices):
Expand All @@ -1238,11 +1235,10 @@ def export_state_as_gltf(
else:
slices_filenames.append(slice_image.filename)

aspect_ratio = float(camera_matrix[0, 2]) / camera_matrix[1, 2]
output_path = Path(filename) / state.MODEL_FILE
gltf_path = export_gltf(
output_path, aspect_ratio, camera.focal_length, camera.camera_distance,
card_corners_3d_list, slices_filenames, depth_filenames,
output_path, camera,
state.image_slices, slices_filenames, depth_filenames,
displacement_scale=displacement_scale,
inline_images=inline_images,
support_dof=support_dof)
Expand Down

0 comments on commit d855a6b

Please sign in to comment.