Skip to content

Commit

Permalink
[MINOR] fix dtype issue 21
Browse files Browse the repository at this point in the history
  • Loading branch information
arturodcb committed Aug 15, 2024
1 parent f7bafbc commit ba8f67d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
18 changes: 9 additions & 9 deletions pymlg/torch/se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def from_components(C, r):
"""
Construct an SE(2) matrix from a rotation matrix and translation vector.
"""
T = torch.zeros(C.shape[0], 3, 3)
T = torch.zeros(C.shape[0], 3, 3, dtype=C.dtype)

T[:, 0:2, 0:2] = C
T[:, 0:2, 2] = r.view(-1, 2)
Expand All @@ -45,7 +45,7 @@ def wedge(xi):
phi = xi[:, 0]
xi_r = xi[:, 1:]
Xi_phi = SO2.wedge(phi)
Xi = torch.zeros(xi.shape[0], 3, 3)
Xi = torch.zeros(xi.shape[0], 3, 3, dtype=xi.dtype)
Xi[:, 0:2, 0:2] = Xi_phi
Xi[:, 0:2, 2] = xi_r.view(-1, 2)
return Xi
Expand All @@ -72,15 +72,15 @@ def log(T):
Xi_phi = SO2.log(T[:, 0:2, 0:2])
r = T[:, 0:2, 2].unsqueeze(2)
xi_r = SE2.V_matrix_inv(SO2.vee(Xi_phi)) @ r
Xi = torch.zeros(T.shape[0], 3, 3)
Xi = torch.zeros(T.shape[0], 3, 3, dtype=T.dtype)
Xi[:, 0:2, 0:2] = Xi_phi
Xi[:, 0:2, 2] = xi_r.squeeze(2)
return Xi

@staticmethod
def odot(b):

X = torch.zeros(b.shape[0], 3, 3)
X = torch.zeros(b.shape[0], 3, 3, dtype=b.dtype)
X[:, 0:2, 0] = SO2.odot(b[:, :2]).squeeze(2)
X[:, 0:2, 1:3] = batch_eye(b.shape[0], 2, 2) * b[:, 2].unsqueeze(2)

Expand All @@ -103,7 +103,7 @@ def left_jacobian(xi):
large_angle_mask = small_angle_mask.logical_not()
large_angle_inds = large_angle_mask.nonzero(as_tuple=True)[0]

J = torch.zeros(xi.shape[0], 3, 3)
J = torch.zeros(xi.shape[0], 3, 3, dtype=xi.dtype)

if small_angle_inds.numel():
A = (1 - 1.0 / 6.0 * phi_sq[small_angle_inds]).view(-1)
Expand Down Expand Up @@ -146,7 +146,7 @@ def adjoint(T):
# build Om matrix manually (will this break the DAG?)
Om = torch.Tensor([[0, -1], [1, 0]]).repeat(T.shape[0], 1, 1)

A = torch.zeros(T.shape[0], 3, 3)
A = torch.zeros(T.shape[0], 3, 3, dtype=T.dtype)
A[:, 0, 0] = 1
A[:, 1:, 0] = -(Om @ r.unsqueeze(2)).squeeze(2)
A[:, 1:, 1:] = C
Expand All @@ -155,7 +155,7 @@ def adjoint(T):

@staticmethod
def adjoint_algebra(Xi):
A = torch.zeros(Xi.shape[0], 3, 3)
A = torch.zeros(Xi.shape[0], 3, 3, dtype=Xi.dtype)
A[:, 1, 0] = Xi[:, 1, 2]
A[:, 2, 0] = -Xi[:, 0, 2]
A[:, 1:, 1:] = Xi[:, 0:2, 0:2]
Expand All @@ -171,7 +171,7 @@ def V_matrix(phi):
large_angle_mask = small_angle_mask.logical_not()
large_angle_inds = large_angle_mask.nonzero(as_tuple=True)[0]

V = batch_eye(phi.shape[0], 2, 2)
V = batch_eye(phi.shape[0], 2, 2, dtype=phi.dtype)

if small_angle_inds.numel():
V[small_angle_inds] += .5 * SO2.wedge(phi[small_angle_inds])
Expand All @@ -193,7 +193,7 @@ def V_matrix_inv(phi):
large_angle_mask = small_angle_mask.logical_not()
large_angle_inds = large_angle_mask.nonzero(as_tuple=True)[0]

