diff --git a/camera.py b/camera.py index 4125842..da33a98 100644 --- a/camera.py +++ b/camera.py @@ -223,7 +223,7 @@ def angle_to_rotation_matrix(a,axis): M = M.roll((roll,roll),dims=(-2,-1)) return M -def get_center_and_ray(opt,pose,intr=None): # [HW,2] +def get_center_and_ray(opt,pose,intr=None,ray_idx=None): # [HW,2] # given the intrinsic/extrinsic matrices, get the camera center and ray directions] assert(opt.camera.model=="perspective") with torch.no_grad(): @@ -232,6 +232,9 @@ def get_center_and_ray(opt,pose,intr=None): # [HW,2] x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5) Y,X = torch.meshgrid(y_range,x_range) # [H,W] xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] + if ray_idx is not None: + # consider only subset of rays + xy_grid = xy_grid[ray_idx] # compute center and ray batch_size = len(pose) xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2] diff --git a/model/nerf.py b/model/nerf.py index b0dcb2c..efd9b6d 100644 --- a/model/nerf.py +++ b/model/nerf.py @@ -231,12 +231,9 @@ def get_pose(self,opt,var,mode=None): def render(self,opt,pose,intr=None,ray_idx=None,mode=None): batch_size = len(pose) - center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3] + center,ray = camera.get_center_and_ray(opt,pose,intr=intr,ray_idx=ray_idx) # [B,HW,3] while ray.isnan().any(): # TODO: weird bug, ray becomes NaN arbitrarily if batch_size>1, not deterministic reproducible - center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3] - if ray_idx is not None: - # consider only subset of rays - center,ray = center[:,ray_idx],ray[:,ray_idx] + center,ray = camera.get_center_and_ray(opt,pose,intr=intr,ray_idx=ray_idx) # [B,HW,3] if opt.camera.ndc: # convert center/ray representations to NDC center,ray = camera.convert_NDC(opt,center,ray,intr=intr)