Skip to content

Commit

Permalink
[MINOR] fix dtype issue 21
Browse files Browse the repository at this point in the history
  • Loading branch information
arturodcb committed Aug 15, 2024
1 parent 0715e40 commit f7bafbc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions pymlg/torch/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def from_components(C: torch.Tensor, r: torch.Tensor):
if not (C.shape[0] == r.shape[0]):
raise ValueError("Batch dimension for SE(3) components don't match.")

X = batch_eye(C.shape[0], 4, 4)
X = batch_eye(C.shape[0], 4, 4, dtype = C.dtype)

X[:, 0:3, 0:3] = C
X[:, 0:3, 3] = r.squeeze(2)
Expand Down Expand Up @@ -189,7 +189,7 @@ def log(X : torch.Tensor):
def odot(b : torch.Tensor):
X = torch.zeros(b.shape[0], 4, 6)
X[:, 0:3, 0:3] = SO3.odot(b[0:3])
X[:, 0:3, 3:6] = b[:, 3] * batch_eye(b.shape[0], 3, 3)
X[:, 0:3, 3:6] = b[:, 3] * batch_eye(b.shape[0], 3, 3, dtype=b.dtype)
return X

@staticmethod
Expand All @@ -211,15 +211,15 @@ def adjoint_algebra(Xi):
return A

@staticmethod
def identity(N=1):
return batch_eye(N, 4, 4)
def identity(N=1, dtype=torch.float32):
return batch_eye(N, 4, 4, dtype=dtype)

@staticmethod
def left_jacobian(xi):
xi_phi = xi[:, 0:3]
xi_r = xi[:, 3:6]

J_left = batch_eye(xi.shape[0], 6, 6)
J_left = batch_eye(xi.shape[0], 6, 6, dtype=xi.dtype)

small_angle_mask = is_close(
torch.linalg.norm(xi_phi, dim=1), 0.0, SE3._small_angle_tol
Expand Down Expand Up @@ -256,7 +256,7 @@ def left_jacobian_inv(xi):
xi_phi = xi[:, 0:3]
xi_r = xi[:, 3:6]

J_left = batch_eye(xi.shape[0], 6, 6)
J_left = batch_eye(xi.shape[0], 6, 6, dtype=xi.dtype)

small_angle_mask = is_close(
torch.linalg.norm(xi_phi, dim=1), 0.0, SE3._small_angle_tol
Expand Down
4 changes: 2 additions & 2 deletions pymlg/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def batch_vector(N, v : torch.Tensor):

return v.repeat(N, 1, 1)

def batch_eye(N, n, m):
def batch_eye(N, n, m, dtype = torch.float32):
"""
Generate a batched set of identity matricies by using torch.repeat()
"""

b = torch.eye(n, m)
b = torch.eye(n, m, dtype=dtype)

return b.repeat(N, 1, 1)

0 comments on commit f7bafbc

Please sign in to comment.