diff --git a/liegroups/torch/so3.py b/liegroups/torch/so3.py index 1495c25..42c3639 100644 --- a/liegroups/torch/so3.py +++ b/liegroups/torch/so3.py @@ -24,16 +24,22 @@ def exp(cls, phi): mat = phi.new_empty(phi.shape[0], cls.dim, cls.dim) angle = phi.norm(p=2, dim=1) + cuda = mat.is_cuda # Near phi==0, use first order Taylor expansion small_angle_mask = utils.isclose(angle, 0.) small_angle_inds = small_angle_mask.nonzero(as_tuple=False).squeeze_(dim=1) if len(small_angle_inds) > 0: - mat[small_angle_inds] = \ - torch.eye(cls.dim, dtype=phi.dtype).expand_as(mat[small_angle_inds]) + \ - cls.wedge(phi[small_angle_inds]) - + if cuda: + mat[small_angle_inds] = \ + torch.eye(cls.dim, dtype=phi.dtype).cuda().expand_as(mat[small_angle_inds]) + \ + cls.wedge(phi[small_angle_inds]) + else: + mat[small_angle_inds] = \ + torch.eye(cls.dim, dtype=phi.dtype).expand_as(mat[small_angle_inds]) + \ + cls.wedge(phi[small_angle_inds]) + # Otherwise... large_angle_mask = small_angle_mask.logical_not() large_angle_inds = large_angle_mask.nonzero(as_tuple=False).squeeze_(dim=1) @@ -46,9 +52,12 @@ def exp(cls, phi): dim=2).expand_as(mat[large_angle_inds]) c = angle.cos().unsqueeze_(dim=1).unsqueeze_( dim=2).expand_as(mat[large_angle_inds]) - - A = c * torch.eye(cls.dim, dtype=phi.dtype).unsqueeze_(dim=0).expand_as( - mat[large_angle_inds]) + if cuda: + A = c * torch.eye(cls.dim, dtype=phi.dtype).cuda().unsqueeze_(dim=0).expand_as( + mat[large_angle_inds]) + else: + A = c * torch.eye(cls.dim, dtype=phi.dtype).unsqueeze_(dim=0).expand_as( + mat[large_angle_inds]) B = (1. - c) * utils.outer(axis, axis) C = s * cls.wedge(axis) @@ -123,7 +132,7 @@ def inv_left_jacobian(cls, phi): if phi.shape[1] != cls.dof: raise ValueError( "phi must have shape ({},) or (N,{})".format(cls.dof, cls.dof)) - + cuda = phi.is_cuda jac = phi.new_empty(phi.shape[0], cls.dof, cls.dof) angle = phi.norm(p=2, dim=1) @@ -131,9 +140,14 @@ def inv_left_jacobian(cls, phi): small_angle_mask = utils.isclose(angle, 0.) small_angle_inds = small_angle_mask.nonzero(as_tuple=False).squeeze_(dim=1) if len(small_angle_inds) > 0: - jac[small_angle_inds] = \ - torch.eye(cls.dof, dtype=phi.dtype).expand_as(jac[small_angle_inds]) - \ - 0.5 * cls.wedge(phi[small_angle_inds]) + if cuda: + jac[small_angle_inds] = \ + torch.eye(cls.dof, dtype=phi.dtype).cuda().expand_as(jac[small_angle_inds]) - \ + 0.5 * cls.wedge(phi[small_angle_inds]) + else: + jac[small_angle_inds] = \ + torch.eye(cls.dof, dtype=phi.dtype).expand_as(jac[small_angle_inds]) - \ + 0.5 * cls.wedge(phi[small_angle_inds]) # Otherwise... large_angle_mask = small_angle_mask.logical_not() @@ -151,10 +165,14 @@ def inv_left_jacobian(cls, phi): dim=2).expand_as(jac[large_angle_inds]) hacha.unsqueeze_(dim=1).unsqueeze_( dim=2).expand_as(jac[large_angle_inds]) - - A = hacha * \ - torch.eye(cls.dof, dtype=phi.dtype).unsqueeze_( - dim=0).expand_as(jac[large_angle_inds]) + if cuda: + A = hacha * \ + torch.eye(cls.dof, dtype=phi.dtype).cuda().unsqueeze_( + dim=0).expand_as(jac[large_angle_inds]) + else: + A = hacha * \ + torch.eye(cls.dof, dtype=phi.dtype).unsqueeze_( + dim=0).expand_as(jac[large_angle_inds]) B = (1. - hacha) * utils.outer(axis, axis) C = -ha * cls.wedge(axis) @@ -173,14 +191,20 @@ def left_jacobian(cls, phi): jac = phi.new_empty(phi.shape[0], cls.dof, cls.dof) angle = phi.norm(p=2, dim=1) + cuda = phi.is_cuda # Near phi==0, use first order Taylor expansion small_angle_mask = utils.isclose(angle, 0.) small_angle_inds = small_angle_mask.nonzero(as_tuple=False).squeeze_(dim=1) if len(small_angle_inds) > 0: - jac[small_angle_inds] = \ - torch.eye(cls.dof, dtype=phi.dtype).expand_as(jac[small_angle_inds]) + \ - 0.5 * cls.wedge(phi[small_angle_inds]) + if cuda: + jac[small_angle_inds] = \ + torch.eye(cls.dof, dtype=phi.dtype).cuda().expand_as(jac[small_angle_inds]) + \ + 0.5 * cls.wedge(phi[small_angle_inds]) + else: + jac[small_angle_inds] = \ + torch.eye(cls.dof, dtype=phi.dtype).expand_as(jac[small_angle_inds]) + \ + 0.5 * cls.wedge(phi[small_angle_inds]) # Otherwise... large_angle_mask = small_angle_mask.logical_not() @@ -192,11 +216,16 @@ def left_jacobian(cls, phi): angle.unsqueeze(dim=1).expand(len(angle), cls.dof) s = angle.sin() c = angle.cos() - - A = (s / angle).unsqueeze_(dim=1).unsqueeze_( - dim=2).expand_as(jac[large_angle_inds]) * \ - torch.eye(cls.dof, dtype=phi.dtype).unsqueeze_(dim=0).expand_as( - jac[large_angle_inds]) + if cuda: + A = (s / angle).unsqueeze_(dim=1).unsqueeze_( + dim=2).expand_as(jac[large_angle_inds]) * \ + torch.eye(cls.dof, dtype=phi.dtype).cuda().unsqueeze_(dim=0).expand_as( + jac[large_angle_inds]) + else: + A = (s / angle).unsqueeze_(dim=1).unsqueeze_( + dim=2).expand_as(jac[large_angle_inds]) * \ + torch.eye(cls.dof, dtype=phi.dtype).unsqueeze_(dim=0).expand_as( + jac[large_angle_inds]) B = (1. - s / angle).unsqueeze_(dim=1).unsqueeze_( dim=2).expand_as(jac[large_angle_inds]) * \ utils.outer(axis, axis) @@ -213,7 +242,7 @@ def log(self): mat = self.mat.unsqueeze(dim=0) else: mat = self.mat - + cuda = mat.is_cuda phi = mat.new_empty(mat.shape[0], self.dof) # The cosine of the rotation angle is related to the utils.trace of C @@ -226,9 +255,14 @@ def log(self): small_angle_inds = small_angle_mask.nonzero(as_tuple=False).squeeze_(dim=1) if len(small_angle_inds) > 0: - phi[small_angle_inds, :] = \ - self.vee(mat[small_angle_inds] - - torch.eye(self.dim, dtype=mat.dtype).expand_as(mat[small_angle_inds])) + if cuda: + phi[small_angle_inds, :] = \ + self.vee(mat[small_angle_inds] - + torch.eye(self.dim, dtype=mat.dtype).cuda().expand_as(mat[small_angle_inds])) + else: + phi[small_angle_inds, :] = \ + self.vee(mat[small_angle_inds] - + torch.eye(self.dim, dtype=mat.dtype).expand_as(mat[small_angle_inds])) # Otherwise... large_angle_mask = small_angle_mask.logical_not() diff --git a/liegroups/torch/utils.py b/liegroups/torch/utils.py index 38129d0..d208a42 100644 --- a/liegroups/torch/utils.py +++ b/liegroups/torch/utils.py @@ -42,8 +42,10 @@ def trace(mat): # Default batch size is 1 if mat.dim() < 3: mat = mat.unsqueeze(dim=0) - + if mat.is_cuda: + tr = (torch.eye(mat.shape[1], dtype=mat.dtype).cuda() * mat).sum(dim=1).sum(dim=1) # Element-wise multiply by identity and take the sum - tr = (torch.eye(mat.shape[1], dtype=mat.dtype) * mat).sum(dim=1).sum(dim=1) + else: + tr = (torch.eye(mat.shape[1], dtype=mat.dtype) * mat).sum(dim=1).sum(dim=1) return tr.view(mat.shape[0])