V_inv = batch_eye(phi.shape[0], 2, 2)
V_inv = batch_eye(phi.shape[0], 2, 2, dtype=phi.dtype)

if small_angle_inds.numel():
V_inv[small_angle_inds] -= .5 * SO2.wedge(phi[small_angle_inds])
Expand Down
6 changes: 3 additions & 3 deletions pymlg/torch/se23.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def from_components(C: torch.Tensor, v: torch.Tensor, r: torch.Tensor):
if not (C.shape[0] == v.shape[0] == r.shape[0]):
raise ValueError("Batch dimension for SE_2(3) components don't match.")

X = batch_eye(C.shape[0], 5, 5)
X = batch_eye(C.shape[0], 5, 5, dtype=C.dtype)

X[:, 0:3, 0:3] = C
X[:, 0:3, 3] = v.squeeze(2)
Expand Down Expand Up @@ -193,7 +193,7 @@ def left_jacobian(xi):
xi_v = xi[:, 3:6]
xi_r = xi[:, 6:9]

J_left = batch_eye(xi.shape[0], 9, 9)
J_left = batch_eye(xi.shape[0], 9, 9, dtype=xi.dtype)

small_angle_mask = is_close(
torch.linalg.norm(xi_phi, dim=1), 0.0, SE23._small_angle_tol
Expand Down Expand Up @@ -242,7 +242,7 @@ def left_jacobian_inv(xi):
xi_v = xi[:, 3:6]
xi_r = xi[:, 6:9]

J_left = batch_eye(xi.shape[0], 9, 9)
J_left = batch_eye(xi.shape[0], 9, 9, dtype=xi.dtype)

small_angle_mask = is_close(
torch.linalg.norm(xi_phi, dim=1), 0.0, SE23._small_angle_tol
Expand Down
6 changes: 3 additions & 3 deletions pymlg/torch/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def wedge(xi: torch.Tensor):
) # this yields a (N, 3, 4) matrix that must now be blocked with a (1, 4) batched matrix

# generating a (N, 1, 4) batched matrix to append
b1 = torch.tensor([0, 0, 0, 0]).reshape(1, 1, 4)
b1 = torch.tensor([0, 0, 0, 0], dtype = xi.dtype).reshape(1, 1, 4)
block = b1.repeat(Xi.shape[0], 1, 1)

return torch.cat((Xi, block), dim=1)
Expand Down Expand Up @@ -187,7 +187,7 @@ def log(X : torch.Tensor):

@staticmethod
def odot(b : torch.Tensor):
X = torch.zeros(b.shape[0], 4, 6)
X = torch.zeros(b.shape[0], 4, 6, dtype=b.dtype)
X[:, 0:3, 0:3] = SO3.odot(b[0:3])
X[:, 0:3, 3:6] = b[:, 3] * batch_eye(b.shape[0], 3, 3, dtype=b.dtype)
return X
Expand All @@ -204,7 +204,7 @@ def adjoint(X):

@staticmethod
def adjoint_algebra(Xi):
A = torch.zeros(Xi.shape[0], 6, 6)
A = torch.zeros(Xi.shape[0], 6, 6, dtype=Xi.dtype)
A[:, 0:3, 0:3] = Xi[:, 0:3, 0:3]
A[:, 3:6, 0:3] = SO3.wedge(Xi[:, 0:3, 3])
A[:, 3:6, 3:6] = Xi[:, 0:3, 0:3]
Expand Down
4 changes: 2 additions & 2 deletions pymlg/torch/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def left_jacobian(xi):
large_angle_mask = small_angle_mask.logical_not()
large_angle_inds = large_angle_mask.nonzero(as_tuple=True)[0]

J_left = torch.empty(xi.shape[0], 3, 3)
J_left = torch.empty(xi.shape[0], 3, 3, dtype=xi.dtype)

cross_xi = SO3.wedge(xi)

Expand Down Expand Up @@ -326,7 +326,7 @@ def left_jacobian_inv(xi):
large_angle_mask = small_angle_mask.logical_not()
large_angle_inds = large_angle_mask.nonzero(as_tuple=True)[0]

J_left = torch.empty(xi.shape[0], 3, 3)
J_left = torch.empty(xi.shape[0], 3, 3, dtype=xi.dtype)

cross_xi = SO3.wedge(xi)

Expand Down

0 comments on commit ba8f67d

Please sign in to comment.