Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support setting image_size=(width, height) #45

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,8 @@ tests/data/test_rasterize1.png
tests/data/test_rasterize2.png
examples/data/example1.gif
examples/data/example2_optimization.gif

# vscode
.vscode/
# pycharm
.idea/
32 changes: 20 additions & 12 deletions neural_renderer/cuda/rasterize_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ std::vector<at::Tensor> forward_face_index_map_cuda(
at::Tensor depth_map,
at::Tensor face_inv_map,
at::Tensor faces_inv,
int image_size,
int width,
int height,
float near,
float far,
int return_rgb,
Expand All @@ -27,7 +28,8 @@ std::vector<at::Tensor> forward_texture_sampling_cuda(
at::Tensor rgb_map,
at::Tensor sampling_index_map,
at::Tensor sampling_weight_map,
int image_size,
int width,
int height,
float eps);

at::Tensor backward_pixel_map_cuda(
Expand All @@ -38,7 +40,8 @@ at::Tensor backward_pixel_map_cuda(
at::Tensor grad_rgb_map,
at::Tensor grad_alpha_map,
at::Tensor grad_faces,
int image_size,
int width,
int height,
float eps,
int return_rgb,
int return_alpha);
Expand All @@ -59,7 +62,8 @@ at::Tensor backward_depth_map_cuda(
at::Tensor weight_map,
at::Tensor grad_depth_map,
at::Tensor grad_faces,
int image_size);
int width,
int height);

// C++ interface

Expand All @@ -74,7 +78,8 @@ std::vector<at::Tensor> forward_face_index_map(
at::Tensor depth_map,
at::Tensor face_inv_map,
at::Tensor faces_inv,
int image_size,
int width,
int height,
float near,
float far,
int return_rgb,
Expand All @@ -90,7 +95,7 @@ std::vector<at::Tensor> forward_face_index_map(

return forward_face_index_map_cuda(faces, face_index_map, weight_map,
depth_map, face_inv_map, faces_inv,
image_size, near, far,
width, height, near, far,
return_rgb, return_alpha, return_depth);
}

Expand All @@ -103,7 +108,8 @@ std::vector<at::Tensor> forward_texture_sampling(
at::Tensor rgb_map,
at::Tensor sampling_index_map,
at::Tensor sampling_weight_map,
int image_size,
int width,
int height,
float eps) {

CHECK_INPUT(faces);
Expand All @@ -118,7 +124,7 @@ std::vector<at::Tensor> forward_texture_sampling(
return forward_texture_sampling_cuda(faces, textures, face_index_map,
weight_map, depth_map, rgb_map,
sampling_index_map, sampling_weight_map,
image_size, eps);
width, height, eps);
}

at::Tensor backward_pixel_map(
Expand All @@ -129,7 +135,8 @@ at::Tensor backward_pixel_map(
at::Tensor grad_rgb_map,
at::Tensor grad_alpha_map,
at::Tensor grad_faces,
int image_size,
int width,
int height,
float eps,
int return_rgb,
int return_alpha) {
Expand All @@ -144,7 +151,7 @@ at::Tensor backward_pixel_map(

return backward_pixel_map_cuda(faces, face_index_map, rgb_map, alpha_map,
grad_rgb_map, grad_alpha_map, grad_faces,
image_size, eps, return_rgb, return_alpha);
width, height, eps, return_rgb, return_alpha);
}

at::Tensor backward_textures(
Expand Down Expand Up @@ -174,7 +181,8 @@ at::Tensor backward_depth_map(
at::Tensor weight_map,
at::Tensor grad_depth_map,
at::Tensor grad_faces,
int image_size) {
int width,
int height) {

CHECK_INPUT(faces);
CHECK_INPUT(depth_map);
Expand All @@ -187,7 +195,7 @@ at::Tensor backward_depth_map(
return backward_depth_map_cuda(faces, depth_map, face_index_map,
face_inv_map, weight_map,
grad_depth_map, grad_faces,
image_size);
width, height);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down
Loading